From 66fab8722c910692ba104417d19c9a3037cd427e Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Fri, 27 Mar 2026 11:53:51 +0100 Subject: [PATCH 01/14] refactor: use EnumText for credential_type in TriggerSubscription (#34174) Co-authored-by: Asuka Minato Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/core/tools/tool_manager.py | 3 +-- api/models/tools.py | 8 ++++++-- api/models/trigger.py | 6 ++++-- api/services/tools/builtin_tools_manage_service.py | 4 ++-- api/services/tools/tools_transform_service.py | 2 +- api/services/trigger/trigger_provider_service.py | 2 +- .../services/plugin/test_plugin_parameter_service.py | 3 ++- 7 files changed, 17 insertions(+), 11 deletions(-) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 4870adb7b58..4a10c7e23e5 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -37,7 +37,6 @@ from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source from core.helper.position_helper import is_filtered -from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import Tool from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort @@ -326,7 +325,7 @@ class ToolManager: tenant_id=tenant_id, user_id=user_id, credentials=dict(decrypted_credentials), - credential_type=CredentialType.of(builtin_provider.credential_type), + credential_type=builtin_provider.credential_type, runtime_parameters={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, diff --git a/api/models/tools.py b/api/models/tools.py index 63b27b94131..d8731fb8a8a 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -11,6 +11,7 @@ from deprecated import deprecated from sqlalchemy import ForeignKey, String, func, select from sqlalchemy.orm import Mapped, mapped_column +from core.plugin.entities.plugin_daemon import CredentialType from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( @@ -109,8 +110,11 @@ class BuiltinToolProvider(TypeBase): ) is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False) # credential type, e.g., "api-key", "oauth2" - credential_type: Mapped[str] = mapped_column( - String(32), nullable=False, server_default=sa.text("'api-key'"), default="api-key" + credential_type: Mapped[CredentialType] = mapped_column( + EnumText(CredentialType, length=32), + nullable=False, + server_default=sa.text("'api-key'"), + default=CredentialType.API_KEY, ) expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"), default=-1) diff --git a/api/models/trigger.py b/api/models/trigger.py index 627b854060c..5233a6e2711 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -102,7 +102,9 @@ class TriggerSubscription(TypeBase): credentials: Mapped[TriggerCredentials] = mapped_column( sa.JSON, nullable=False, comment="Subscription credentials JSON" ) - credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key") + credential_type: Mapped[CredentialType] = mapped_column( + EnumText(CredentialType, length=50), nullable=False, comment="oauth or api_key" + ) credential_expires_at: Mapped[int] = mapped_column( Integer, default=-1, comment="OAuth token expiration timestamp, -1 for never" ) @@ -144,7 +146,7 @@ class TriggerSubscription(TypeBase): endpoint=generate_plugin_trigger_endpoint_url(self.endpoint_id), parameters=self.parameters, properties=self.properties, - credential_type=CredentialType(self.credential_type), + credential_type=self.credential_type, credentials=self.credentials, workflows_in_use=-1, ) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 6797a67ddef..8e3c36e0998 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -275,7 +275,7 @@ class BuiltinToolManageService: user_id=user_id, provider=provider, encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), - credential_type=api_type.value, + credential_type=api_type, name=name, expires_at=expires_at if expires_at is not None else -1, ) @@ -314,7 +314,7 @@ class BuiltinToolManageService: .filter_by( tenant_id=tenant_id, provider=provider, - credential_type=credential_type.value, + credential_type=credential_type, ) .order_by(BuiltinToolProvider.created_at.desc()) .all() diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index b6e5367023c..b2761460664 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -423,7 +423,7 @@ class ToolTransformService: id=provider.id, name=provider.name, provider=provider.provider, - credential_type=CredentialType.of(provider.credential_type), + credential_type=provider.credential_type, is_default=provider.is_default, credentials=credentials, ) diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 688993c7987..008d8bdb8af 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -198,7 +198,7 @@ class TriggerProviderService: credentials=dict(credential_encrypter.encrypt(dict(credentials))) if credential_encrypter else {}, - credential_type=credential_type.value, + credential_type=credential_type, credential_expires_at=credential_expires_at, expires_at=expires_at, ) diff --git a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py index 38851372210..ce9f10e207d 100644 --- a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py @@ -12,6 +12,7 @@ from uuid import uuid4 import pytest +from core.plugin.entities.plugin_daemon import CredentialType from models.tools import BuiltinToolProvider from services.plugin.plugin_parameter_service import PluginParameterService @@ -66,7 +67,7 @@ class TestGetDynamicSelectOptionsTool: provider="google", name="API KEY 1", encrypted_credentials=json.dumps({"api_key": "encrypted"}), - credential_type="api_key", + credential_type=CredentialType.API_KEY, ) db_session_with_containers.add(db_record) db_session_with_containers.commit() From 32d394d65b9ec6d3668c134f134227c3072ec9b1 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Fri, 27 Mar 2026 15:00:26 +0100 Subject: [PATCH 02/14] refactor: select in core/ops trace manager and trace providers (#34197) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/core/ops/aliyun_trace/utils.py | 4 +- .../arize_phoenix_trace.py | 4 +- api/core/ops/langfuse_trace/langfuse_trace.py | 4 +- .../ops/langsmith_trace/langsmith_trace.py | 4 +- api/core/ops/mlflow_trace/mlflow_trace.py | 25 ++------- api/core/ops/opik_trace/opik_trace.py | 4 +- api/core/ops/ops_trace_manager.py | 24 ++++---- api/core/ops/weave_trace/weave_trace.py | 4 +- .../aliyun_trace/test_aliyun_trace_utils.py | 10 +--- .../ops/langfuse_trace/test_langfuse_trace.py | 4 +- .../langsmith_trace/test_langsmith_trace.py | 4 +- .../ops/mlflow_trace/test_mlflow_trace.py | 28 +++++----- .../core/ops/opik_trace/test_opik_trace.py | 4 +- .../core/ops/test_ops_trace_manager.py | 55 +++++++++---------- .../core/ops/weave_trace/test_weave_trace.py | 14 ++--- 15 files changed, 77 insertions(+), 115 deletions(-) diff --git a/api/core/ops/aliyun_trace/utils.py b/api/core/ops/aliyun_trace/utils.py index 43b204b78c7..956fc60191f 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/core/ops/aliyun_trace/utils.py @@ -27,9 +27,7 @@ DEFAULT_FRAMEWORK_NAME = "dify" def get_user_id_from_message_data(message_data) -> str: user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: user_id = end_user_data.session_id return user_id diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 724127c31c7..a1ea182f66d 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -410,9 +410,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): # Add end user data if available if trace_info.message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, trace_info.message_data.from_end_user_id) if end_user_data is not None: metadata["end_user_id"] = end_user_data.session_id diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 4a634e2e57f..3bf01eb81c6 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -241,9 +241,7 @@ class LangFuseDataTrace(BaseTraceInstance): user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: user_id = end_user_data.session_id metadata["user_id"] = user_id diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 9f7d73b4cad..d960038f154 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -259,9 +259,7 @@ class LangSmithDataTrace(BaseTraceInstance): metadata["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: end_user_id = end_user_data.session_id metadata["end_user_id"] = end_user_id diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/core/ops/mlflow_trace/mlflow_trace.py index 8ec69e3542c..8bf2e5dc138 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/core/ops/mlflow_trace/mlflow_trace.py @@ -9,6 +9,7 @@ from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey from mlflow.tracing.fluent import start_span_no_context, update_current_trace from mlflow.tracing.provider import detach_span_from_context, set_span_in_context +from sqlalchemy import select from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig @@ -320,7 +321,7 @@ class MLflowDataTrace(BaseTraceInstance): def _get_message_user_id(self, metadata: dict) -> str | None: if (end_user_id := metadata.get("from_end_user_id")) and ( - end_user_data := db.session.query(EndUser).where(EndUser.id == end_user_id).first() + end_user_data := db.session.get(EndUser, end_user_id) ): return end_user_data.session_id @@ -447,25 +448,11 @@ class MLflowDataTrace(BaseTraceInstance): def _get_workflow_nodes(self, workflow_run_id: str): """Helper method to get workflow nodes""" - workflow_nodes = ( - db.session.query( - WorkflowNodeExecutionModel.id, - WorkflowNodeExecutionModel.tenant_id, - WorkflowNodeExecutionModel.app_id, - WorkflowNodeExecutionModel.title, - WorkflowNodeExecutionModel.node_type, - WorkflowNodeExecutionModel.status, - WorkflowNodeExecutionModel.inputs, - WorkflowNodeExecutionModel.outputs, - WorkflowNodeExecutionModel.created_at, - WorkflowNodeExecutionModel.elapsed_time, - WorkflowNodeExecutionModel.process_data, - WorkflowNodeExecutionModel.execution_metadata, - ) - .filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) + workflow_nodes = db.session.scalars( + select(WorkflowNodeExecutionModel) + .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) .order_by(WorkflowNodeExecutionModel.created_at) - .all() - ) + ).all() return workflow_nodes def _get_node_span_type(self, node_type: str) -> str: diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index a3ead548bbe..b98cc3ce598 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -288,9 +288,7 @@ class OpikDataTrace(BaseTraceInstance): metadata["file_list"] = file_list if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: end_user_id = end_user_data.session_id metadata["end_user_id"] = end_user_id diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 0a2a0642f1a..9c36d57c6f5 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -420,10 +420,10 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig | None = ( - db.session.query(TraceAppConfig) + trace_config_data: TraceAppConfig | None = db.session.scalar( + select(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) - .first() + .limit(1) ) if not trace_config_data: @@ -463,7 +463,7 @@ class OpsTraceManager: if isinstance(app_id, str) and app_id.startswith("tenant-"): return None - app: App | None = db.session.query(App).where(App.id == app_id).first() + app = db.session.get(App, app_id) if app is None: return None @@ -537,7 +537,7 @@ class OpsTraceManager: except KeyError: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: App | None = db.session.query(App).where(App.id == app_id).first() + app_config: App | None = db.session.get(App, app_id) if not app_config: raise ValueError("App not found") app_config.tracing = json.dumps( @@ -555,7 +555,7 @@ class OpsTraceManager: :param app_id: app id :return: """ - app: App | None = db.session.query(App).where(App.id == app_id).first() + app: App | None = db.session.get(App, app_id) if not app: raise ValueError("App not found") if not app.tracing: @@ -883,7 +883,7 @@ class TraceTask: inputs = message_data.message # get message file data - message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() + message_file_data = db.session.scalar(select(MessageFile).where(MessageFile.message_id == message_id).limit(1)) file_list = [] if message_file_data and message_file_data.url is not None: file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" @@ -972,8 +972,8 @@ class TraceTask: # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = ( - db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + workflow_app_log_data = db.session.scalar( + select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == message_data.workflow_run_id).limit(1) ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None @@ -1015,8 +1015,8 @@ class TraceTask: # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = ( - db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + workflow_app_log_data = db.session.scalar( + select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == message_data.workflow_run_id).limit(1) ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None @@ -1171,7 +1171,7 @@ class TraceTask: metadata["node_execution_id"] = node_execution_id file_url = "" - message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() + message_file_data = db.session.scalar(select(MessageFile).where(MessageFile.message_id == message_id).limit(1)) if message_file_data: message_file_id = message_file_data.id if message_file_data else None type = message_file_data.type diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index a55505822ad..f79544f1c7b 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -245,9 +245,7 @@ class WeaveDataTrace(BaseTraceInstance): attributes["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: end_user_id = end_user_data.session_id attributes["end_user_id"] = end_user_id diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py index fa885e93208..e4d8f2d5ea0 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py @@ -45,11 +45,8 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch): end_user_data = MagicMock(spec=EndUser) end_user_data.session_id = "session_id" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = end_user_data - mock_session = MagicMock() - mock_session.query.return_value = mock_query + mock_session.get.return_value = end_user_data from core.ops.aliyun_trace.utils import db @@ -63,11 +60,8 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch): message_data.from_account_id = "account_id" message_data.from_end_user_id = "end_user_id" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = None - mock_session = MagicMock() - mock_session.query.return_value = mock_query + mock_session.get.return_value = None from core.ops.aliyun_trace.utils import db diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py index fdf66d4d405..8ebf4419211 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py @@ -365,9 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_end_user - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.query", lambda model: mock_query) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock() trace_instance.add_generation = MagicMock() diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py index e89359c25b6..34c64c54a1f 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py @@ -319,9 +319,7 @@ def test_message_trace(trace_instance, monkeypatch): # Mock EndUser lookup mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_end_user - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.query", lambda model: mock_query) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_run = MagicMock() diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py index 7ff6f7dcfd8..afc5726ede2 100644 --- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py +++ b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py @@ -330,7 +330,7 @@ class TestTraceDispatcher: class TestWorkflowTrace: def test_basic_workflow_no_nodes(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -343,7 +343,7 @@ class TestWorkflowTrace: span.end.assert_called_once() def test_workflow_filters_sys_inputs_and_adds_query(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -374,7 +374,7 @@ class TestWorkflowTrace: ), outputs='{"text": "hello world"}', ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [llm_node] + mock_db.session.scalars.return_value.all.return_value = [llm_node] workflow_span = MagicMock() node_span = MagicMock() @@ -397,7 +397,7 @@ class TestWorkflowTrace: } ), ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [qc_node] + mock_db.session.scalars.return_value.all.return_value = [qc_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -411,7 +411,7 @@ class TestWorkflowTrace: node_type=BuiltinNodeTypes.HTTP_REQUEST, process_data='{"url": "https://api.com"}', ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [http_node] + mock_db.session.scalars.return_value.all.return_value = [http_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -434,7 +434,7 @@ class TestWorkflowTrace: } ), ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [kr_node] + mock_db.session.scalars.return_value.all.return_value = [kr_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -448,7 +448,7 @@ class TestWorkflowTrace: def test_workflow_with_failed_node(self, trace_instance, mock_tracing, mock_db): failed_node = _make_node(status="failed") - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [failed_node] + mock_db.session.scalars.return_value.all.return_value = [failed_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -459,7 +459,7 @@ class TestWorkflowTrace: node_span.add_event.assert_called_once() def test_workflow_with_workflow_error(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] workflow_span = MagicMock() mock_tracing["start"].return_value = workflow_span mock_tracing["set"].return_value = "token" @@ -473,7 +473,7 @@ class TestWorkflowTrace: def test_workflow_node_no_inputs_no_outputs(self, trace_instance, mock_tracing, mock_db): node = _make_node(inputs=None, outputs=None) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [node] + mock_db.session.scalars.return_value.all.return_value = [node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -486,7 +486,7 @@ class TestWorkflowTrace: assert end_call.kwargs["outputs"] == {} def test_workflow_no_user_id_no_conversation_id(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -501,7 +501,7 @@ class TestWorkflowTrace: def test_workflow_empty_query(self, trace_instance, mock_tracing, mock_db): """When query is empty string, it's falsy so no query key added.""" - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -680,12 +680,12 @@ class TestGetMessageUserId: def test_returns_end_user_session_id(self, trace_instance, mock_db): end_user = MagicMock() end_user.session_id = "session-1" - mock_db.session.query.return_value.where.return_value.first.return_value = end_user + mock_db.session.get.return_value = end_user result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1"}) assert result == "session-1" def test_returns_account_id_when_no_end_user(self, trace_instance, mock_db): - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1", "from_account_id": "acc-1"}) assert result == "acc-1" @@ -834,7 +834,7 @@ class TestGenerateNameTrace: class TestGetWorkflowNodes: def test_queries_db(self, trace_instance, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = ["n1", "n2"] + mock_db.session.scalars.return_value.all.return_value = ["n1", "n2"] result = trace_instance._get_workflow_nodes("run-1") assert result == ["n1", "n2"] diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py index 6625cb719fe..c02ac413f27 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py @@ -373,9 +373,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_end_user - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.query", lambda model: mock_query) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2")) trace_instance.add_span = MagicMock() diff --git a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py index f81806c9418..e47df0121ea 100644 --- a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py +++ b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py @@ -157,17 +157,19 @@ def make_workflow_run(): ) -def configure_db_query(session, *, message_file=None, workflow_app_log=None): - def _side_effect(model): - query = MagicMock() - query.filter_by.return_value.first.return_value = None - if message_file and model.__name__ == "MessageFile": - query.filter_by.return_value.first.return_value = message_file - if workflow_app_log and model.__name__ == "WorkflowAppLog": - query.filter_by.return_value.first.return_value = workflow_app_log - return query +def configure_db_scalar(session, *, message_file=None, workflow_app_log=None): + """Configure session.scalar to return appropriate values for MessageFile/WorkflowAppLog lookups.""" + original_scalar = session.scalar - session.query.side_effect = _side_effect + def _side_effect(stmt): + stmt_str = str(stmt) + if "message_file" in stmt_str.lower(): + return message_file + if "workflow_app_log" in stmt_str.lower(): + return workflow_app_log + return original_scalar(stmt) + + session.scalar.side_effect = _side_effect class DummySessionContext: @@ -263,7 +265,7 @@ def workflow_repo_fixture(monkeypatch): def trace_task_message(monkeypatch, mock_db): message_data = make_message_data() monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda msg_id: message_data) - configure_db_query(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id")) + configure_db_scalar(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id")) return message_data @@ -307,56 +309,53 @@ def test_obfuscated_decrypt_token(encryption_mocks): def test_get_decrypted_tracing_config_returns_config(encryption_mocks, mock_db): trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc", "other_value": "info"}) - mock_db.query.return_value.where.return_value.first.return_value = trace_config_data app = SimpleNamespace(id="app-id", tenant_id="tenant") - mock_db.scalar.return_value = app + mock_db.scalar.side_effect = [trace_config_data, app] decrypted = OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") assert decrypted["other_value"] == "info" def test_get_decrypted_tracing_config_missing_trace_config(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.scalar.return_value = None assert OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") is None def test_get_decrypted_tracing_config_raises_for_missing_app(mock_db): trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc"}) - mock_db.query.return_value.where.return_value.first.return_value = trace_config_data - mock_db.scalar.return_value = None + mock_db.scalar.side_effect = [trace_config_data, None] with pytest.raises(ValueError, match="App not found"): OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") def test_get_decrypted_tracing_config_raises_for_none_config(mock_db): trace_config_data = SimpleNamespace(tracing_config=None) - mock_db.query.return_value.where.return_value.first.return_value = trace_config_data - mock_db.scalar.return_value = SimpleNamespace(tenant_id="tenant") + mock_db.scalar.side_effect = [trace_config_data, SimpleNamespace(tenant_id="tenant")] with pytest.raises(ValueError, match="Tracing config cannot be None"): OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") def test_get_ops_trace_instance_handles_none_app(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.get.return_value = None assert OpsTraceManager.get_ops_trace_instance("app-id") is None def test_get_ops_trace_instance_returns_none_when_disabled(mock_db, monkeypatch): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": False})) - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app assert OpsTraceManager.get_ops_trace_instance("app-id") is None def test_get_ops_trace_instance_invalid_provider(mock_db, monkeypatch): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "missing"})) - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app monkeypatch.setattr("core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({})) assert OpsTraceManager.get_ops_trace_instance("app-id") is None def test_get_ops_trace_instance_success(monkeypatch, mock_db): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "dummy"})) - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app monkeypatch.setattr( "core.ops.ops_trace_manager.OpsTraceManager.get_decrypted_tracing_config", classmethod(lambda cls, aid, provider: {"secret_value": "decrypted", "other_value": "info"}), @@ -390,7 +389,7 @@ def test_get_app_config_through_message_id_app_model_config(mock_db): def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.get.return_value = None with pytest.raises(ValueError, match="Invalid tracing provider"): OpsTraceManager.update_app_tracing_config("app", True, "bad") with pytest.raises(ValueError, match="App not found"): @@ -399,26 +398,26 @@ def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch): def test_update_app_tracing_config_success(mock_db): app = SimpleNamespace(id="app-id", tracing="{}") - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app OpsTraceManager.update_app_tracing_config("app-id", True, "dummy") assert app.tracing is not None mock_db.commit.assert_called_once() def test_get_app_tracing_config_errors_when_missing(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.get.return_value = None with pytest.raises(ValueError, match="App not found"): OpsTraceManager.get_app_tracing_config("app") def test_get_app_tracing_config_returns_defaults(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=None) + mock_db.get.return_value = SimpleNamespace(tracing=None) assert OpsTraceManager.get_app_tracing_config("app-id") == {"enabled": False, "tracing_provider": None} def test_get_app_tracing_config_returns_payload(mock_db): payload = {"enabled": True, "tracing_provider": "dummy"} - mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=json.dumps(payload)) + mock_db.get.return_value = SimpleNamespace(tracing=json.dumps(payload)) assert OpsTraceManager.get_app_tracing_config("app-id") == payload @@ -501,7 +500,7 @@ def test_trace_task_dataset_retrieval_trace(trace_task_message): def test_trace_task_tool_trace(monkeypatch, mock_db): custom_message = make_message_data(agent_thoughts=[make_agent_thought("tool-a", datetime(2025, 2, 20, 12, 1, 0))]) monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda _: custom_message) - configure_db_query(mock_db, message_file=FakeMessageFile()) + configure_db_scalar(mock_db, message_file=FakeMessageFile()) task = TraceTask(trace_type=TraceTaskName.TOOL_TRACE, message_id="msg-id") timer = {"start": 1, "end": 5} result = task.tool_trace("msg-id", timer, tool_name="tool-a", tool_inputs={"foo": 1}, tool_outputs="result") diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py index 8987b6682cd..531c7de05f9 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py @@ -802,8 +802,8 @@ class TestMessageTrace: def test_basic_message_trace(self, trace_instance, monkeypatch): """message_trace creates message run and llm child run.""" monkeypatch.setattr( - "core.ops.weave_trace.weave_trace.db.session.query", - lambda model: MagicMock(where=lambda: MagicMock(first=lambda: None)), + "core.ops.weave_trace.weave_trace.db.session.get", + lambda model, pk: None, ) trace_instance.start_call = MagicMock() @@ -823,7 +823,7 @@ class TestMessageTrace: trace_instance.file_base_url = "http://files.test" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -845,7 +845,7 @@ class TestMessageTrace: end_user.session_id = "session-xyz" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = end_user + mock_db.session.get.return_value = end_user monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -865,7 +865,7 @@ class TestMessageTrace: def test_message_trace_no_end_user(self, trace_instance, monkeypatch): """message_trace handles when from_end_user_id is None.""" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -883,7 +883,7 @@ class TestMessageTrace: def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch): """trace_id falls back to message_id when trace_id is None.""" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -898,7 +898,7 @@ class TestMessageTrace: def test_message_trace_file_list_none(self, trace_instance, monkeypatch): """message_trace handles file_list=None gracefully.""" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() From 40591a7c5067f8777eb5cdf660d01e1b7991bb2a Mon Sep 17 00:00:00 2001 From: 99 Date: Sat, 28 Mar 2026 05:05:32 +0800 Subject: [PATCH 03/14] refactor(api): use standalone graphon package (#34209) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .github/CODEOWNERS | 1 - api/.importlinter | 150 -- api/controllers/common/fields.py | 2 +- api/controllers/console/app/app.py | 4 +- api/controllers/console/app/audio.py | 2 +- api/controllers/console/app/completion.py | 2 +- api/controllers/console/app/generator.py | 2 +- api/controllers/console/app/message.py | 2 +- api/controllers/console/app/workflow.py | 8 +- .../console/app/workflow_app_log.py | 2 +- .../console/app/workflow_draft_variable.py | 8 +- api/controllers/console/app/workflow_run.py | 4 +- api/controllers/console/auth/oauth_server.py | 2 +- api/controllers/console/datasets/datasets.py | 2 +- .../console/datasets/datasets_document.py | 4 +- .../console/datasets/datasets_segments.py | 2 +- .../console/datasets/hit_testing_base.py | 2 +- .../datasets/rag_pipeline/datasource_auth.py | 4 +- .../rag_pipeline_draft_variable.py | 2 +- .../rag_pipeline/rag_pipeline_workflow.py | 2 +- api/controllers/console/explore/audio.py | 2 +- api/controllers/console/explore/completion.py | 2 +- api/controllers/console/explore/message.py | 2 +- api/controllers/console/explore/trial.py | 4 +- api/controllers/console/explore/workflow.py | 4 +- api/controllers/console/remote_files.py | 2 +- .../console/workspace/agent_providers.py | 2 +- api/controllers/console/workspace/endpoint.py | 2 +- .../workspace/load_balancing_config.py | 4 +- .../console/workspace/model_providers.py | 6 +- api/controllers/console/workspace/models.py | 6 +- api/controllers/console/workspace/plugin.py | 2 +- .../console/workspace/tool_providers.py | 2 +- .../console/workspace/trigger_providers.py | 2 +- api/controllers/inner_api/plugin/plugin.py | 2 +- api/controllers/mcp/mcp.py | 2 +- api/controllers/service_api/app/audio.py | 2 +- api/controllers/service_api/app/completion.py | 2 +- api/controllers/service_api/app/workflow.py | 6 +- .../service_api/dataset/dataset.py | 2 +- .../service_api/dataset/segment.py | 2 +- .../service_api/workspace/models.py | 2 +- api/controllers/web/audio.py | 2 +- api/controllers/web/completion.py | 2 +- api/controllers/web/message.py | 2 +- api/controllers/web/remote_files.py | 2 +- api/controllers/web/workflow.py | 4 +- api/core/agent/base_agent_runner.py | 28 +- api/core/agent/cot_agent_runner.py | 17 +- api/core/agent/cot_chat_agent_runner.py | 3 +- api/core/agent/cot_completion_agent_runner.py | 3 +- api/core/agent/fc_agent_runner.py | 15 +- .../agent/output_parser/cot_output_parser.py | 3 +- .../model_config/converter.py | 7 +- .../easy_ui_based_app/model_config/manager.py | 3 +- .../prompt_template/manager.py | 3 +- .../easy_ui_based_app/variables/manager.py | 3 +- api/core/app/app_config/entities.py | 6 +- .../features/file_upload/manager.py | 3 +- .../variables/manager.py | 3 +- .../app/apps/advanced_chat/app_generator.py | 9 +- api/core/app/apps/advanced_chat/app_runner.py | 12 +- .../advanced_chat/generate_task_pipeline.py | 12 +- api/core/app/apps/agent_chat/app_generator.py | 2 +- api/core/app/apps/agent_chat/app_runner.py | 6 +- .../base_app_generate_response_converter.py | 3 +- api/core/app/apps/base_app_generator.py | 6 +- api/core/app/apps/base_app_queue_manager.py | 2 +- api/core/app/apps/base_app_runner.py | 23 +- api/core/app/apps/chat/app_generator.py | 2 +- api/core/app/apps/chat/app_runner.py | 4 +- .../common/graph_runtime_state_support.py | 3 +- .../common/workflow_response_converter.py | 26 +- api/core/app/apps/completion/app_generator.py | 2 +- api/core/app/apps/completion/app_runner.py | 4 +- .../app/apps/pipeline/pipeline_generator.py | 4 +- api/core/app/apps/pipeline/pipeline_runner.py | 15 +- api/core/app/apps/workflow/app_generator.py | 8 +- api/core/app/apps/workflow/app_runner.py | 11 +- .../apps/workflow/generate_task_pipeline.py | 6 +- api/core/app/apps/workflow_app_runner.py | 68 +- api/core/app/entities/app_invoke_entities.py | 4 +- api/core/app/entities/queue_entities.py | 8 +- api/core/app/entities/task_entities.py | 8 +- .../hosting_moderation/hosting_moderation.py | 3 +- .../conversation_variable_persist_layer.py | 5 +- .../app/layers/pause_state_persist_layer.py | 5 +- api/core/app/layers/suspend_layer.py | 5 +- api/core/app/layers/timeslice_layer.py | 6 +- api/core/app/layers/trigger_post_layer.py | 5 +- api/core/app/llm/model_access.py | 9 +- api/core/app/llm/quota.py | 2 +- .../based_generate_task_pipeline.py | 2 +- .../easy_ui_based_generate_task_pipeline.py | 14 +- .../app/task_pipeline/message_file_utils.py | 5 +- api/core/app/workflow/file_runtime.py | 9 +- api/core/app/workflow/layers/llm_quota.py | 11 +- api/core/app/workflow/layers/observability.py | 8 +- api/core/app/workflow/layers/persistence.py | 17 +- .../base/tts/app_generator_tts_publisher.py | 5 +- api/core/datasource/datasource_manager.py | 8 +- api/core/datasource/entities/api_entities.py | 2 +- .../datasource/utils/message_transformer.py | 3 +- api/core/entities/execution_extra_content.py | 2 +- api/core/entities/mcp_provider.py | 2 +- api/core/entities/model_entities.py | 3 +- api/core/entities/provider_configuration.py | 20 +- api/core/entities/provider_entities.py | 2 +- .../helper/code_executor/code_executor.py | 2 +- api/core/helper/moderation.py | 7 +- api/core/hosting_configuration.py | 2 +- api/core/indexing_runner.py | 2 +- api/core/llm_generator/llm_generator.py | 10 +- .../output_parser/structured_output.py | 10 +- api/core/mcp/server/streamable_http.py | 3 +- api/core/mcp/utils.py | 2 +- api/core/memory/token_buffer_memory.py | 18 +- api/core/model_manager.py | 17 +- .../openai_moderation/openai_moderation.py | 3 +- api/core/ops/aliyun_trace/aliyun_trace.py | 4 +- api/core/ops/aliyun_trace/utils.py | 4 +- .../arize_phoenix_trace.py | 2 +- api/core/ops/langfuse_trace/langfuse_trace.py | 2 +- .../ops/langsmith_trace/langsmith_trace.py | 2 +- api/core/ops/mlflow_trace/mlflow_trace.py | 2 +- api/core/ops/opik_trace/opik_trace.py | 2 +- api/core/ops/tencent_trace/span_builder.py | 7 +- api/core/ops/tencent_trace/tencent_trace.py | 8 +- api/core/ops/weave_trace/weave_trace.py | 2 +- api/core/plugin/backwards_invocation/model.py | 27 +- api/core/plugin/backwards_invocation/node.py | 14 +- api/core/plugin/entities/marketplace.py | 2 +- api/core/plugin/entities/plugin.py | 2 +- api/core/plugin/entities/plugin_daemon.py | 4 +- api/core/plugin/entities/request.py | 19 +- api/core/plugin/impl/base.py | 16 +- api/core/plugin/impl/model.py | 13 +- api/core/plugin/impl/model_runtime.py | 14 +- api/core/plugin/impl/model_runtime_factory.py | 3 +- api/core/plugin/utils/converter.py | 3 +- api/core/prompt/advanced_prompt_transform.py | 18 +- .../prompt/agent_history_prompt_transform.py | 11 +- api/core/prompt/prompt_transform.py | 5 +- api/core/prompt/simple_prompt_transform.py | 15 +- api/core/prompt/utils/prompt_message_util.py | 3 +- api/core/provider_manager.py | 16 +- .../data_post_processor.py | 4 +- api/core/rag/datasource/retrieval_service.py | 2 +- api/core/rag/datasource/vdb/vector_factory.py | 2 +- api/core/rag/docstore/dataset_docstore.py | 2 +- api/core/rag/embedding/cached_embedding.py | 4 +- .../processor/paragraph_index_processor.py | 21 +- api/core/rag/models/document.py | 3 +- api/core/rag/rerank/rerank_model.py | 5 +- api/core/rag/rerank/weight_rerank.py | 2 +- api/core/rag/retrieval/dataset_retrieval.py | 10 +- .../multi_dataset_function_call_router.py | 5 +- .../router/multi_dataset_react_route.py | 7 +- api/core/rag/splitter/fixed_text_splitter.py | 3 +- .../celery_workflow_execution_repository.py | 2 +- ...lery_workflow_node_execution_repository.py | 2 +- api/core/repositories/factory.py | 2 +- .../repositories/human_input_repository.py | 4 +- ...qlalchemy_workflow_execution_repository.py | 6 +- ...hemy_workflow_node_execution_repository.py | 8 +- .../builtin_tool/providers/audio/tools/asr.py | 7 +- .../builtin_tool/providers/audio/tools/tts.py | 3 +- api/core/tools/builtin_tool/tool.py | 5 +- api/core/tools/custom_tool/tool.py | 2 +- api/core/tools/entities/api_entities.py | 2 +- api/core/tools/mcp_tool/tool.py | 3 +- api/core/tools/tool_engine.py | 3 +- api/core/tools/tool_file_manager.py | 2 +- api/core/tools/tool_manager.py | 5 +- .../dataset_multi_retriever_tool.py | 2 +- api/core/tools/utils/message_transformer.py | 2 +- .../tools/utils/model_invocation_utils.py | 7 +- .../utils/workflow_configuration_sync.py | 5 +- api/core/tools/workflow_as_tool/provider.py | 2 +- api/core/tools/workflow_as_tool/tool.py | 4 +- api/core/trigger/debug/event_selectors.py | 2 +- api/core/workflow/human_input_compat.py | 5 +- api/core/workflow/node_factory.py | 32 +- api/core/workflow/node_runtime.py | 57 +- api/core/workflow/nodes/agent/agent_node.py | 5 +- api/core/workflow/nodes/agent/entities.py | 4 +- .../nodes/agent/message_transformer.py | 16 +- .../workflow/nodes/agent/runtime_support.py | 4 +- .../nodes/datasource/datasource_node.py | 17 +- .../workflow/nodes/datasource/entities.py | 5 +- .../nodes/knowledge_index/entities.py | 4 +- .../knowledge_index/knowledge_index_node.py | 12 +- .../nodes/knowledge_retrieval/entities.py | 3 +- .../knowledge_retrieval_node.py | 13 +- .../nodes/knowledge_retrieval/retrieval.py | 4 +- .../workflow/nodes/trigger_plugin/entities.py | 4 +- .../trigger_plugin/trigger_event_node.py | 8 +- .../nodes/trigger_schedule/entities.py | 4 +- .../trigger_schedule/trigger_schedule_node.py | 8 +- .../nodes/trigger_webhook/entities.py | 6 +- .../workflow/nodes/trigger_webhook/node.py | 12 +- api/core/workflow/template_rendering.py | 3 +- api/core/workflow/workflow_entry.py | 28 +- api/enterprise/telemetry/draft_trace.py | 3 +- ...rameters_cache_when_sync_draft_workflow.py | 5 +- ...oin_when_app_published_workflow_updated.py | 2 +- api/extensions/ext_sentry.py | 3 +- ..._api_workflow_node_execution_repository.py | 2 +- .../logstore_api_workflow_run_repository.py | 2 +- .../logstore_workflow_execution_repository.py | 4 +- ...tore_workflow_node_execution_repository.py | 8 +- api/extensions/otel/parser/base.py | 10 +- api/extensions/otel/parser/llm.py | 4 +- api/extensions/otel/parser/retrieval.py | 6 +- api/extensions/otel/parser/tool.py | 8 +- api/factories/file_factory/builders.py | 3 +- api/factories/file_factory/message_files.py | 3 +- api/factories/file_factory/storage_keys.py | 2 +- api/factories/variable_factory.py | 11 +- api/fields/conversation_fields.py | 3 +- api/fields/member_fields.py | 3 +- api/fields/message_fields.py | 2 +- api/fields/raws.py | 1 - api/fields/workflow_fields.py | 2 +- api/graphon/README.md | 135 -- api/graphon/__init__.py | 0 api/graphon/entities/__init__.py | 11 - api/graphon/entities/base_node_data.py | 178 --- api/graphon/entities/exc.py | 10 - api/graphon/entities/graph_config.py | 23 - api/graphon/entities/graph_init_params.py | 24 - api/graphon/entities/pause_reason.py | 42 - api/graphon/entities/workflow_execution.py | 71 - .../entities/workflow_node_execution.py | 141 -- api/graphon/entities/workflow_start_reason.py | 8 - api/graphon/enums.py | 262 ---- api/graphon/errors.py | 16 - api/graphon/file/__init__.py | 22 - api/graphon/file/constants.py | 48 - api/graphon/file/enums.py | 57 - api/graphon/file/file_factory.py | 39 - api/graphon/file/file_manager.py | 129 -- api/graphon/file/helpers.py | 48 - api/graphon/file/models.py | 215 --- api/graphon/file/protocols.py | 56 - api/graphon/file/runtime.py | 71 - api/graphon/file/tool_file_parser.py | 9 - api/graphon/graph/__init__.py | 11 - api/graphon/graph/edge.py | 15 - api/graphon/graph/graph.py | 438 ------ api/graphon/graph/graph_template.py | 20 - api/graphon/graph/validation.py | 125 -- api/graphon/graph_engine/__init__.py | 4 - api/graphon/graph_engine/_engine_utils.py | 15 - .../graph_engine/command_channels/README.md | 33 - .../graph_engine/command_channels/__init__.py | 6 - .../command_channels/in_memory_channel.py | 53 - .../command_channels/redis_channel.py | 153 -- .../command_processing/__init__.py | 16 - .../command_processing/command_handlers.py | 56 - .../command_processing/command_processor.py | 79 - api/graphon/graph_engine/config.py | 16 - api/graphon/graph_engine/domain/__init__.py | 14 - .../graph_engine/domain/graph_execution.py | 242 --- .../graph_engine/domain/node_execution.py | 45 - api/graphon/graph_engine/entities/__init__.py | 0 api/graphon/graph_engine/entities/commands.py | 56 - api/graphon/graph_engine/error_handler.py | 213 --- .../graph_engine/event_management/__init__.py | 14 - .../event_management/event_handlers.py | 367 ----- .../event_management/event_manager.py | 186 --- api/graphon/graph_engine/graph_engine.py | 377 ----- .../graph_engine/graph_state_manager.py | 290 ---- .../graph_engine/graph_traversal/__init__.py | 14 - .../graph_traversal/edge_processor.py | 201 --- .../graph_traversal/skip_propagator.py | 96 -- api/graphon/graph_engine/layers/README.md | 55 - api/graphon/graph_engine/layers/__init__.py | 16 - api/graphon/graph_engine/layers/base.py | 128 -- .../graph_engine/layers/debug_logging.py | 247 --- .../graph_engine/layers/execution_limits.py | 150 -- api/graphon/graph_engine/manager.py | 79 - .../graph_engine/orchestration/__init__.py | 14 - .../graph_engine/orchestration/dispatcher.py | 143 -- .../orchestration/execution_coordinator.py | 104 -- .../graph_engine/protocols/command_channel.py | 41 - .../graph_engine/ready_queue/__init__.py | 12 - .../graph_engine/ready_queue/factory.py | 37 - .../graph_engine/ready_queue/in_memory.py | 140 -- .../graph_engine/ready_queue/protocol.py | 104 -- .../response_coordinator/__init__.py | 10 - .../response_coordinator/coordinator.py | 697 --------- .../graph_engine/response_coordinator/path.py | 35 - .../response_coordinator/session.py | 66 - api/graphon/graph_engine/worker.py | 204 --- .../worker_management/__init__.py | 12 - .../worker_management/worker_pool.py | 277 ---- api/graphon/graph_events/__init__.py | 84 - api/graphon/graph_events/agent.py | 17 - api/graphon/graph_events/base.py | 31 - api/graphon/graph_events/graph.py | 57 - api/graphon/graph_events/human_input.py | 0 api/graphon/graph_events/iteration.py | 40 - api/graphon/graph_events/loop.py | 40 - api/graphon/graph_events/node.py | 106 -- api/graphon/model_runtime/README.md | 51 - api/graphon/model_runtime/README_CN.md | 64 - api/graphon/model_runtime/__init__.py | 0 .../model_runtime/callbacks/__init__.py | 0 .../model_runtime/callbacks/base_callback.py | 159 -- .../callbacks/logging_callback.py | 180 --- .../model_runtime/entities/__init__.py | 43 - .../model_runtime/entities/common_entities.py | 16 - .../model_runtime/entities/defaults.py | 130 -- .../model_runtime/entities/llm_entities.py | 219 --- .../entities/message_entities.py | 279 ---- .../model_runtime/entities/model_entities.py | 242 --- .../entities/provider_entities.py | 179 --- .../model_runtime/entities/rerank_entities.py | 27 - .../entities/text_embedding_entities.py | 47 - api/graphon/model_runtime/errors/__init__.py | 0 api/graphon/model_runtime/errors/invoke.py | 41 - api/graphon/model_runtime/errors/validate.py | 6 - api/graphon/model_runtime/memory/__init__.py | 3 - .../memory/prompt_message_memory.py | 18 - .../model_providers/__base/__init__.py | 0 .../model_providers/__base/ai_model.py | 247 --- .../__base/large_language_model.py | 638 -------- .../__base/moderation_model.py | 33 - .../model_providers/__base/rerank_model.py | 76 - .../__base/speech2text_model.py | 31 - .../__base/text_embedding_model.py | 98 -- .../__base/tokenizers/gpt2_tokenizer.py | 53 - .../model_providers/__base/tts_model.py | 58 - .../model_runtime/model_providers/__init__.py | 0 .../model_providers/_position.yaml | 43 - .../model_providers/model_provider_factory.py | 173 --- api/graphon/model_runtime/runtime.py | 159 -- .../schema_validators/__init__.py | 0 .../schema_validators/common_validator.py | 92 -- .../model_credential_schema_validator.py | 27 - .../provider_credential_schema_validator.py | 19 - api/graphon/model_runtime/utils/__init__.py | 0 api/graphon/model_runtime/utils/encoders.py | 218 --- api/graphon/node_events/__init__.py | 48 - api/graphon/node_events/agent.py | 18 - api/graphon/node_events/base.py | 40 - api/graphon/node_events/iteration.py | 36 - api/graphon/node_events/loop.py | 36 - api/graphon/node_events/node.py | 72 - api/graphon/nodes/__init__.py | 3 - api/graphon/nodes/answer/__init__.py | 0 api/graphon/nodes/answer/answer_node.py | 70 - api/graphon/nodes/answer/entities.py | 67 - api/graphon/nodes/base/__init__.py | 10 - api/graphon/nodes/base/entities.py | 87 -- api/graphon/nodes/base/node.py | 787 ---------- api/graphon/nodes/base/template.py | 150 -- .../nodes/base/usage_tracking_mixin.py | 28 - .../nodes/base/variable_template_parser.py | 130 -- api/graphon/nodes/code/__init__.py | 3 - api/graphon/nodes/code/code_node.py | 493 ------ api/graphon/nodes/code/entities.py | 57 - api/graphon/nodes/code/exc.py | 16 - api/graphon/nodes/code/limits.py | 13 - .../nodes/document_extractor/__init__.py | 4 - .../nodes/document_extractor/entities.py | 16 - api/graphon/nodes/document_extractor/exc.py | 14 - api/graphon/nodes/document_extractor/node.py | 782 ---------- api/graphon/nodes/end/__init__.py | 0 api/graphon/nodes/end/end_node.py | 47 - api/graphon/nodes/end/entities.py | 27 - api/graphon/nodes/http_request/__init__.py | 22 - api/graphon/nodes/http_request/config.py | 33 - api/graphon/nodes/http_request/entities.py | 241 --- api/graphon/nodes/http_request/exc.py | 26 - api/graphon/nodes/http_request/executor.py | 488 ------ api/graphon/nodes/http_request/node.py | 261 ---- api/graphon/nodes/human_input/__init__.py | 3 - api/graphon/nodes/human_input/entities.py | 208 --- api/graphon/nodes/human_input/enums.py | 55 - .../nodes/human_input/human_input_node.py | 299 ---- api/graphon/nodes/if_else/__init__.py | 3 - api/graphon/nodes/if_else/entities.py | 29 - api/graphon/nodes/if_else/if_else_node.py | 124 -- api/graphon/nodes/iteration/__init__.py | 5 - api/graphon/nodes/iteration/entities.py | 67 - api/graphon/nodes/iteration/exc.py | 26 - api/graphon/nodes/iteration/iteration_node.py | 686 --------- .../nodes/iteration/iteration_start_node.py | 22 - api/graphon/nodes/list_operator/__init__.py | 3 - api/graphon/nodes/list_operator/entities.py | 71 - api/graphon/nodes/list_operator/exc.py | 16 - api/graphon/nodes/list_operator/node.py | 345 ----- api/graphon/nodes/llm/__init__.py | 17 - api/graphon/nodes/llm/entities.py | 100 -- api/graphon/nodes/llm/exc.py | 45 - api/graphon/nodes/llm/file_saver.py | 139 -- api/graphon/nodes/llm/llm_utils.py | 545 ------- api/graphon/nodes/llm/node.py | 1372 ----------------- api/graphon/nodes/llm/protocols.py | 21 - api/graphon/nodes/llm/runtime_protocols.py | 77 - api/graphon/nodes/loop/__init__.py | 6 - api/graphon/nodes/loop/entities.py | 107 -- api/graphon/nodes/loop/loop_end_node.py | 22 - api/graphon/nodes/loop/loop_node.py | 428 ----- api/graphon/nodes/loop/loop_start_node.py | 22 - .../nodes/parameter_extractor/__init__.py | 3 - .../nodes/parameter_extractor/entities.py | 131 -- api/graphon/nodes/parameter_extractor/exc.py | 75 - .../parameter_extractor_node.py | 846 ---------- .../nodes/parameter_extractor/prompts.py | 184 --- api/graphon/nodes/protocols.py | 46 - .../nodes/question_classifier/__init__.py | 4 - .../nodes/question_classifier/entities.py | 30 - api/graphon/nodes/question_classifier/exc.py | 6 - .../question_classifier_node.py | 395 ----- .../question_classifier/template_prompts.py | 76 - api/graphon/nodes/runtime.py | 106 -- api/graphon/nodes/start/__init__.py | 3 - api/graphon/nodes/start/entities.py | 16 - api/graphon/nodes/start/start_node.py | 57 - .../nodes/template_transform/__init__.py | 3 - .../nodes/template_transform/entities.py | 13 - .../template_transform_node.py | 119 -- api/graphon/nodes/tool/__init__.py | 3 - api/graphon/nodes/tool/entities.py | 101 -- api/graphon/nodes/tool/exc.py | 28 - api/graphon/nodes/tool/tool_node.py | 432 ------ api/graphon/nodes/tool_runtime_entities.py | 105 -- .../nodes/variable_aggregator/__init__.py | 3 - .../nodes/variable_aggregator/entities.py | 35 - .../variable_aggregator_node.py | 40 - .../nodes/variable_assigner/__init__.py | 0 .../variable_assigner/common/__init__.py | 0 .../nodes/variable_assigner/common/exc.py | 4 - .../nodes/variable_assigner/common/helpers.py | 55 - .../nodes/variable_assigner/v1/__init__.py | 3 - .../nodes/variable_assigner/v1/node.py | 106 -- .../nodes/variable_assigner/v1/node_data.py | 18 - .../nodes/variable_assigner/v2/__init__.py | 3 - .../nodes/variable_assigner/v2/entities.py | 28 - .../nodes/variable_assigner/v2/enums.py | 20 - api/graphon/nodes/variable_assigner/v2/exc.py | 36 - .../nodes/variable_assigner/v2/helpers.py | 98 -- .../nodes/variable_assigner/v2/node.py | 257 --- api/graphon/prompt_entities.py | 47 - api/graphon/runtime/__init__.py | 22 - api/graphon/runtime/graph_runtime_state.py | 704 --------- .../runtime/graph_runtime_state_protocol.py | 79 - api/graphon/runtime/read_only_wrappers.py | 82 - api/graphon/runtime/variable_pool.py | 279 ---- api/graphon/template_rendering.py | 18 - api/graphon/utils/__init__.py | 0 api/graphon/utils/condition/__init__.py | 0 api/graphon/utils/condition/entities.py | 49 - api/graphon/utils/condition/processor.py | 504 ------ api/graphon/utils/json_in_md_parser.py | 58 - api/graphon/variable_loader.py | 75 - api/graphon/variables/__init__.py | 82 - api/graphon/variables/consts.py | 7 - api/graphon/variables/exc.py | 2 - api/graphon/variables/factory.py | 202 --- api/graphon/variables/input_entities.py | 62 - api/graphon/variables/segment_group.py | 22 - api/graphon/variables/segments.py | 253 --- api/graphon/variables/types.py | 273 ---- api/graphon/variables/utils.py | 33 - api/graphon/variables/variables.py | 172 --- api/graphon/workflow_type_encoder.py | 49 - api/libs/helper.py | 4 +- api/models/human_input.py | 2 +- api/models/model.py | 6 +- api/models/utils/file_input_compat.py | 3 +- api/models/workflow.py | 31 +- api/pyproject.toml | 5 +- api/pyrefly-local-excludes.txt | 27 - .../api_workflow_run_repository.py | 4 +- ..._api_workflow_node_execution_repository.py | 2 +- .../sqlalchemy_api_workflow_run_repository.py | 6 +- ...hemy_execution_extra_content_repository.py | 6 +- api/services/app_dsl_service.py | 12 +- api/services/app_service.py | 4 +- api/services/app_task_service.py | 3 +- api/services/audio_service.py | 2 +- .../clear_free_plan_tenant_expired_logs.py | 2 +- api/services/conversation_service.py | 2 +- api/services/conversation_variable_updater.py | 2 +- api/services/dataset_service.py | 6 +- api/services/datasource_provider_service.py | 2 +- .../entities/model_provider_entities.py | 18 +- api/services/external_knowledge_service.py | 2 +- api/services/file_service.py | 2 +- api/services/hit_testing_service.py | 3 +- .../human_input_delivery_test_service.py | 2 +- api/services/human_input_service.py | 12 +- api/services/message_service.py | 2 +- api/services/model_load_balancing_service.py | 12 +- api/services/model_provider_service.py | 3 +- api/services/rag_pipeline/rag_pipeline.py | 22 +- .../rag_pipeline/rag_pipeline_dsl_service.py | 12 +- .../archive_paid_plan_workflow_run.py | 2 +- api/services/summary_index_service.py | 4 +- .../tools/api_tools_manage_service.py | 2 +- .../tools/workflow_tools_manage_service.py | 2 +- api/services/trigger/schedule_service.py | 2 +- api/services/trigger/trigger_service.py | 2 +- api/services/trigger/webhook_service.py | 6 +- api/services/variable_truncator.py | 5 +- api/services/vector_service.py | 3 +- api/services/workflow/workflow_converter.py | 10 +- api/services/workflow_app_service.py | 2 +- .../workflow_draft_variable_service.py | 26 +- .../workflow_event_snapshot_service.py | 8 +- api/services/workflow_service.py | 50 +- .../app_generate/workflow_execute_task.py | 2 +- api/tasks/async_workflow_tasks.py | 2 +- .../batch_create_segment_to_index_task.py | 2 +- api/tasks/human_input_timeout_tasks.py | 4 +- api/tasks/mail_human_input_delivery_task.py | 2 +- api/tasks/trigger_processing_tasks.py | 2 +- api/tasks/workflow_execution_tasks.py | 4 +- api/tasks/workflow_node_execution_tasks.py | 6 +- .../test_datasource_manager_integration.py | 3 +- .../test_datasource_node_integration.py | 5 +- .../factories/test_storage_key_loader.py | 2 +- .../model_runtime/__mock/plugin_model.py | 6 +- .../test_workflow_draft_variable_service.py | 8 +- .../test_remove_app_and_related_data_task.py | 4 +- .../workflow/nodes/__mock/model.py | 3 +- .../workflow/nodes/test_code.py | 12 +- .../workflow/nodes/test_http.py | 13 +- .../workflow/nodes/test_llm.py | 11 +- .../nodes/test_parameter_extractor.py | 11 +- .../workflow/nodes/test_template_transform.py | 7 +- .../workflow/nodes/test_tool.py | 11 +- ...test_chat_conversation_status_count_api.py | 2 +- .../app/test_workflow_draft_variable.py | 2 +- .../layers/test_pause_state_persist_layer.py | 19 +- .../test_human_input_form_repository_impl.py | 2 +- .../test_human_input_resume_node_execution.py | 24 +- .../factories/test_storage_key_loader.py | 2 +- .../helpers/execution_extra_content.py | 1 + ..._api_workflow_node_execution_repository.py | 2 +- ..._sqlalchemy_api_workflow_run_repository.py | 8 +- ...hemy_execution_extra_content_repository.py | 4 +- .../test_workflow_run_repository.py | 4 +- .../services/test_agent_service.py | 1 + .../test_conversation_variable_updater.py | 2 +- .../services/test_dataset_service.py | 2 +- .../test_dataset_service_update_dataset.py | 2 +- .../test_delete_archived_workflow_run.py | 2 +- .../test_human_input_delivery_test.py | 4 +- .../test_human_input_delivery_test_service.py | 2 +- .../services/test_messages_clean_service.py | 2 +- .../services/test_model_provider_service.py | 8 +- .../services/test_workflow_app_service.py | 2 +- .../test_workflow_draft_variable_service.py | 2 +- .../workflow/test_workflow_converter.py | 6 +- ...kflow_node_execution_service_repository.py | 2 +- .../test_mail_human_input_delivery_task.py | 6 +- .../test_remove_app_and_related_data_task.py | 4 +- .../test_workflow_pause_integration.py | 4 +- .../trigger/test_trigger_e2e.py | 2 +- .../controllers/console/app/test_audio.py | 2 +- .../controllers/console/app/test_workflow.py | 3 +- .../app/test_workflow_pause_details_api.py | 8 +- .../app/workflow_draft_variables_test.py | 8 +- .../rag_pipeline/test_datasource_auth.py | 2 +- .../test_rag_pipeline_draft_variable.py | 2 +- .../console/datasets/test_hit_testing_base.py | 2 +- .../controllers/console/explore/test_audio.py | 2 +- .../console/explore/test_message.py | 2 +- .../controllers/console/explore/test_trial.py | 2 +- .../workspace/test_load_balancing_config.py | 3 +- .../console/workspace/test_model_providers.py | 2 +- .../console/workspace/test_models.py | 4 +- .../controllers/service_api/app/test_audio.py | 2 +- .../service_api/app/test_completion.py | 2 +- .../service_api/app/test_workflow.py | 2 +- .../service_api/app/test_workflow_fields.py | 3 +- .../unit_tests/controllers/web/test_audio.py | 2 +- .../controllers/web/test_completion.py | 2 +- .../core/agent/test_cot_agent_runner.py | 2 +- .../core/agent/test_cot_chat_agent_runner.py | 2 +- .../agent/test_cot_completion_agent_runner.py | 4 +- .../core/agent/test_fc_agent_runner.py | 10 +- .../test_model_config_converter.py | 4 +- .../test_variables_manager.py | 2 +- .../features/file_upload/test_manager.py | 5 +- .../core/app/app_config/test_entities.py | 2 +- .../test_app_runner_conversation_variables.py | 2 +- .../test_generate_response_converter.py | 3 +- .../test_generate_task_pipeline.py | 4 +- .../test_generate_task_pipeline_core.py | 4 +- .../test_agent_chat_app_generator.py | 2 +- .../agent_chat/test_agent_chat_app_runner.py | 4 +- .../chat/test_app_generator_and_runner.py | 2 +- .../chat/test_base_app_runner_multimodal.py | 4 +- .../test_graph_runtime_state_support.py | 3 +- .../test_workflow_response_converter.py | 3 +- ...workflow_response_converter_human_input.py | 5 +- ..._workflow_response_converter_resumption.py | 5 +- ..._workflow_response_converter_truncation.py | 4 +- .../app/apps/completion/test_app_runner.py | 2 +- ...est_completion_completion_app_generator.py | 2 +- ...st_pipeline_generate_response_converter.py | 3 +- .../pipeline/test_pipeline_queue_manager.py | 2 +- .../app/apps/pipeline/test_pipeline_runner.py | 2 +- .../core/app/apps/test_base_app_generator.py | 5 +- .../core/app/apps/test_base_app_runner.py | 18 +- .../core/app/apps/test_pause_resume.py | 15 +- .../app/apps/test_workflow_app_runner_core.py | 34 +- .../test_workflow_app_runner_notifications.py | 4 +- .../test_workflow_app_runner_single_node.py | 4 +- .../app/apps/test_workflow_pause_events.py | 10 +- .../test_generate_response_converter.py | 3 +- .../workflow/test_generate_task_pipeline.py | 5 +- .../test_generate_task_pipeline_core.py | 4 +- .../core/app/entities/test_task_entities.py | 3 +- ...est_conversation_variable_persist_layer.py | 15 +- .../layers/test_pause_state_persist_layer.py | 22 +- .../core/app/layers/test_suspend_layer.py | 3 +- .../core/app/layers/test_timeslice_layer.py | 3 +- .../app/layers/test_trigger_post_layer.py | 5 +- .../test_based_generate_task_pipeline.py | 2 +- ...st_easy_ui_based_generate_task_pipeline.py | 4 +- ...sy_ui_based_generate_task_pipeline_core.py | 6 +- .../test_easy_ui_message_end_files.py | 2 +- .../app/test_easy_ui_model_config_manager.py | 3 +- .../app/workflow/layers/test_persistence.py | 4 +- .../core/app/workflow/test_file_runtime.py | 2 +- .../core/app/workflow/test_node_factory.py | 2 +- .../test_observability_layer_extra.py | 3 +- .../app/workflow/test_persistence_layer.py | 14 +- .../base/test_app_generator_tts_publisher.py | 6 +- .../datasource/test_datasource_manager.py | 7 +- .../utils/test_message_transformer.py | 3 +- .../test_entities_execution_extra_content.py | 5 +- .../entities/test_entities_model_entities.py | 6 +- .../test_entities_provider_configuration.py | 22 +- .../test_entities_provider_entities.py | 2 +- .../output_parser/test_structured_output.py | 28 +- .../core/llm_generator/test_llm_generator.py | 4 +- .../core/mcp/server/test_streamable_http.py | 2 +- .../core/memory/test_token_buffer_memory.py | 4 +- .../test_model_provider_factory.py | 1 - .../ops/aliyun_trace/test_aliyun_trace.py | 4 +- .../aliyun_trace/test_aliyun_trace_utils.py | 4 +- .../ops/langfuse_trace/test_langfuse_trace.py | 2 +- .../langsmith_trace/test_langsmith_trace.py | 2 +- .../ops/mlflow_trace/test_mlflow_trace.py | 2 +- .../core/ops/opik_trace/test_opik_trace.py | 2 +- .../ops/tencent_trace/test_span_builder.py | 4 +- .../ops/tencent_trace/test_tencent_trace.py | 4 +- .../core/ops/test_arize_phoenix_trace.py | 2 +- .../core/ops/weave_trace/test_weave_trace.py | 2 +- .../plugin/test_backwards_invocation_model.py | 3 +- .../core/plugin/test_model_runtime_adapter.py | 6 +- .../core/plugin/test_plugin_entities.py | 12 +- .../core/plugin/test_plugin_runtime.py | 16 +- .../core/plugin/utils/test_chunk_merger.py | 3 +- .../prompt/test_advanced_prompt_transform.py | 14 +- .../test_agent_history_prompt_transform.py | 13 +- .../core/prompt/test_prompt_message.py | 5 +- .../core/prompt/test_prompt_transform.py | 2 +- .../prompt/test_simple_prompt_transform.py | 12 +- .../test_data_post_processor.py | 5 +- .../rag/embedding/test_cached_embedding.py | 4 +- .../rag/embedding/test_embedding_service.py | 8 +- .../test_paragraph_index_processor.py | 6 +- .../core/rag/indexing/test_indexing_runner.py | 2 +- .../core/rag/rerank/test_reranker.py | 2 +- .../rag/retrieval/test_dataset_retrieval.py | 4 +- ...test_multi_dataset_function_call_router.py | 3 +- .../test_multi_dataset_react_route.py | 5 +- ...st_celery_workflow_execution_repository.py | 3 +- ...lery_workflow_node_execution_repository.py | 6 +- .../test_human_input_form_repository_impl.py | 10 +- .../test_human_input_repository.py | 4 +- ...qlalchemy_workflow_execution_repository.py | 3 +- ...hemy_workflow_node_execution_repository.py | 12 +- ...rkflow_node_execution_conflict_handling.py | 10 +- ...test_workflow_node_execution_truncation.py | 10 +- api/tests/unit_tests/core/test_file.py | 1 + .../unit_tests/core/test_model_manager.py | 2 +- .../core/test_provider_configuration.py | 18 +- .../unit_tests/core/test_provider_manager.py | 4 +- .../core/tools/test_builtin_tool_base.py | 2 +- .../core/tools/test_builtin_tools_extra.py | 4 +- .../core/tools/test_tool_file_manager.py | 2 +- .../utils/test_model_invocation_utils.py | 4 +- .../utils/test_workflow_configuration_sync.py | 2 +- .../tools/workflow_as_tool/test_provider.py | 2 +- .../core/tools/workflow_as_tool/test_tool.py | 2 +- .../debug/test_debug_event_selectors.py | 2 +- .../unit_tests/core/variables/test_segment.py | 10 +- .../core/variables/test_segment_type.py | 1 - .../variables/test_segment_type_validation.py | 4 +- .../core/variables/test_variables.py | 3 +- .../entities/test_graph_runtime_state.py | 307 ---- .../workflow/entities/test_pause_reason.py | 88 -- .../core/workflow/entities/test_template.py | 87 -- .../workflow/entities/test_variable_pool.py | 136 -- .../entities/test_workflow_node_execution.py | 225 --- .../core/workflow/graph/test_graph.py | 281 ---- .../core/workflow/graph/test_graph_builder.py | 59 - .../graph/test_graph_skip_validation.py | 118 -- .../workflow/graph/test_graph_validation.py | 219 --- .../core/workflow/graph_engine/README.md | 453 +----- .../command_channels/test_redis_channel.py | 315 ---- .../event_management/test_event_handlers.py | 119 -- .../event_management/test_event_manager.py | 39 - .../graph_engine/graph_traversal/__init__.py | 1 - .../graph_traversal/test_skip_propagator.py | 307 ---- .../graph_engine/human_input_test_utils.py | 131 -- .../workflow/graph_engine/layers/conftest.py | 10 +- .../layers/test_layer_initialization.py | 57 - .../graph_engine/layers/test_llm_quota.py | 11 +- .../graph_engine/layers/test_observability.py | 8 +- .../orchestration/test_dispatcher.py | 189 --- .../graph_engine/test_answer_end_with_text.py | 37 - .../test_answer_order_workflow.py | 28 - ...est_array_iteration_formatting_workflow.py | 24 - .../graph_engine/test_auto_mock_system.py | 392 ----- .../graph_engine/test_basic_chatflow.py | 41 - .../graph_engine/test_command_system.py | 266 ---- .../test_complex_branch_workflow.py | 134 -- ...ditional_streaming_vs_template_workflow.py | 220 --- .../graph_engine/test_database_utils.py | 46 - .../test_dispatcher_pause_drain.py | 72 - .../test_end_node_without_value_type.py | 60 - .../test_execution_coordinator.py | 62 - .../graph_engine/test_graph_engine.py | 770 --------- .../test_graph_execution_serialization.py | 196 --- .../graph_engine/test_graph_state_snapshot.py | 190 --- .../test_human_input_pause_multi_branch.py | 389 ----- .../test_human_input_pause_single_branch.py | 346 ----- .../graph_engine/test_if_else_streaming.py | 324 ---- .../test_iteration_flatten_output.py | 126 -- .../graph_engine/test_loop_contains_answer.py | 88 -- .../workflow/graph_engine/test_loop_node.py | 41 - .../graph_engine/test_loop_with_tool.py | 72 - .../graph_engine/test_mock_example.py | 281 ---- .../graph_engine/test_mock_factory.py | 3 +- .../test_mock_iteration_simple.py | 199 --- .../workflow/graph_engine/test_mock_nodes.py | 9 +- .../test_mock_nodes_template_code.py | 670 -------- .../workflow/graph_engine/test_mock_simple.py | 231 --- .../test_parallel_human_input_join_resume.py | 22 +- ...rallel_human_input_pause_missing_finish.py | 336 ---- .../test_parallel_streaming_workflow.py | 286 ---- .../test_pause_deferred_ready_nodes.py | 311 ---- .../graph_engine/test_pause_resume_state.py | 219 --- .../test_redis_stop_integration.py | 268 ---- .../graph_engine/test_response_session.py | 55 - .../test_streaming_conversation_variables.py | 79 - .../graph_engine/test_table_runner.py | 13 +- ..._update_conversation_variable_iteration.py | 41 - .../graph_engine/test_variable_aggregator.py | 58 - .../test_variable_update_events.py | 129 -- .../core/workflow/graph_engine/test_worker.py | 148 -- .../nodes/agent/test_message_transformer.py | 3 +- .../nodes/agent/test_runtime_support.py | 3 +- .../core/workflow/nodes/answer/test_answer.py | 9 +- .../workflow/nodes/base/test_base_node.py | 4 +- .../test_get_node_type_classes_mapping.py | 3 +- .../workflow/nodes/code/code_node_spec.py | 3 +- .../core/workflow/nodes/code/entities_spec.py | 352 ----- .../nodes/datasource/test_datasource_node.py | 5 +- .../nodes/http_request/test_config.py | 33 - .../nodes/http_request/test_entities.py | 233 --- .../test_http_request_executor.py | 8 +- .../http_request/test_http_request_node.py | 10 +- .../human_input/test_email_delivery_config.py | 3 +- .../nodes/human_input/test_entities.py | 131 +- .../test_human_input_form_filled_event.py | 9 +- .../workflow/nodes/iteration/entities_spec.py | 339 ---- .../nodes/iteration/iteration_node_spec.py | 438 ------ .../test_iteration_abort_propagation.py | 201 --- .../test_iteration_child_engine_errors.py | 4 +- .../test_parallel_iteration_duration.py | 67 - .../test_knowledge_index_node.py | 6 +- .../test_knowledge_retrieval_node.py | 8 +- .../workflow/nodes/list_operator/node_spec.py | 4 +- .../workflow/nodes/llm/test_file_saver.py | 170 -- .../core/workflow/nodes/llm/test_llm_utils.py | 7 +- .../core/workflow/nodes/llm/test_node.py | 26 +- .../core/workflow/nodes/llm/test_scenarios.py | 25 - .../parameter_extractor/test_entities.py | 27 - .../test_parameter_extractor_node.py | 4 +- .../nodes/template_transform/entities_spec.py | 225 --- .../template_transform_node_spec.py | 4 +- .../test_template_transform_node.py | 4 +- .../core/workflow/nodes/test_base_node.py | 8 +- .../nodes/test_document_extractor_node.py | 4 +- .../core/workflow/nodes/test_if_else.py | 10 +- .../core/workflow/nodes/test_list_operator.py | 4 +- .../core/workflow/nodes/test_loop_node.py | 150 -- .../nodes/test_question_classifier_node.py | 126 -- .../nodes/test_start_node_json_object.py | 8 +- .../workflow/nodes/tool/test_tool_node.py | 4 +- .../nodes/tool/test_tool_node_runtime.py | 10 +- .../trigger_plugin/test_trigger_event_node.py | 7 +- .../v1/test_variable_assigner_v1.py | 312 ---- .../nodes/variable_assigner/v2/__init__.py | 1 - .../variable_assigner/v2/test_helpers.py | 22 - .../v2/test_variable_assigner_v2.py | 430 ------ .../workflow/nodes/webhook/test_exceptions.py | 2 +- .../webhook/test_webhook_file_conversion.py | 8 +- .../nodes/webhook/test_webhook_node.py | 11 +- .../unit_tests/core/workflow/test_enums.py | 41 - .../core/workflow/test_human_input_compat.py | 2 +- .../core/workflow/test_node_factory.py | 8 +- .../core/workflow/test_node_runtime.py | 8 +- .../core/workflow/test_system_variable.py | 6 +- .../core/workflow/test_variable_pool.py | 18 +- .../core/workflow/test_workflow_entry.py | 13 +- .../workflow/test_workflow_entry_helpers.py | 15 +- .../test_workflow_entry_redis_channel.py | 5 +- .../core/workflow/utils/test_condition.py | 52 - .../utils/test_variable_template_parser.py | 48 - .../factories/test_build_from_mapping.py | 2 +- .../factories/test_variable_factory.py | 10 +- .../unit_tests/fields/test_file_fields.py | 2 +- .../graphon/file/test_file_factory.py | 18 - .../graphon/file/test_file_manager.py | 133 -- .../unit_tests/graphon/file/test_models.py | 54 - .../graphon/model_runtime/__base/__init__.py | 0 .../__base/test_increase_tool_call.py | 114 -- ...large_language_model_non_stream_parsing.py | 126 -- .../graphon/model_runtime/__init__.py | 0 .../callbacks/test_base_callback.py | 964 ------------ .../callbacks/test_logging_callback.py | 700 --------- .../entities/test_common_entities.py | 35 - .../entities/test_llm_entities.py | 148 -- .../entities/test_message_entities.py | 210 --- .../entities/test_model_entities.py | 220 --- .../model_runtime/errors/test_invoke.py | 63 - .../model_providers/__base/test_ai_model.py | 254 --- .../__base/test_large_language_model.py | 452 ------ .../__base/test_moderation_model.py | 56 - .../__base/test_rerank_model.py | 110 -- .../__base/test_runtime_user_forwarding.py | 170 -- .../__base/test_speech2text_model.py | 56 - .../__base/test_text_embedding_model.py | 146 -- .../model_providers/__base/test_tts_model.py | 83 - .../__base/tokenizers/test_gpt2_tokenizer.py | 96 -- .../test_common_validator.py | 201 --- .../test_model_credential_schema_validator.py | 233 --- ...st_provider_credential_schema_validator.py | 72 - .../model_runtime/utils/test_encoders.py | 231 --- .../graphon/node_events/test_base.py | 19 - .../graphon/utils/test_json_in_md_parser.py | 75 - .../unit_tests/libs/_human_input/support.py | 1 + .../libs/_human_input/test_form_service.py | 2 +- .../libs/_human_input/test_models.py | 2 +- .../models/test_conversation_variable.py | 3 +- api/tests/unit_tests/models/test_model.py | 2 +- api/tests/unit_tests/models/test_workflow.py | 8 +- .../unit_tests/models/test_workflow_models.py | 2 +- .../test_sqlalchemy_repository.py | 10 +- ...hemy_workflow_node_execution_repository.py | 4 +- .../services/dataset_service_test_helpers.py | 2 +- .../services/document_service_validation.py | 2 +- .../services/test_app_dsl_service.py | 2 +- .../test_datasource_provider_service.py | 2 +- .../services/test_human_input_service.py | 12 +- .../test_model_load_balancing_service.py | 6 +- ...est_model_provider_service_sanitization.py | 4 +- .../services/test_variable_truncator.py | 5 +- .../test_workflow_run_service_pause.py | 2 +- .../services/test_workflow_service.py | 2 +- .../workflow/test_draft_var_loader_simple.py | 7 +- .../test_workflow_draft_variable_service.py | 9 +- .../test_workflow_event_snapshot_service.py | 6 +- .../test_workflow_human_input_delivery.py | 6 +- .../workflow/test_workflow_service.py | 2 +- .../tasks/test_human_input_timeout_tasks.py | 2 +- api/tests/unit_tests/tools/test_mcp_tool.py | 2 +- .../test_structured_output_parser.py | 6 +- api/tests/workflow_test_utils.py | 7 +- api/uv.lock | 50 +- 883 files changed, 1779 insertions(+), 47377 deletions(-) delete mode 100644 api/graphon/README.md delete mode 100644 api/graphon/__init__.py delete mode 100644 api/graphon/entities/__init__.py delete mode 100644 api/graphon/entities/base_node_data.py delete mode 100644 api/graphon/entities/exc.py delete mode 100644 api/graphon/entities/graph_config.py delete mode 100644 api/graphon/entities/graph_init_params.py delete mode 100644 api/graphon/entities/pause_reason.py delete mode 100644 api/graphon/entities/workflow_execution.py delete mode 100644 api/graphon/entities/workflow_node_execution.py delete mode 100644 api/graphon/entities/workflow_start_reason.py delete mode 100644 api/graphon/enums.py delete mode 100644 api/graphon/errors.py delete mode 100644 api/graphon/file/__init__.py delete mode 100644 api/graphon/file/constants.py delete mode 100644 api/graphon/file/enums.py delete mode 100644 api/graphon/file/file_factory.py delete mode 100644 api/graphon/file/file_manager.py delete mode 100644 api/graphon/file/helpers.py delete mode 100644 api/graphon/file/models.py delete mode 100644 api/graphon/file/protocols.py delete mode 100644 api/graphon/file/runtime.py delete mode 100644 api/graphon/file/tool_file_parser.py delete mode 100644 api/graphon/graph/__init__.py delete mode 100644 api/graphon/graph/edge.py delete mode 100644 api/graphon/graph/graph.py delete mode 100644 api/graphon/graph/graph_template.py delete mode 100644 api/graphon/graph/validation.py delete mode 100644 api/graphon/graph_engine/__init__.py delete mode 100644 api/graphon/graph_engine/_engine_utils.py delete mode 100644 api/graphon/graph_engine/command_channels/README.md delete mode 100644 api/graphon/graph_engine/command_channels/__init__.py delete mode 100644 api/graphon/graph_engine/command_channels/in_memory_channel.py delete mode 100644 api/graphon/graph_engine/command_channels/redis_channel.py delete mode 100644 api/graphon/graph_engine/command_processing/__init__.py delete mode 100644 api/graphon/graph_engine/command_processing/command_handlers.py delete mode 100644 api/graphon/graph_engine/command_processing/command_processor.py delete mode 100644 api/graphon/graph_engine/config.py delete mode 100644 api/graphon/graph_engine/domain/__init__.py delete mode 100644 api/graphon/graph_engine/domain/graph_execution.py delete mode 100644 api/graphon/graph_engine/domain/node_execution.py delete mode 100644 api/graphon/graph_engine/entities/__init__.py delete mode 100644 api/graphon/graph_engine/entities/commands.py delete mode 100644 api/graphon/graph_engine/error_handler.py delete mode 100644 api/graphon/graph_engine/event_management/__init__.py delete mode 100644 api/graphon/graph_engine/event_management/event_handlers.py delete mode 100644 api/graphon/graph_engine/event_management/event_manager.py delete mode 100644 api/graphon/graph_engine/graph_engine.py delete mode 100644 api/graphon/graph_engine/graph_state_manager.py delete mode 100644 api/graphon/graph_engine/graph_traversal/__init__.py delete mode 100644 api/graphon/graph_engine/graph_traversal/edge_processor.py delete mode 100644 api/graphon/graph_engine/graph_traversal/skip_propagator.py delete mode 100644 api/graphon/graph_engine/layers/README.md delete mode 100644 api/graphon/graph_engine/layers/__init__.py delete mode 100644 api/graphon/graph_engine/layers/base.py delete mode 100644 api/graphon/graph_engine/layers/debug_logging.py delete mode 100644 api/graphon/graph_engine/layers/execution_limits.py delete mode 100644 api/graphon/graph_engine/manager.py delete mode 100644 api/graphon/graph_engine/orchestration/__init__.py delete mode 100644 api/graphon/graph_engine/orchestration/dispatcher.py delete mode 100644 api/graphon/graph_engine/orchestration/execution_coordinator.py delete mode 100644 api/graphon/graph_engine/protocols/command_channel.py delete mode 100644 api/graphon/graph_engine/ready_queue/__init__.py delete mode 100644 api/graphon/graph_engine/ready_queue/factory.py delete mode 100644 api/graphon/graph_engine/ready_queue/in_memory.py delete mode 100644 api/graphon/graph_engine/ready_queue/protocol.py delete mode 100644 api/graphon/graph_engine/response_coordinator/__init__.py delete mode 100644 api/graphon/graph_engine/response_coordinator/coordinator.py delete mode 100644 api/graphon/graph_engine/response_coordinator/path.py delete mode 100644 api/graphon/graph_engine/response_coordinator/session.py delete mode 100644 api/graphon/graph_engine/worker.py delete mode 100644 api/graphon/graph_engine/worker_management/__init__.py delete mode 100644 api/graphon/graph_engine/worker_management/worker_pool.py delete mode 100644 api/graphon/graph_events/__init__.py delete mode 100644 api/graphon/graph_events/agent.py delete mode 100644 api/graphon/graph_events/base.py delete mode 100644 api/graphon/graph_events/graph.py delete mode 100644 api/graphon/graph_events/human_input.py delete mode 100644 api/graphon/graph_events/iteration.py delete mode 100644 api/graphon/graph_events/loop.py delete mode 100644 api/graphon/graph_events/node.py delete mode 100644 api/graphon/model_runtime/README.md delete mode 100644 api/graphon/model_runtime/README_CN.md delete mode 100644 api/graphon/model_runtime/__init__.py delete mode 100644 api/graphon/model_runtime/callbacks/__init__.py delete mode 100644 api/graphon/model_runtime/callbacks/base_callback.py delete mode 100644 api/graphon/model_runtime/callbacks/logging_callback.py delete mode 100644 api/graphon/model_runtime/entities/__init__.py delete mode 100644 api/graphon/model_runtime/entities/common_entities.py delete mode 100644 api/graphon/model_runtime/entities/defaults.py delete mode 100644 api/graphon/model_runtime/entities/llm_entities.py delete mode 100644 api/graphon/model_runtime/entities/message_entities.py delete mode 100644 api/graphon/model_runtime/entities/model_entities.py delete mode 100644 api/graphon/model_runtime/entities/provider_entities.py delete mode 100644 api/graphon/model_runtime/entities/rerank_entities.py delete mode 100644 api/graphon/model_runtime/entities/text_embedding_entities.py delete mode 100644 api/graphon/model_runtime/errors/__init__.py delete mode 100644 api/graphon/model_runtime/errors/invoke.py delete mode 100644 api/graphon/model_runtime/errors/validate.py delete mode 100644 api/graphon/model_runtime/memory/__init__.py delete mode 100644 api/graphon/model_runtime/memory/prompt_message_memory.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/__init__.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/ai_model.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/large_language_model.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/moderation_model.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/rerank_model.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/speech2text_model.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/text_embedding_model.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py delete mode 100644 api/graphon/model_runtime/model_providers/__base/tts_model.py delete mode 100644 api/graphon/model_runtime/model_providers/__init__.py delete mode 100644 api/graphon/model_runtime/model_providers/_position.yaml delete mode 100644 api/graphon/model_runtime/model_providers/model_provider_factory.py delete mode 100644 api/graphon/model_runtime/runtime.py delete mode 100644 api/graphon/model_runtime/schema_validators/__init__.py delete mode 100644 api/graphon/model_runtime/schema_validators/common_validator.py delete mode 100644 api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py delete mode 100644 api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py delete mode 100644 api/graphon/model_runtime/utils/__init__.py delete mode 100644 api/graphon/model_runtime/utils/encoders.py delete mode 100644 api/graphon/node_events/__init__.py delete mode 100644 api/graphon/node_events/agent.py delete mode 100644 api/graphon/node_events/base.py delete mode 100644 api/graphon/node_events/iteration.py delete mode 100644 api/graphon/node_events/loop.py delete mode 100644 api/graphon/node_events/node.py delete mode 100644 api/graphon/nodes/__init__.py delete mode 100644 api/graphon/nodes/answer/__init__.py delete mode 100644 api/graphon/nodes/answer/answer_node.py delete mode 100644 api/graphon/nodes/answer/entities.py delete mode 100644 api/graphon/nodes/base/__init__.py delete mode 100644 api/graphon/nodes/base/entities.py delete mode 100644 api/graphon/nodes/base/node.py delete mode 100644 api/graphon/nodes/base/template.py delete mode 100644 api/graphon/nodes/base/usage_tracking_mixin.py delete mode 100644 api/graphon/nodes/base/variable_template_parser.py delete mode 100644 api/graphon/nodes/code/__init__.py delete mode 100644 api/graphon/nodes/code/code_node.py delete mode 100644 api/graphon/nodes/code/entities.py delete mode 100644 api/graphon/nodes/code/exc.py delete mode 100644 api/graphon/nodes/code/limits.py delete mode 100644 api/graphon/nodes/document_extractor/__init__.py delete mode 100644 api/graphon/nodes/document_extractor/entities.py delete mode 100644 api/graphon/nodes/document_extractor/exc.py delete mode 100644 api/graphon/nodes/document_extractor/node.py delete mode 100644 api/graphon/nodes/end/__init__.py delete mode 100644 api/graphon/nodes/end/end_node.py delete mode 100644 api/graphon/nodes/end/entities.py delete mode 100644 api/graphon/nodes/http_request/__init__.py delete mode 100644 api/graphon/nodes/http_request/config.py delete mode 100644 api/graphon/nodes/http_request/entities.py delete mode 100644 api/graphon/nodes/http_request/exc.py delete mode 100644 api/graphon/nodes/http_request/executor.py delete mode 100644 api/graphon/nodes/http_request/node.py delete mode 100644 api/graphon/nodes/human_input/__init__.py delete mode 100644 api/graphon/nodes/human_input/entities.py delete mode 100644 api/graphon/nodes/human_input/enums.py delete mode 100644 api/graphon/nodes/human_input/human_input_node.py delete mode 100644 api/graphon/nodes/if_else/__init__.py delete mode 100644 api/graphon/nodes/if_else/entities.py delete mode 100644 api/graphon/nodes/if_else/if_else_node.py delete mode 100644 api/graphon/nodes/iteration/__init__.py delete mode 100644 api/graphon/nodes/iteration/entities.py delete mode 100644 api/graphon/nodes/iteration/exc.py delete mode 100644 api/graphon/nodes/iteration/iteration_node.py delete mode 100644 api/graphon/nodes/iteration/iteration_start_node.py delete mode 100644 api/graphon/nodes/list_operator/__init__.py delete mode 100644 api/graphon/nodes/list_operator/entities.py delete mode 100644 api/graphon/nodes/list_operator/exc.py delete mode 100644 api/graphon/nodes/list_operator/node.py delete mode 100644 api/graphon/nodes/llm/__init__.py delete mode 100644 api/graphon/nodes/llm/entities.py delete mode 100644 api/graphon/nodes/llm/exc.py delete mode 100644 api/graphon/nodes/llm/file_saver.py delete mode 100644 api/graphon/nodes/llm/llm_utils.py delete mode 100644 api/graphon/nodes/llm/node.py delete mode 100644 api/graphon/nodes/llm/protocols.py delete mode 100644 api/graphon/nodes/llm/runtime_protocols.py delete mode 100644 api/graphon/nodes/loop/__init__.py delete mode 100644 api/graphon/nodes/loop/entities.py delete mode 100644 api/graphon/nodes/loop/loop_end_node.py delete mode 100644 api/graphon/nodes/loop/loop_node.py delete mode 100644 api/graphon/nodes/loop/loop_start_node.py delete mode 100644 api/graphon/nodes/parameter_extractor/__init__.py delete mode 100644 api/graphon/nodes/parameter_extractor/entities.py delete mode 100644 api/graphon/nodes/parameter_extractor/exc.py delete mode 100644 api/graphon/nodes/parameter_extractor/parameter_extractor_node.py delete mode 100644 api/graphon/nodes/parameter_extractor/prompts.py delete mode 100644 api/graphon/nodes/protocols.py delete mode 100644 api/graphon/nodes/question_classifier/__init__.py delete mode 100644 api/graphon/nodes/question_classifier/entities.py delete mode 100644 api/graphon/nodes/question_classifier/exc.py delete mode 100644 api/graphon/nodes/question_classifier/question_classifier_node.py delete mode 100644 api/graphon/nodes/question_classifier/template_prompts.py delete mode 100644 api/graphon/nodes/runtime.py delete mode 100644 api/graphon/nodes/start/__init__.py delete mode 100644 api/graphon/nodes/start/entities.py delete mode 100644 api/graphon/nodes/start/start_node.py delete mode 100644 api/graphon/nodes/template_transform/__init__.py delete mode 100644 api/graphon/nodes/template_transform/entities.py delete mode 100644 api/graphon/nodes/template_transform/template_transform_node.py delete mode 100644 api/graphon/nodes/tool/__init__.py delete mode 100644 api/graphon/nodes/tool/entities.py delete mode 100644 api/graphon/nodes/tool/exc.py delete mode 100644 api/graphon/nodes/tool/tool_node.py delete mode 100644 api/graphon/nodes/tool_runtime_entities.py delete mode 100644 api/graphon/nodes/variable_aggregator/__init__.py delete mode 100644 api/graphon/nodes/variable_aggregator/entities.py delete mode 100644 api/graphon/nodes/variable_aggregator/variable_aggregator_node.py delete mode 100644 api/graphon/nodes/variable_assigner/__init__.py delete mode 100644 api/graphon/nodes/variable_assigner/common/__init__.py delete mode 100644 api/graphon/nodes/variable_assigner/common/exc.py delete mode 100644 api/graphon/nodes/variable_assigner/common/helpers.py delete mode 100644 api/graphon/nodes/variable_assigner/v1/__init__.py delete mode 100644 api/graphon/nodes/variable_assigner/v1/node.py delete mode 100644 api/graphon/nodes/variable_assigner/v1/node_data.py delete mode 100644 api/graphon/nodes/variable_assigner/v2/__init__.py delete mode 100644 api/graphon/nodes/variable_assigner/v2/entities.py delete mode 100644 api/graphon/nodes/variable_assigner/v2/enums.py delete mode 100644 api/graphon/nodes/variable_assigner/v2/exc.py delete mode 100644 api/graphon/nodes/variable_assigner/v2/helpers.py delete mode 100644 api/graphon/nodes/variable_assigner/v2/node.py delete mode 100644 api/graphon/prompt_entities.py delete mode 100644 api/graphon/runtime/__init__.py delete mode 100644 api/graphon/runtime/graph_runtime_state.py delete mode 100644 api/graphon/runtime/graph_runtime_state_protocol.py delete mode 100644 api/graphon/runtime/read_only_wrappers.py delete mode 100644 api/graphon/runtime/variable_pool.py delete mode 100644 api/graphon/template_rendering.py delete mode 100644 api/graphon/utils/__init__.py delete mode 100644 api/graphon/utils/condition/__init__.py delete mode 100644 api/graphon/utils/condition/entities.py delete mode 100644 api/graphon/utils/condition/processor.py delete mode 100644 api/graphon/utils/json_in_md_parser.py delete mode 100644 api/graphon/variable_loader.py delete mode 100644 api/graphon/variables/__init__.py delete mode 100644 api/graphon/variables/consts.py delete mode 100644 api/graphon/variables/exc.py delete mode 100644 api/graphon/variables/factory.py delete mode 100644 api/graphon/variables/input_entities.py delete mode 100644 api/graphon/variables/segment_group.py delete mode 100644 api/graphon/variables/segments.py delete mode 100644 api/graphon/variables/types.py delete mode 100644 api/graphon/variables/utils.py delete mode 100644 api/graphon/variables/variables.py delete mode 100644 api/graphon/workflow_type_encoder.py delete mode 100644 api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py delete mode 100644 api/tests/unit_tests/core/workflow/entities/test_pause_reason.py delete mode 100644 api/tests/unit_tests/core/workflow/entities/test_template.py delete mode 100644 api/tests/unit_tests/core/workflow/entities/test_variable_pool.py delete mode 100644 api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py delete mode 100644 api/tests/unit_tests/core/workflow/graph/test_graph.py delete mode 100644 api/tests/unit_tests/core/workflow/graph/test_graph_builder.py delete mode 100644 api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py delete mode 100644 api/tests/unit_tests/core/workflow/graph/test_graph_validation.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py delete mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_worker.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/test_loop_node.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py delete mode 100644 api/tests/unit_tests/core/workflow/test_enums.py delete mode 100644 api/tests/unit_tests/core/workflow/utils/test_condition.py delete mode 100644 api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py delete mode 100644 api/tests/unit_tests/graphon/file/test_file_factory.py delete mode 100644 api/tests/unit_tests/graphon/file/test_file_manager.py delete mode 100644 api/tests/unit_tests/graphon/file/test_models.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/__base/__init__.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/__base/test_increase_tool_call.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/__base/test_large_language_model_non_stream_parsing.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/__init__.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/callbacks/test_base_callback.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/callbacks/test_logging_callback.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/entities/test_common_entities.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/entities/test_llm_entities.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/entities/test_message_entities.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/entities/test_model_entities.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/errors/test_invoke.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_ai_model.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_large_language_model.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_moderation_model.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_rerank_model.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_runtime_user_forwarding.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_speech2text_model.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_text_embedding_model.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_tts_model.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/schema_validators/test_common_validator.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/schema_validators/test_model_credential_schema_validator.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/schema_validators/test_provider_credential_schema_validator.py delete mode 100644 api/tests/unit_tests/graphon/model_runtime/utils/test_encoders.py delete mode 100644 api/tests/unit_tests/graphon/node_events/test_base.py delete mode 100644 api/tests/unit_tests/graphon/utils/test_json_in_md_parser.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 3f53811f85e..94e857f93a5 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -36,7 +36,6 @@ /api/core/workflow/graph/ @laipz8200 @QuantumGhost /api/core/workflow/graph_events/ @laipz8200 @QuantumGhost /api/core/workflow/node_events/ @laipz8200 @QuantumGhost -/api/graphon/model_runtime/ @laipz8200 @WH-2099 # Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM) /api/core/workflow/nodes/agent/ @Nov1c444 diff --git a/api/.importlinter b/api/.importlinter index c2841f64d2c..5e06947d941 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -3,7 +3,6 @@ root_packages = core constants context - graphon configs controllers extensions @@ -13,152 +12,3 @@ root_packages = tasks services include_external_packages = True - -[importlinter:contract:workflow] -name = Workflow -type=layers -layers = - graph_engine - graph_events - graph - nodes - node_events - runtime - entities -containers = - graphon -ignore_imports = - graphon.nodes.base.node -> graphon.graph_events - graphon.nodes.iteration.iteration_node -> graphon.graph_events - graphon.nodes.loop.loop_node -> graphon.graph_events - - graphon.nodes.iteration.iteration_node -> graphon.graph_engine - graphon.nodes.loop.loop_node -> graphon.graph_engine - # TODO(QuantumGhost): fix the import violation later - graphon.entities.pause_reason -> graphon.nodes.human_input.entities - -[importlinter:contract:workflow-external-imports] -name = Workflow External Imports -type = forbidden -source_modules = - graphon -forbidden_modules = - constants - configs - context - controllers - extensions - factories - libs - models - services - tasks - core.agent - core.app - core.base - core.callback_handler - core.datasource - core.db - core.entities - core.errors - core.extension - core.external_data_tool - core.file - core.helper - core.hosting_configuration - core.indexing_runner - core.llm_generator - core.logging - core.mcp - core.memory - core.moderation - core.ops - core.plugin - core.prompt - core.provider_manager - core.rag - core.repositories - core.schemas - core.tools - core.trigger - core.variables - -[importlinter:contract:workflow-third-party-imports] -name = Workflow Third-Party Imports -type = forbidden -source_modules = - graphon -forbidden_modules = - sqlalchemy - -[importlinter:contract:rsc] -name = RSC -type = layers -layers = - graph_engine - response_coordinator -containers = - graphon.graph_engine - -[importlinter:contract:worker] -name = Worker -type = layers -layers = - graph_engine - worker -containers = - graphon.graph_engine - -[importlinter:contract:graph-engine-architecture] -name = Graph Engine Architecture -type = layers -layers = - graph_engine - orchestration - command_processing - event_management - error_handler - graph_traversal - graph_state_manager - worker_management - domain -containers = - graphon.graph_engine - -[importlinter:contract:domain-isolation] -name = Domain Model Isolation -type = forbidden -source_modules = - graphon.graph_engine.domain -forbidden_modules = - graphon.graph_engine.worker_management - graphon.graph_engine.command_channels - graphon.graph_engine.layers - graphon.graph_engine.protocols - -[importlinter:contract:worker-management] -name = Worker Management -type = forbidden -source_modules = - graphon.graph_engine.worker_management -forbidden_modules = - graphon.graph_engine.orchestration - graphon.graph_engine.command_processing - graphon.graph_engine.event_management - - -[importlinter:contract:graph-traversal-components] -name = Graph Traversal Components -type = layers -layers = - edge_processor - skip_propagator -containers = - graphon.graph_engine.graph_traversal - -[importlinter:contract:command-channels] -name = Command Channels Independence -type = independence -modules = - graphon.graph_engine.command_channels.in_memory_channel - graphon.graph_engine.command_channels.redis_channel diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 515a6a51258..7348ef62aad 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -2,9 +2,9 @@ from __future__ import annotations from typing import Any, TypeAlias +from graphon.file import helpers as file_helpers from pydantic import BaseModel, ConfigDict, computed_field -from graphon.file import helpers as file_helpers from models.model import IconType JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any] diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 357697ed30f..738e77b3715 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -5,6 +5,8 @@ from typing import Any, Literal, TypeAlias from flask import request from flask_restx import Resource +from graphon.enums import WorkflowExecutionStatus +from graphon.file import helpers as file_helpers from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session @@ -27,8 +29,6 @@ from core.ops.ops_trace_manager import OpsTraceManager from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.trigger.constants import TRIGGER_NODE_TYPES from extensions.ext_database import db -from graphon.enums import WorkflowExecutionStatus -from graphon.file import helpers as file_helpers from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow from models.model import IconType diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 91fbe4a85ae..78ddb904e14 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import Resource, fields +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -22,7 +23,6 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode from services.audio_service import AudioService diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index fe274e4c9a8..d83925d173a 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -3,6 +3,7 @@ from typing import Any, Literal from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -26,7 +27,6 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user, login_required diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index c720a5e074b..7101d5df7b4 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from controllers.console import console_ns @@ -19,7 +20,6 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db -from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_account_with_tenant, login_required from models import App from services.workflow_service import WorkflowService diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index dc752939ae9..2afe2767427 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -3,6 +3,7 @@ from typing import Literal from flask import request from flask_restx import Resource, fields, marshal_with +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from sqlalchemy import exists, func, select from werkzeug.exceptions import InternalServerError, NotFound @@ -26,7 +27,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from extensions.ext_database import db from fields.raws import FilesContainedField -from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import TimestampField, uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import current_account_with_tenant, login_required diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 2737dd1dfdf..1f5a84c0b2b 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -5,6 +5,10 @@ from typing import Any from flask import abort, request from flask_restx import Resource, fields, marshal_with +from graphon.enums import NodeType +from graphon.file import File +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -35,10 +39,6 @@ from extensions.ext_redis import redis_client from factories import file_factory, variable_factory from fields.member_fields import simple_account_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields -from graphon.enums import NodeType -from graphon.file.models import File -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, uuid_value diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 8cf0004b092..f0e26c86a5a 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -3,6 +3,7 @@ from datetime import datetime from dateutil.parser import isoparse from flask import request from flask_restx import Resource, marshal_with +from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session @@ -14,7 +15,6 @@ from fields.workflow_app_log_fields import ( build_workflow_app_log_pagination_model, build_workflow_archived_log_pagination_model, ) -from graphon.enums import WorkflowExecutionStatus from libs.login import login_required from models import App from models.model import AppMode diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 657b0724903..4052897e9a4 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -5,6 +5,10 @@ from typing import Any, NoReturn, ParamSpec, TypeVar from flask import Response, request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.file import helpers as file_helpers +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import Session @@ -20,10 +24,6 @@ from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTE from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type -from graphon.file import helpers as file_helpers -from graphon.variables.segment_group import SegmentGroup -from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment -from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import App, AppMode from models.workflow import WorkflowDraftVariable diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index d1df7227293..83e8bedc110 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -3,6 +3,8 @@ from typing import Literal, TypedDict, cast from flask import request from flask_restx import Resource, fields, marshal_with +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -26,8 +28,6 @@ from fields.workflow_run_fields import ( workflow_run_node_execution_list_fields, workflow_run_pagination_fields, ) -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from libs.custom_inputs import time_duration from libs.helper import uuid_value diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 665a80802d1..686b865871d 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -4,11 +4,11 @@ from typing import Concatenate, ParamSpec, TypeVar from flask import jsonify, request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models import Account from models.model import OAuthProviderApp diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 5d704b6224b..f23c7eb4310 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -2,6 +2,7 @@ from typing import Any, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, select from werkzeug.exceptions import Forbidden, NotFound @@ -51,7 +52,6 @@ from fields.dataset_fields import ( weighted_score_fields, ) from fields.document_fields import document_status_fields -from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index edb738aad80..ab367d84838 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -9,6 +9,8 @@ from uuid import UUID import sqlalchemy as sa from flask import request, send_file from flask_restx import Resource, fields, marshal, marshal_with +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import BaseModel, Field from sqlalchemy import asc, desc, func, select from werkzeug.exceptions import Forbidden, NotFound @@ -37,8 +39,6 @@ from fields.document_fields import ( document_status_fields, document_with_segments_fields, ) -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import DatasetProcessRule, Document, DocumentSegment, UploadFile diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 2fd84303d79..c5f4e3a6e26 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -2,6 +2,7 @@ import uuid from flask import request from flask_restx import Resource, marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import String, cast, func, or_, select from sqlalchemy.dialects.postgresql import JSONB @@ -30,7 +31,6 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields -from graphon.model_runtime.entities.model_entities import ModelType from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 699fa599c8a..8fb3699849e 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -2,6 +2,7 @@ import logging from typing import Any from flask_restx import marshal +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -20,7 +21,6 @@ from core.errors.error import ( QuotaExceededError, ) from fields.hit_testing_fields import hit_testing_record_fields -from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_user from models.account import Account from services.dataset_service import DatasetService diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 946fa599e6b..1976a6bc8ac 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -2,6 +2,8 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, NotFound @@ -10,8 +12,6 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.plugin.impl.oauth import OAuthHandler -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 977ae93c03f..f12cbd34959 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -3,6 +3,7 @@ from typing import Any, NoReturn from flask import Response, request from flask_restx import Resource, marshal, marshal_with +from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -26,7 +27,6 @@ from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTE from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type -from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import Account from models.dataset import Pipeline diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 9079fbc29a1..8e44bd68738 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -4,6 +4,7 @@ from typing import Any, Literal, cast from flask import abort, request from flask_restx import Resource, marshal_with # type: ignore +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -39,7 +40,6 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from factories import variable_factory -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.helper import TimestampField, UUIDStrOrEmpty from libs.login import current_account_with_tenant, current_user, login_required diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index bc78ee6d2d2..b1b01b5f51c 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,6 +1,7 @@ import logging from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -19,7 +20,6 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index ccdccceaa67..eacd7332fe8 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -2,6 +2,7 @@ import logging from typing import Any, Literal from uuid import UUID +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -25,7 +26,6 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_database import db -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index a72cf6328ad..fcbefcda33b 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -2,6 +2,7 @@ import logging from typing import Literal from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound @@ -23,7 +24,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from libs.login import current_account_with_tenant diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 26aa086aacb..e432574434d 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -3,6 +3,8 @@ from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -59,8 +61,6 @@ from fields.workflow_fields import ( workflow_fields, workflow_partial_fields, ) -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 17dbbdd5344..42cafc71932 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,6 +1,8 @@ import logging from typing import Any +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel from werkzeug.exceptions import InternalServerError @@ -22,8 +24,6 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_redis import redis_client -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.login import current_account_with_tenant from models.model import AppMode, InstalledApp diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 2a46d2250a0..551c86fd827 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -2,6 +2,7 @@ import urllib.parse import httpx from flask_restx import Resource +from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field import services @@ -15,7 +16,6 @@ from controllers.console import console_ns from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo -from graphon.file import helpers as file_helpers from libs.login import current_account_with_tenant, login_required from services.file_service import FileService diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 764f4887558..3fdcbc47108 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,8 +1,8 @@ from flask_restx import Resource, fields +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.agent_service import AgentService diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index f45b72f3904..b6b9deb1f92 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -2,13 +2,13 @@ from typing import Any from flask import request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginPermissionDeniedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 2a6f37aec81..e4cfca9fa4c 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,12 +1,12 @@ from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_account_with_tenant, login_required from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index b22b91706e2..8e0aefc9e3e 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -3,13 +3,13 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 3c7b97d7fce..2ec1a9435a2 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -3,14 +3,14 @@ from typing import Any, cast from flask import request from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index b3e344ccea8..aa674a63b30 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -4,6 +4,7 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden @@ -14,7 +15,6 @@ from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginDaemonClientSideError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 1273b85bc36..02eb0adc944 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -5,6 +5,7 @@ from urllib.parse import urlparse from flask import make_response, redirect, request, send_file from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -27,7 +28,6 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index feedf074b73..265b6ecd9a0 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -3,6 +3,7 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden @@ -15,7 +16,6 @@ from core.plugin.impl.oauth import OAuthHandler from core.trigger.entities.entities import SubscriptionBuilderUpdater from core.trigger.trigger_manager import TriggerManager from extensions.ext_database import db -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_user, login_required from models.account import Account from models.provider_ids import TriggerProviderID diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 72cab3de737..83c8fa02fee 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,4 +1,5 @@ from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console.wraps import setup_required from controllers.inner_api import inner_api_ns @@ -29,7 +30,6 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.signature import get_signed_file_url_for_plugin -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import length_prefixed_response from models import Account, Tenant from models.model import EndUser diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 869fb73cf55..3d00f77e79f 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -2,6 +2,7 @@ from typing import Any, Union from flask import Response from flask_restx import Resource +from graphon.variables.input_entities import VariableEntity from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import Session @@ -10,7 +11,6 @@ from controllers.mcp import mcp_ns from core.mcp import types as mcp_types from core.mcp.server.streamable_http import handle_mcp_request from extensions.ext_database import db -from graphon.variables.input_entities import VariableEntity from libs import helper from models.enums import AppMCPServerStatus from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 86d88ddafbc..6228cfc25be 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -21,7 +22,6 @@ from controllers.service_api.app.error import ( ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 31f2797d66e..3142e5118e9 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -4,6 +4,7 @@ from uuid import UUID from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -28,7 +29,6 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 94afd47f7fb..17590751395 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -4,6 +4,9 @@ from typing import Any, Literal from dateutil.parser import isoparse from flask import request from flask_restx import Namespace, Resource, fields +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -30,9 +33,6 @@ from core.helper.trace_id_helper import get_external_trace_id from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model -from graphon.enums import WorkflowExecutionStatus -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import OptionalTimestampField, TimestampField from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index dcf788f7a8f..80205b283bc 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -2,6 +2,7 @@ from typing import Any, Literal, cast from flask import request from flask_restx import marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import Forbidden, NotFound @@ -18,7 +19,6 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage from core.rag.index_processor.constant.index_type import IndexTechniqueType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag -from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 28fa9151179..b4cc9874b63 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -2,6 +2,7 @@ from typing import Any from flask import request from flask_restx import marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select from werkzeug.exceptions import NotFound @@ -21,7 +22,6 @@ from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields -from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant from models.dataset import Dataset from services.dataset_service import DatasetService, DocumentService, SegmentService diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 5ac65fc4e6d..c0a6cb0a763 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -1,9 +1,9 @@ from flask_login import current_user from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token -from graphon.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 8081dee0bd2..9ba1dc4a3ac 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import fields, marshal_with +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, field_validator from werkzeug.exceptions import InternalServerError @@ -20,7 +21,6 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 0528184d79e..e37f9af5f0a 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,6 +1,7 @@ import logging from typing import Any, Literal +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -25,7 +26,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.model import AppMode diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 4274b8c9ab2..c5505dd60de 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -2,6 +2,7 @@ import logging from typing import Literal from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -22,7 +23,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.enums import FeedbackRating diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index fe31e9d4acd..38aeccc642b 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,6 +1,7 @@ import urllib.parse import httpx +from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field, HttpUrl import services @@ -13,7 +14,6 @@ from controllers.common.errors import ( from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo -from graphon.file import helpers as file_helpers from services.file_service import FileService from ..common.schema import register_schema_models diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index ccef6e5b7f4..7f5521f9f58 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,6 +1,8 @@ import logging from typing import Any +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -23,8 +25,6 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_redis import redis_client -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index a846cf4b0f2..ff8f40407fa 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -4,6 +4,20 @@ import uuid from decimal import Decimal from typing import Union, cast +from graphon.file import file_manager +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + LLMUsage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from core.agent.entities import AgentEntity, AgentToolEntity @@ -29,20 +43,6 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db from factories import file_factory -from graphon.file import file_manager -from graphon.model_runtime.entities import ( - AssistantPromptMessage, - LLMUsage, - PromptMessage, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from graphon.model_runtime.entities.model_entities import ModelFeature -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 0a0fdfdd29a..11e2aa062d2 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -4,6 +4,15 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping, Sequence from typing import Any +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + ToolPromptMessage, + UserPromptMessage, +) + from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError @@ -15,14 +24,6 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) from models.model import Message logger = logging.getLogger(__name__) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index b3fc8d42e60..a4c438e9296 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,6 +1,5 @@ import json -from core.agent.cot_agent_runner import CotAgentRunner from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -12,6 +11,8 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.model_runtime.utils.encoders import jsonable_encoder +from core.agent.cot_agent_runner import CotAgentRunner + class CotChatAgentRunner(CotAgentRunner): def _organize_system_prompt(self) -> SystemPromptMessage: diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 51a30998ae2..d4c52a8eb16 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,6 +1,5 @@ import json -from core.agent.cot_agent_runner import CotAgentRunner from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -9,6 +8,8 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.utils.encoders import jsonable_encoder +from core.agent.cot_agent_runner import CotAgentRunner + class CotCompletionAgentRunner(CotAgentRunner): def _organize_instruction_prompt(self) -> str: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index d38d24d1e79..fdffde85d01 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -4,13 +4,6 @@ from collections.abc import Generator from copy import deepcopy from typing import Any, Union -from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.errors import AgentMaxIterationError -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -26,6 +19,14 @@ from graphon.model_runtime.entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes + +from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.errors import AgentMaxIterationError +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine from models.model import Message logger = logging.getLogger(__name__) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index c3e56fe0118..46c1f1230d0 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -3,9 +3,10 @@ import re from collections.abc import Generator from typing import Union -from core.agent.entities import AgentScratchpadUnit from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from core.agent.entities import AgentScratchpadUnit + class CotAgentOutputParser: @classmethod diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index dbd7527fc64..b7dd55632e2 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -1,13 +1,14 @@ from typing import cast +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel class ModelConfigConverter: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index f279f769aa8..5cc385c3781 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -1,9 +1,10 @@ from collections.abc import Mapping from typing import Any +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType + from core.app.app_config.entities import ModelConfigEntity from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 7715a5330a9..76196e7034e 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,5 +1,7 @@ from typing import Any +from graphon.model_runtime.entities.message_entities import PromptMessageRole + from core.app.app_config.entities import ( AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, @@ -7,7 +9,6 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.simple_prompt_transform import ModelMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode, AppModelConfigDict diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 6d63ae04d34..f0b71c58016 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,9 +1,10 @@ import re from typing import cast +from graphon.variables.input_entities import VariableEntity, VariableEntityType + from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import AppModelConfigDict _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index c67412cc291..536617edba4 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -2,13 +2,13 @@ from collections.abc import Sequence from enum import StrEnum, auto from typing import Any, Literal -from pydantic import BaseModel, Field - -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from graphon.file import FileUploadConfig from graphon.model_runtime.entities.llm_entities import LLMMode from graphon.model_runtime.entities.message_entities import PromptMessageRole from graphon.variables.input_entities import VariableEntity as WorkflowVariableEntity +from pydantic import BaseModel, Field + +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from models.model import AppMode diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 9092c1a17dc..e96517c4264 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,9 +1,10 @@ from collections.abc import Mapping from typing import Any -from constants import DEFAULT_FILE_NUMBER_LIMITS from graphon.file import FileUploadConfig +from constants import DEFAULT_FILE_NUMBER_LIMITS + class FileUploadConfigManager: @classmethod diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index 13ace32fd60..62e0c31d1ae 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,7 +1,8 @@ import re -from core.app.app_config.entities import RagPipelineVariableEntity from graphon.variables.input_entities import VariableEntity + +from core.app.app_config.entities import RagPipelineVariableEntity from models.workflow import Workflow diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index d69a80e4a93..aa2b65766f8 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -18,6 +18,11 @@ from constants import UUID_NIL if TYPE_CHECKING: from controllers.console.app.workflow import LoopNodeRunPayload +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader + from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner @@ -43,10 +48,6 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.runtime import GraphRuntimeState -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index d21fce144eb..a884a1c7f9b 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -3,6 +3,12 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import Variable from sqlalchemy import select from sqlalchemy.orm import Session @@ -37,12 +43,6 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span -from graphon.enums import WorkflowType -from graphon.graph_engine.command_channels.redis_channel import RedisChannel -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader -from graphon.variables.variables import Variable from models import Workflow from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 3577ae139bf..5203de225cc 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -9,6 +9,12 @@ from datetime import datetime from threading import Thread from typing import Any, Union +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session @@ -71,12 +77,6 @@ from core.repositories.human_input_repository import HumanInputFormRepositoryImp from core.workflow.file_reference import resolve_file_record_id from core.workflow.system_variables import build_system_variables from extensions.ext_database import db -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes import BuiltinNodeTypes -from graphon.runtime import GraphRuntimeState from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 1a44cc235ea..bb258af4c16 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -6,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -23,7 +24,6 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, In from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 09ddce327e5..a20d3f3c38f 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -1,6 +1,9 @@ import logging from typing import cast +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from core.agent.cot_chat_agent_runner import CotChatAgentRunner @@ -16,9 +19,6 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError from extensions.ext_database import db -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 5c9ba4567a5..66390116d46 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -3,10 +3,11 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping from typing import Any, Union +from graphon.model_runtime.errors.invoke import InvokeError + from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 8e8ccf2b903..7eccd59d17c 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -2,6 +2,9 @@ from collections.abc import Generator, Mapping, Sequence from contextlib import AbstractContextManager, nullcontext from typing import TYPE_CHECKING, Any, Union, final +from graphon.enums import NodeType +from graphon.file import File, FileUploadConfig +from graphon.variables.input_entities import VariableEntityType from sqlalchemy.orm import Session from core.app.apps.draft_variable_saver import ( @@ -13,9 +16,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope from extensions.ext_database import db from factories import file_factory -from graphon.enums import NodeType -from graphon.file import File, FileUploadConfig -from graphon.variables.input_entities import VariableEntityType from libs.orjson import orjson_dumps from models import Account, EndUser from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index d1771452c5f..20bf81aeecf 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -7,6 +7,7 @@ from enum import IntEnum, auto from typing import Any from cachetools import TTLCache, cachedmethod +from graphon.runtime import GraphRuntimeState from redis.exceptions import RedisError from sqlalchemy.orm import DeclarativeMeta @@ -21,7 +22,6 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from extensions.ext_redis import redis_client -from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 4a4c8b535d3..4aebc0cb30e 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -5,6 +5,17 @@ from collections.abc import Generator, Mapping, Sequence from mimetypes import guess_extension from typing import TYPE_CHECKING, Any, Union +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError + from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( @@ -30,21 +41,11 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from core.tools.tool_file_manager import ToolFileManager from extensions.ext_database import db -from graphon.file.enums import FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, Message, MessageAnnotation, MessageFile if TYPE_CHECKING: - from graphon.file.models import File + from graphon.file import File _logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index db3a98c7acd..b675a87382c 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -6,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, copy_current_request_context, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -23,7 +24,6 @@ from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeF from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account from models.model import App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 077c5239f39..050f763e958 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -16,8 +18,6 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db -from graphon.file import File -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py index 2a90fbdad0e..ab277857fe5 100644 --- a/api/core/app/apps/common/graph_runtime_state_support.py +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -4,9 +4,10 @@ from __future__ import annotations from typing import TYPE_CHECKING -from core.workflow.system_variables import SystemVariableKey, get_system_text from graphon.runtime import GraphRuntimeState +from core.workflow.system_variables import SystemVariableKey, get_system_text + if TYPE_CHECKING: from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index e4aa2ff6506..a5155316163 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -6,6 +6,19 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, NewType, TypedDict, Union +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import FILE_MODEL_IDENTITY, File +from graphon.runtime import GraphRuntimeState +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.variables import Variable +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from sqlalchemy.orm import Session @@ -55,19 +68,6 @@ from core.workflow.human_input_forms import load_form_tokens_by_form_id from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db -from graphon.entities.pause_reason import HumanInputRequired -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import FILE_MODEL_IDENTITY, File -from graphon.runtime import GraphRuntimeState -from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment -from graphon.variables.variables import Variable -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.human_input import HumanInputForm diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index c418fe97597..a62c5b80b51 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -6,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, copy_current_request_context, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from sqlalchemy import select @@ -23,7 +24,6 @@ from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, I from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account, App, EndUser, Message from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 6bb1ecdcb19..b216f7cf7b1 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager @@ -14,8 +16,6 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db -from graphon.file import File -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 48457b53265..fa242003a25 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -10,6 +10,8 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, cast, overload from flask import Flask, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -41,8 +43,6 @@ from core.repositories.factory import ( WorkflowNodeExecutionRepository, ) from extensions.ext_database import db -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 44d2450f743..4c188dac68d 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -2,6 +2,14 @@ import logging import time from typing import cast +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner @@ -18,13 +26,6 @@ from core.workflow.system_variables import build_bootstrap_variables, build_syst from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db -from graphon.entities.graph_init_params import GraphInitParams -from graphon.enums import WorkflowType -from graphon.graph import Graph -from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader -from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from models.dataset import Document, Pipeline from models.model import EndUser from models.workflow import Workflow diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 8ad6893a159..9618ab35c62 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -8,6 +8,10 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, Union, overload from flask import Flask, current_app +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -34,10 +38,6 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.runtime import GraphRuntimeState -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models.account import Account from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index c02c0b16e94..2cb8088971a 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -3,6 +3,12 @@ import time from collections.abc import Sequence from typing import cast +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner @@ -15,11 +21,6 @@ from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span -from graphon.enums import WorkflowType -from graphon.graph_engine.command_channels.redis_channel import RedisChannel -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader from libs.datetime_utils import naive_utc_now from models.workflow import Workflow diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index e0c5b44ee48..49af169e88a 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -4,6 +4,9 @@ from collections.abc import Callable, Generator from contextlib import contextmanager from typing import Union +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -58,9 +61,6 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.system_variables import build_system_variables from extensions.ext_database import db -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus -from graphon.runtime import GraphRuntimeState from models import Account from models.enums import CreatorUserRole from models.model import EndUser diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index d7d3bd27de4..f68c8e60b4f 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -3,6 +3,40 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph import Graph +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import ( + GraphEngineEvent, + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunAgentLogEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunHumanInputFormFilledEvent, + NodeRunHumanInputFormTimeoutEvent, + NodeRunIterationFailedEvent, + NodeRunIterationNextEvent, + NodeRunIterationStartedEvent, + NodeRunIterationSucceededEvent, + NodeRunLoopFailedEvent, + NodeRunLoopNextEvent, + NodeRunLoopStartedEvent, + NodeRunLoopSucceededEvent, + NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from pydantic import ValidationError from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -44,40 +78,6 @@ from core.workflow.system_variables import ( from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph import Graph -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import ( - GraphEngineEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunAgentLogEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.graph_events.graph import GraphRunAbortedEvent -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index d8d851c505e..0cdbb5f50a1 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -2,13 +2,13 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import TYPE_CHECKING, Any, Optional +from graphon.file import File, FileUploadConfig +from graphon.model_runtime.entities.model_entities import AIModelEntity from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from graphon.file import File, FileUploadConfig -from graphon.model_runtime.entities.model_entities import AIModelEntity if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 63857bfff2b..5e56341f892 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -3,14 +3,14 @@ from datetime import datetime from enum import StrEnum, auto from typing import Any +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import PauseReason +from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from graphon.entities.pause_reason import PauseReason -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk class QueueEvent(StrEnum): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 719027bd239..ba3b2e356f7 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -2,14 +2,14 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import Any +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.nodes.human_input.entities import FormInput, UserAction from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.nodes.human_input.entities import FormInput, UserAction class AnnotationReplyAccount(BaseModel): diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index d59f5125e3c..d2d2fea4fb8 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,8 +1,9 @@ import logging +from graphon.model_runtime.entities.message_entities import PromptMessage + from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation -from graphon.model_runtime.entities.message_entities import PromptMessage logger = logging.getLogger(__name__) diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index eeb9abbbfa4..e09869f5f8f 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -9,10 +9,11 @@ scope updates that matter to chat applications. import logging +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent + from core.workflow.system_variables import SystemVariableKey, get_system_text from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 98e2257b1fe..79a54421306 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,15 +1,14 @@ from dataclasses import dataclass from typing import Annotated, Literal, Self, TypeAlias +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent from pydantic import BaseModel, Field from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.workflow.system_variables import SystemVariableKey, get_system_text -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events.base import GraphEngineEvent -from graphon.graph_events.graph import GraphRunPausedEvent from models.model import AppMode from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/core/app/layers/suspend_layer.py b/api/core/app/layers/suspend_layer.py index 172306f271a..1a79a9f843e 100644 --- a/api/core/app/layers/suspend_layer.py +++ b/api/core/app/layers/suspend_layer.py @@ -1,6 +1,5 @@ -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events.base import GraphEngineEvent -from graphon.graph_events.graph import GraphRunPausedEvent +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent class SuspendLayer(GraphEngineLayer): diff --git a/api/core/app/layers/timeslice_layer.py b/api/core/app/layers/timeslice_layer.py index fef12df5040..8c8daf87122 100644 --- a/api/core/app/layers/timeslice_layer.py +++ b/api/core/app/layers/timeslice_layer.py @@ -3,10 +3,10 @@ import uuid from typing import ClassVar from apscheduler.schedulers.background import BackgroundScheduler # type: ignore - from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events.base import GraphEngineEvent +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent + from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index 781a0aa3d3e..77c7bec67e6 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -2,13 +2,12 @@ import logging from datetime import UTC, datetime from typing import Any, ClassVar +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from pydantic import TypeAdapter from core.db.session_factory import session_factory from core.workflow.system_variables import SystemVariableKey, get_system_text -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events.base import GraphEngineEvent -from graphon.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from models.enums import WorkflowTriggerStatus from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index c49c4eb0acf..278d0cb30b5 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -2,15 +2,16 @@ from __future__ import annotations from typing import Any +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.nodes.llm.entities import ModelConfig +from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from graphon.nodes.llm.protocols import CredentialsProvider + from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.nodes.llm.entities import ModelConfig -from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from graphon.nodes.llm.protocols import CredentialsProvider class DifyCredentialsProvider: diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 65a3f39d64d..63d22353588 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -1,3 +1,4 @@ +from graphon.model_runtime.entities.llm_entities import LLMUsage from sqlalchemy import update from sqlalchemy.orm import Session @@ -7,7 +8,6 @@ from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance from extensions.ext_database import db -from graphon.model_runtime.entities.llm_entities import LLMUsage from libs.datetime_utils import naive_utc_now from models.provider import Provider, ProviderType from models.provider_ids import ModelProviderID diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 9e688589db7..10b9c36d3e2 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,6 +1,7 @@ import logging import time +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from sqlalchemy import select from sqlalchemy.orm import Session @@ -17,7 +18,6 @@ from core.app.entities.task_entities import ( ) from core.errors.error import QuotaExceededError from core.moderation.output_moderation import ModerationRule, OutputModeration -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus from models.model import Message diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index cf9cb6d0513..a410fac5580 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -4,6 +4,13 @@ from collections.abc import Generator from threading import Thread from typing import Any, Union, cast +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -53,13 +60,6 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db -from graphon.file.enums import FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - TextPromptMessageContent, -) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile diff --git a/api/core/app/task_pipeline/message_file_utils.py b/api/core/app/task_pipeline/message_file_utils.py index 45f622c469a..b23a33923b3 100644 --- a/api/core/app/task_pipeline/message_file_utils.py +++ b/api/core/app/task_pipeline/message_file_utils.py @@ -1,8 +1,9 @@ from typing import TypedDict -from core.tools.signature import sign_tool_file +from graphon.file import FileTransferMethod from graphon.file import helpers as file_helpers -from graphon.file.enums import FileTransferMethod + +from core.tools.signature import sign_tool_file from models.model import MessageFile, UploadFile MAX_TOOL_FILE_EXTENSION_LENGTH = 10 diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index aa5291bad59..8604235ef28 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -9,6 +9,10 @@ import urllib.parse from collections.abc import Generator from typing import TYPE_CHECKING, Literal +from graphon.file import FileTransferMethod +from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol +from graphon.file.runtime import set_workflow_file_runtime + from configs import dify_config from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol from core.db.session_factory import session_factory @@ -16,12 +20,9 @@ from core.helper.ssrf_proxy import ssrf_proxy from core.tools.signature import sign_tool_file from core.workflow.file_reference import parse_file_reference from extensions.ext_storage import storage -from graphon.file.enums import FileTransferMethod -from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol -from graphon.file.runtime import set_workflow_file_runtime if TYPE_CHECKING: - from graphon.file.models import File + from graphon.file import File class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index 5666bf11911..48cabaf4d0f 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -7,18 +7,17 @@ This layer centralizes model-quota deduction outside node implementations. import logging from typing import TYPE_CHECKING, cast, final +from graphon.enums import BuiltinNodeTypes +from graphon.graph_engine.entities.commands import AbortCommand, CommandType +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent +from graphon.nodes.base.node import Node from typing_extensions import override from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.app.llm import deduct_llm_quota, ensure_llm_quota_available from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from graphon.enums import BuiltinNodeTypes -from graphon.graph_engine.entities.commands import AbortCommand, CommandType -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase -from graphon.graph_events.node import NodeRunSucceededEvent -from graphon.nodes.base.node import Node if TYPE_CHECKING: from graphon.nodes.llm.node import LLMNode diff --git a/api/core/app/workflow/layers/observability.py b/api/core/app/workflow/layers/observability.py index 837bf7ff813..8565c3076cc 100644 --- a/api/core/app/workflow/layers/observability.py +++ b/api/core/app/workflow/layers/observability.py @@ -11,6 +11,10 @@ import logging from dataclasses import dataclass from typing import cast, final +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node from opentelemetry import context as context_api from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context from typing_extensions import override @@ -24,10 +28,6 @@ from extensions.otel.parser import ( ToolNodeOTelParser, ) from extensions.otel.runtime import is_instrument_flag_enabled -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index e540733de25..ada065a9433 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -14,13 +14,6 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, Union -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository -from core.workflow.system_variables import SystemVariableKey -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from graphon.entities import WorkflowExecution, WorkflowNodeExecution from graphon.enums import ( WorkflowExecutionStatus, @@ -28,7 +21,7 @@ from graphon.enums import ( WorkflowNodeExecutionStatus, WorkflowType, ) -from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import ( GraphEngineEvent, GraphRunAbortedEvent, @@ -45,6 +38,14 @@ from graphon.graph_events import ( NodeRunSucceededEvent, ) from graphon.node_events import NodeRunResult + +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from libs.datetime_utils import naive_utc_now diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index 9e3c1872108..3d8a7a54f31 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -6,6 +6,9 @@ import re import threading from collections.abc import Iterable +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.queue_entities import ( MessageQueueMessage, QueueAgentMessageEvent, @@ -15,8 +18,6 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent -from graphon.model_runtime.entities.model_entities import ModelType class AudioTrunk: diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 8a9875e4d7a..143d1e696bf 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -3,6 +3,9 @@ from collections.abc import Generator from threading import Lock from typing import Any, cast +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType, get_file_type_by_mime_type +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from sqlalchemy import select import contexts @@ -28,11 +31,6 @@ from core.plugin.impl.datasource import PluginDatasourceManager from core.workflow.file_reference import build_file_reference from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam from factories import file_factory -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import WorkflowNodeExecutionMetadataKey -from graphon.file import File, get_file_type_by_mime_type -from graphon.file.enums import FileTransferMethod, FileType -from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from models.model import UploadFile from models.tools import ToolFile from services.datasource_provider_service import DatasourceProviderService diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 84dd6537723..14d1af2e8b4 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -1,10 +1,10 @@ from typing import Literal, Optional +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter from core.tools.entities.common_entities import I18nObject -from graphon.model_runtime.utils.encoders import jsonable_encoder class DatasourceApiEntity(BaseModel): diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 089b8b8e595..04f15dee31d 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -2,10 +2,11 @@ import logging from collections.abc import Generator from mimetypes import guess_extension, guess_type +from graphon.file import File, FileTransferMethod, FileType + from core.datasource.entities.datasource_entities import DatasourceMessage from core.tools.tool_file_manager import ToolFileManager from core.workflow.file_reference import parse_file_reference -from graphon.file import File, FileTransferMethod, FileType from models.tools import ToolFile logger = logging.getLogger(__name__) diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py index 9d970d5db16..72f6590e683 100644 --- a/api/core/entities/execution_extra_content.py +++ b/api/core/entities/execution_extra_content.py @@ -3,9 +3,9 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from typing import Any, TypeAlias +from graphon.nodes.human_input.entities import FormInput, UserAction from pydantic import BaseModel, ConfigDict, Field -from graphon.nodes.human_input.entities import FormInput, UserAction from models.execution_extra_content import ExecutionContentType diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index bfa4f569155..a440829b46b 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -6,6 +6,7 @@ from enum import StrEnum from typing import TYPE_CHECKING, Any from urllib.parse import urlparse +from graphon.file import helpers as file_helpers from pydantic import BaseModel from configs import dify_config @@ -15,7 +16,6 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from graphon.file import helpers as file_helpers if TYPE_CHECKING: from models.tools import MCPToolProvider diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index e99a131500f..84d95c38c6c 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,11 +1,10 @@ from collections.abc import Sequence from enum import StrEnum, auto -from pydantic import BaseModel, ConfigDict - from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import ModelType, ProviderModel from graphon.model_runtime.entities.provider_entities import ProviderEntity +from pydantic import BaseModel, ConfigDict class ModelStatus(StrEnum): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index d90afd3f7bf..8b48aa2660e 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -7,6 +7,16 @@ from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.runtime import ModelRuntime from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -22,16 +32,6 @@ from core.entities.provider_entities import ( from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from graphon.model_runtime.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index dffc7f2fc1f..2c8767a32b8 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -3,6 +3,7 @@ from __future__ import annotations from enum import StrEnum, auto from typing import Union +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, ConfigDict, Field from core.entities.parameter_entities import ( @@ -12,7 +13,6 @@ from core.entities.parameter_entities import ( ToolSelectorScope, ) from core.tools.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType class ProviderQuotaType(StrEnum): diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 951e065b2cb..35bfcfb6a5c 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,6 +4,7 @@ from threading import Lock from typing import Any import httpx +from graphon.nodes.code.entities import CodeLanguage from pydantic import BaseModel from yarl import URL @@ -13,7 +14,6 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.http_client_pooling import get_pooled_http_client -from graphon.nodes.code.entities import CodeLanguage logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index dc37a369432..a1e782a094e 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -2,13 +2,14 @@ import logging import secrets from typing import cast +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeBadRequestError +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from extensions.ext_hosting_provider import hosting_configuration -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeBadRequestError -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel from models.provider import ProviderType logger = logging.getLogger(__name__) diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index eb762c35086..60f5434bc1e 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,10 +1,10 @@ from flask import Flask +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel from configs import dify_config from core.entities import DEFAULT_PLUGIN_ID from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel -from graphon.model_runtime.entities.model_entities import ModelType class HostingQuota(BaseModel): diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 46bf1d69379..3ec17bc9864 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -9,6 +9,7 @@ from collections.abc import Mapping from typing import Any from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from sqlalchemy.orm.exc import ObjectDeletedError @@ -34,7 +35,6 @@ from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage -from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models import Account diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 3712374305a..3d94f1a5969 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -5,6 +5,11 @@ from collections.abc import Sequence from typing import Protocol, cast import json_repair +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload @@ -29,11 +34,6 @@ from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from extensions.ext_storage import storage -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models import App, Message, WorkflowNodeExecutionModel from models.workflow import Workflow diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 81672ee7aac..a1710f11ace 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -5,11 +5,6 @@ from enum import StrEnum from typing import Any, Literal, cast, overload import json_repair -from pydantic import TypeAdapter, ValidationError - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT -from core.model_manager import ModelInstance from graphon.model_runtime.callbacks.base_callback import Callback from graphon.model_runtime.entities.llm_entities import ( LLMResult, @@ -26,6 +21,11 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule +from pydantic import TypeAdapter, ValidationError + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT +from core.model_manager import ModelInstance class ResponseFormat(StrEnum): diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 92d23c6dc95..27000c947c1 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -3,11 +3,12 @@ import logging from collections.abc import Mapping from typing import Any, cast +from graphon.variables.input_entities import VariableEntity, VariableEntityType + from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types as mcp_types -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index 7b5a7635f1c..7e350441768 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -4,11 +4,11 @@ from contextlib import AbstractContextManager import httpx import httpx_sse +from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError -from graphon.model_runtime.utils.encoders import jsonable_encoder HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 658206128d1..09c84538a9a 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,14 +1,5 @@ from collections.abc import Sequence -from sqlalchemy import select -from sqlalchemy.orm import sessionmaker - -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.file_access import DatabaseFileAccessController -from core.model_manager import ModelInstance -from core.prompt.utils.extract_thread_messages import extract_thread_messages -from extensions.ext_database import db -from factories import file_factory from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -19,6 +10,15 @@ from graphon.model_runtime.entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.file_access import DatabaseFileAccessController +from core.model_manager import ModelInstance +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from extensions.ext_database import db +from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile from models.workflow import Workflow from repositories.api_workflow_run_repository import APIWorkflowRunRepository diff --git a/api/core/model_manager.py b/api/core/model_manager.py index f5ff375f651..87d1d7fba60 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -2,14 +2,6 @@ import logging from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from typing import IO, Any, Literal, Optional, Union, cast, overload -from configs import dify_config -from core.entities.embedding_type import EmbeddingInputType -from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -from core.entities.provider_entities import ModelLoadBalancingConfiguration -from core.errors.error import ProviderTokenNotInitError -from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager -from core.provider_manager import ProviderManager -from extensions.ext_redis import redis_client from graphon.model_runtime.callbacks.base_callback import Callback from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -23,6 +15,15 @@ from graphon.model_runtime.model_providers.__base.rerank_model import RerankMode from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from graphon.model_runtime.model_providers.__base.tts_model import TTSModel + +from configs import dify_config +from core.entities.embedding_type import EmbeddingInputType +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import ModelLoadBalancingConfiguration +from core.errors.error import ProviderTokenNotInitError +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from core.provider_manager import ProviderManager +from extensions.ext_redis import redis_client from models.provider import ProviderType from services.enterprise.plugin_manager_service import PluginCredentialType diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 35d4469bc1a..dd038c77f13 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,6 +1,7 @@ +from graphon.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelManager from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult -from graphon.model_runtime.entities.model_entities import ModelType class OpenAIModeration(Moderation): diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 76e81242f4c..70aaf2a07be 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -1,6 +1,8 @@ import logging from collections.abc import Sequence +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opentelemetry.trace import SpanKind from sqlalchemy.orm import sessionmaker @@ -58,8 +60,6 @@ from core.ops.entities.trace_entity import ( ) from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/aliyun_trace/utils.py b/api/core/ops/aliyun_trace/utils.py index 956fc60191f..d8e105d6a32 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/core/ops/aliyun_trace/utils.py @@ -2,6 +2,8 @@ import json from collections.abc import Mapping from typing import Any +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from opentelemetry.trace import Link, Status, StatusCode from core.ops.aliyun_trace.entities.semconv import ( @@ -15,8 +17,6 @@ from core.ops.aliyun_trace.entities.semconv import ( ) from core.rag.models.document import Document from extensions.ext_database import db -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser # Constants diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index a1ea182f66d..39d97e28828 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -6,6 +6,7 @@ from datetime import datetime, timedelta from typing import Any, Union, cast from urllib.parse import urlparse +from graphon.enums import WorkflowNodeExecutionStatus from openinference.semconv.trace import ( MessageAttributes, OpenInferenceMimeTypeValues, @@ -39,7 +40,6 @@ from core.ops.entities.trace_entity import ( ) from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import WorkflowNodeExecutionStatus from models.model import EndUser, MessageFile from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 3bf01eb81c6..3644b6b4c26 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -2,6 +2,7 @@ import logging import os from datetime import datetime, timedelta +from graphon.enums import BuiltinNodeTypes from langfuse import Langfuse from sqlalchemy.orm import sessionmaker @@ -29,7 +30,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( from core.ops.utils import filter_none_values from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes from models import EndUser, WorkflowNodeExecutionTriggeredFrom from models.enums import MessageStatus diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index d960038f154..490c64af84d 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -4,6 +4,7 @@ import uuid from datetime import datetime, timedelta from typing import cast +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from langsmith import Client from langsmith.schemas import RunBase from sqlalchemy.orm import sessionmaker @@ -29,7 +30,6 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/core/ops/mlflow_trace/mlflow_trace.py index 8bf2e5dc138..946d3cdd479 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/core/ops/mlflow_trace/mlflow_trace.py @@ -5,6 +5,7 @@ from datetime import datetime, timedelta from typing import Any, cast import mlflow +from graphon.enums import BuiltinNodeTypes from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey from mlflow.tracing.fluent import start_span_no_context, update_current_trace @@ -25,7 +26,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes from models import EndUser from models.workflow import WorkflowNodeExecutionModel diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index b98cc3ce598..2215bdeb33b 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -5,6 +5,7 @@ import uuid from datetime import datetime, timedelta from typing import cast +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 from sqlalchemy.orm import sessionmaker @@ -24,7 +25,6 @@ from core.ops.entities.trace_entity import ( ) from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/core/ops/tencent_trace/span_builder.py index 4f064581571..f79095d9662 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/core/ops/tencent_trace/span_builder.py @@ -6,6 +6,8 @@ import json import logging from datetime import datetime +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from opentelemetry.trace import Status, StatusCode from core.ops.entities.trace_entity import ( @@ -41,11 +43,6 @@ from core.ops.tencent_trace.entities.semconv import ( from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData from core.ops.tencent_trace.utils import TencentTraceUtils from core.rag.models.document import Document -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py index 1b1b1025bc4..2bd6db22bf7 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/core/ops/tencent_trace/tencent_trace.py @@ -4,6 +4,10 @@ Tencent APM tracing implementation with separated concerns import logging +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, +) +from graphon.nodes import BuiltinNodeTypes from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -25,10 +29,6 @@ from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.ops.tencent_trace.utils import TencentTraceUtils from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from extensions.ext_database import db -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, -) -from graphon.nodes import BuiltinNodeTypes from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index f79544f1c7b..8d9ba4694d9 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -6,6 +6,7 @@ from typing import Any, cast import wandb import weave +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from sqlalchemy.orm import sessionmaker from weave.trace_server.trace_server_interface import ( CallEndReq, @@ -32,7 +33,6 @@ from core.ops.entities.trace_entity import ( from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 85625fc87d7..c715b9171c6 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -2,6 +2,20 @@ import tempfile from binascii import hexlify, unhexlify from collections.abc import Generator +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.llm import deduct_llm_quota from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.model_manager import ModelManager @@ -18,19 +32,6 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, -) -from graphon.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelType from models.account import Tenant diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index 248f8ef3e6b..94789974942 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -1,8 +1,5 @@ -from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from graphon.enums import BuiltinNodeTypes -from graphon.nodes.parameter_extractor.entities import ( - ModelConfig as ParameterExtractorModelConfig, -) +from graphon.nodes.llm.entities import ModelConfig as LLMModelConfig from graphon.nodes.parameter_extractor.entities import ( ParameterConfig, ParameterExtractorNodeData, @@ -11,9 +8,8 @@ from graphon.nodes.question_classifier.entities import ( ClassConfig, QuestionClassifierNodeData, ) -from graphon.nodes.question_classifier.entities import ( - ModelConfig as QuestionClassifierModelConfig, -) + +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from services.workflow_service import WorkflowService @@ -24,7 +20,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): tenant_id: str, user_id: str, parameters: list[ParameterConfig], - model_config: ParameterExtractorModelConfig, + model_config: LLMModelConfig, instruction: str, query: str, ): @@ -74,7 +70,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): cls, tenant_id: str, user_id: str, - model_config: QuestionClassifierModelConfig, + model_config: LLMModelConfig, classes: list[ClassConfig], instruction: str, query: str, diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py index 1bd239a8310..2177e8af908 100644 --- a/api/core/plugin/entities/marketplace.py +++ b/api/core/plugin/entities/marketplace.py @@ -1,10 +1,10 @@ +from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, Field, computed_field, model_validator from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.plugin.entities.plugin import PluginResourceRequirements from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity -from graphon.model_runtime.entities.provider_entities import ProviderEntity class MarketplacePluginDeclaration(BaseModel): diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 6aefc414003..b095b4998d7 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -3,6 +3,7 @@ from collections.abc import Mapping from enum import StrEnum, auto from typing import Any +from graphon.model_runtime.entities.provider_entities import ProviderEntity from packaging.version import InvalidVersion, Version from pydantic import BaseModel, Field, field_validator, model_validator @@ -13,7 +14,6 @@ from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity from core.trigger.entities.entities import TriggerProviderEntity -from graphon.model_runtime.entities.provider_entities import ProviderEntity class PluginInstallationSource(StrEnum): diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 864e4b8dd73..94263ec44e6 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -6,6 +6,8 @@ from datetime import datetime from enum import StrEnum from typing import Any, Generic, TypeVar +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, ConfigDict, Field from core.agent.plugin_entities import AgentProviderEntityWithPlugin @@ -16,8 +18,6 @@ from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin from core.trigger.entities.entities import TriggerProviderEntity -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.entities.provider_entities import ProviderEntity T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 704cacae2ab..059f3fa9be1 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -4,10 +4,6 @@ from collections.abc import Mapping from typing import Any, Literal from flask import Response -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from core.entities.provider_entities import BasicProviderConfig -from core.plugin.utils.http_parser import deserialize_response from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -18,18 +14,17 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.model_entities import ModelType -from graphon.nodes.parameter_extractor.entities import ( - ModelConfig as ParameterExtractorModelConfig, -) +from graphon.nodes.llm.entities import ModelConfig as LLMModelConfig from graphon.nodes.parameter_extractor.entities import ( ParameterConfig, ) from graphon.nodes.question_classifier.entities import ( ClassConfig, ) -from graphon.nodes.question_classifier.entities import ( - ModelConfig as QuestionClassifierModelConfig, -) +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from core.entities.provider_entities import BasicProviderConfig +from core.plugin.utils.http_parser import deserialize_response class InvokeCredentials(BaseModel): @@ -176,7 +171,7 @@ class RequestInvokeParameterExtractorNode(BaseModel): """ parameters: list[ParameterConfig] - model: ParameterExtractorModelConfig + model: LLMModelConfig instruction: str query: str @@ -187,7 +182,7 @@ class RequestInvokeQuestionClassifierNode(BaseModel): """ query: str - model: QuestionClassifierModelConfig + model: LLMModelConfig classes: list[ClassConfig] instruction: str diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 44047911da3..2d0ab3fcd73 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -5,6 +5,14 @@ from collections.abc import Callable, Generator from typing import Any, TypeVar, cast import httpx +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from yarl import URL @@ -28,14 +36,6 @@ from core.trigger.errors import ( TriggerPluginInvokeError, TriggerProviderCredentialValidationError, ) -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index c91fa713744..1e38c24717f 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -2,6 +2,13 @@ import binascii from collections.abc import Generator, Sequence from typing import IO, Any +from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.utils.encoders import jsonable_encoder + from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDaemonInnerError, @@ -13,12 +20,6 @@ from core.plugin.entities.plugin_daemon import ( PluginVoicesResponse, ) from core.plugin.impl.base import BasePluginClient -from graphon.model_runtime.entities.llm_entities import LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult -from graphon.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index e3fba4ef3a6..22c846b6de0 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -6,6 +6,13 @@ from collections.abc import Generator, Iterable, Sequence from threading import Lock from typing import IO, Any, Union +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult +from graphon.model_runtime.runtime import ModelRuntime from pydantic import ValidationError from redis import RedisError @@ -14,13 +21,6 @@ from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.model import PluginModelClient from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.model_runtime.entities.provider_entities import ProviderEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult -from graphon.model_runtime.runtime import ModelRuntime from models.provider_ids import ModelProviderID logger = logging.getLogger(__name__) diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py index 35abd2ae8c1..4b29a6fc56b 100644 --- a/api/core/plugin/impl/model_runtime_factory.py +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -2,9 +2,10 @@ from __future__ import annotations from typing import TYPE_CHECKING -from core.plugin.impl.model import PluginModelClient from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model import PluginModelClient + if TYPE_CHECKING: from core.model_manager import ModelManager from core.plugin.impl.model_runtime import PluginModelRuntime diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 322f78ab4eb..90350f84000 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,7 +1,8 @@ from typing import Any +from graphon.file import File + from core.tools.entities.tool_entities import ToolSelector -from graphon.file.models import File def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]: diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index de87a096521..19b5e9223a8 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,15 +1,7 @@ from collections.abc import Mapping, Sequence from typing import cast -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from graphon.file import file_manager -from graphon.file.models import File +from graphon.file import File, file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, @@ -21,6 +13,14 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.runtime import VariablePool +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser + class AdvancedPromptTransform(PromptTransform): """ diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 8f1d51f08a0..9be70199b7d 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -1,10 +1,5 @@ from typing import cast -from core.app.entities.app_invoke_entities import ( - ModelConfigWithCredentialsEntity, -) -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.prompt_transform import PromptTransform from graphon.model_runtime.entities.message_entities import ( PromptMessage, SystemPromptMessage, @@ -12,6 +7,12 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.app.entities.app_invoke_entities import ( + ModelConfigWithCredentialsEntity, +) +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.prompt_transform import PromptTransform + class AgentHistoryPromptTransform(PromptTransform): """ diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 6ff2f44cdc8..4539ae9f11b 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,11 +1,12 @@ from typing import Any +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from graphon.model_runtime.entities.message_entities import PromptMessage -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey class PromptTransform: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index e091215b803..c706353ffeb 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -4,12 +4,6 @@ from collections.abc import Mapping, Sequence from enum import StrEnum, auto from typing import TYPE_CHECKING, Any, cast -from core.app.app_config.entities import PromptTemplateEntity -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import file_manager from graphon.model_runtime.entities.message_entities import ( ImagePromptMessageContent, @@ -19,10 +13,17 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) + +from core.app.app_config.entities import PromptTemplateEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode if TYPE_CHECKING: - from graphon.file.models import File + from graphon.file import File class ModelMode(StrEnum): diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index ba76eb0c4e0..dbda7499255 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,7 +1,6 @@ from collections.abc import Sequence from typing import Any, cast -from core.prompt.simple_prompt_transform import ModelMode from graphon.model_runtime.entities import ( AssistantPromptMessage, AudioPromptMessageContent, @@ -12,6 +11,8 @@ from graphon.model_runtime.entities import ( TextPromptMessageContent, ) +from core.prompt.simple_prompt_transform import ModelMode + class PromptMessageUtil: @staticmethod diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 79fd78fe807..30933239f65 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -7,6 +7,14 @@ from collections.abc import Sequence from json import JSONDecodeError from typing import TYPE_CHECKING, Any, cast +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -33,14 +41,6 @@ from core.helper.position_helper import is_filtered from extensions import ext_hosting_provider from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ( LoadBalancingModelConfig, Provider, diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 2c816535597..b872ea8a8fb 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,3 +1,5 @@ +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from typing_extensions import TypedDict from core.model_manager import ModelInstance, ModelManager @@ -8,8 +10,6 @@ from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_type import RerankMode -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class RerankingModelDict(TypedDict): diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 1e4aa242874..cc6ec12c750 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, NotRequired from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from sqlalchemy.orm import Session, load_only from typing_extensions import TypedDict @@ -24,7 +25,6 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ( ChildChunk, Dataset, diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index a77458706aa..5a8d3a2f3ff 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -4,6 +4,7 @@ import time from abc import ABC, abstractmethod from typing import Any +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from configs import dify_config @@ -17,7 +18,6 @@ from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Whitelist from models.model import UploadFile diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 369159767e7..e5b794f80d6 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -3,13 +3,13 @@ from __future__ import annotations from collections.abc import Sequence from typing import Any +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import func, select from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index b12a0ae2d69..3bdad007121 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -4,6 +4,8 @@ import pickle from typing import Any, cast import numpy as np +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from sqlalchemy.exc import IntegrityError from configs import dify_config @@ -12,8 +14,6 @@ from core.model_manager import ModelInstance from core.rag.embedding.embedding_base import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from libs import helper from models.dataset import Embedding diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 9f36b7a2255..5c10ffbf2dd 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -8,6 +8,17 @@ from typing import Any, cast logger = logging.getLogger(__name__) +from graphon.file import File, FileTransferMethod, FileType, file_manager +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType + from core.app.file_access import DatabaseFileAccessController from core.app.llm import deduct_llm_quota from core.entities.knowledge_entities import PreviewDetail @@ -31,16 +42,6 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols from core.workflow.file_reference import build_file_reference from extensions.ext_database import db from factories.file_factory import build_from_mapping -from graphon.file import File, FileTransferMethod, FileType, file_manager -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentUnionTypes, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from libs import helper from models import UploadFile from models.account import Account diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 4ebf0959042..087736d0b0a 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,9 +2,8 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any -from pydantic import BaseModel, Field - from graphon.file import File +from pydantic import BaseModel, Field class ChildDocument(BaseModel): diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 6c6b077cc2d..211a9f5c5cd 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,5 +1,8 @@ import base64 +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.rerank_entities import RerankResult + from core.model_manager import ModelInstance, ModelManager from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.query_type import QueryType @@ -7,8 +10,6 @@ from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner from extensions.ext_database import db from extensions.ext_storage import storage -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.rerank_entities import RerankResult from models.model import UploadFile diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index d0732b269af..49123e13d05 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -2,6 +2,7 @@ import math from collections import Counter import numpy as np +from graphon.model_runtime.entities.model_entities import ModelType from core.model_manager import ModelManager from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler @@ -11,7 +12,6 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner -from graphon.model_runtime.entities.model_entities import ModelType class WeightRerankRunner(BaseRerankRunner): diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 49b91707ec0..1abea6639e8 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -9,6 +9,11 @@ from collections.abc import Generator, Mapping from typing import Any, Union, cast from flask import Flask, current_app +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import and_, func, literal, or_, select from sqlalchemy.orm import Session @@ -66,11 +71,6 @@ from core.workflow.nodes.knowledge_retrieval.retrieval import ( ) from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.file import File, FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from libs.helper import parse_uuid_str_or_none from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index e617a9660eb..dce7b6226ce 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -1,9 +1,10 @@ from typing import Union +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage class FunctionCallMultiDatasetRouter: diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 83e58fe0f93..dd280cdf6a7 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -1,6 +1,10 @@ from collections.abc import Generator, Sequence from typing import Union +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.llm import deduct_llm_quota from core.model_manager import ModelInstance, ModelManager @@ -8,9 +12,6 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from graphon.model_runtime.entities.model_entities import ModelType PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 2c27ac3cf64..e6aec4a3af9 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -6,6 +6,8 @@ import codecs import re from typing import Any +from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer + from core.model_manager import ModelInstance from core.rag.splitter.text_splitter import ( TS, @@ -15,7 +17,6 @@ from core.rag.splitter.text_splitter import ( Set, Union, ) -from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index d0164b76dc1..465f43da739 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -8,11 +8,11 @@ providing improved performance by offloading database operations to background w import logging from typing import Union +from graphon.entities import WorkflowExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository -from graphon.entities.workflow_execution import WorkflowExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 52361cf6dcd..22ef44b3dc4 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -9,6 +9,7 @@ import logging from collections.abc import Sequence from typing import Union +from graphon.entities import WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -16,7 +17,6 @@ from core.repositories.factory import ( OrderConfig, WorkflowNodeExecutionRepository, ) -from graphon.entities.workflow_node_execution import WorkflowNodeExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index dafdbf641a7..ed6d44f4340 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -9,11 +9,11 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Literal, Protocol, Union +from graphon.entities import WorkflowExecution, WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config -from graphon.entities import WorkflowExecution, WorkflowNodeExecution from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 02625e242f9..72d93941498 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -4,6 +4,8 @@ from collections.abc import Mapping, Sequence from datetime import datetime from typing import Any, Protocol +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import select from sqlalchemy.orm import Session, selectinload @@ -17,8 +19,6 @@ from core.workflow.human_input_compat import ( InteractiveSurfaceDeliveryMethod, is_human_input_webapp_enabled, ) -from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.account import Account, TenantAccountJoin diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 1ee5d4ae77b..85d20b675d2 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -6,13 +6,13 @@ import json import logging from typing import Union +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus, WorkflowType -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 749ab44a14f..a72bfa378bc 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -10,6 +10,10 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, TypeVar, Union import psycopg2.errors +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import UnaryExpression, asc, desc, select from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError @@ -19,10 +23,6 @@ from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_att from configs import dify_config from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.ext_storage import storage -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from libs.uuid_utils import uuidv7 from models import ( diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index 40bf2e98c2b..e5390743036 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -2,14 +2,15 @@ import io from collections.abc import Generator from typing import Any +from graphon.file import FileType +from graphon.file.file_manager import download +from graphon.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from graphon.file.enums import FileType -from graphon.file.file_manager import download -from graphon.model_runtime.entities.model_entities import ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index ac3820f1aba..f49c669fe09 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -2,12 +2,13 @@ import io from collections.abc import Generator from typing import Any +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType + from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index d41503e1e61..14af63a962e 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,11 +1,12 @@ from __future__ import annotations +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage + from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language and you can quickly aimed at the main point of an webpage and reproduce it in your own words but diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 168e5f4493a..0a2c37c5632 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -6,6 +6,7 @@ from typing import Any, Union from urllib.parse import urlencode import httpx +from graphon.file.file_manager import download from core.helper import ssrf_proxy from core.tools.__base.tool import Tool @@ -13,7 +14,6 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError -from graphon.file.file_manager import download API_TOOL_DEFAULT_TIMEOUT = ( int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 08640befb49..d5d3d1b1d95 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -2,6 +2,7 @@ from collections.abc import Mapping from datetime import datetime from typing import Any, Literal +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration @@ -9,7 +10,6 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from graphon.model_runtime.utils.encoders import jsonable_encoder class ToolApiEntity(BaseModel): diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 00fc8a82827..f6d09472b3c 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -6,6 +6,8 @@ import logging from collections.abc import Generator, Mapping from typing import Any, cast +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata + from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPConnectionError from core.mcp.types import ( @@ -21,7 +23,6 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata logger = logging.getLogger(__name__) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 1fd259f3bb3..685d687d8c4 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -7,6 +7,7 @@ from datetime import UTC, datetime from mimetypes import guess_type from typing import Any, Union, cast +from graphon.file import FileTransferMethod, FileType from yarl import URL from core.app.entities.app_invoke_entities import InvokeFrom @@ -32,8 +33,6 @@ from core.tools.errors import ( from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db -from graphon.file import FileType -from graphon.file.models import FileTransferMethod from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import Message, MessageFile diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 2ec292602ca..7ac29cf0698 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -10,13 +10,13 @@ from typing import Union from uuid import uuid4 import httpx +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from configs import dify_config from core.db.session_factory import session_factory from core.helper import ssrf_proxy from core.workflow.file_reference import build_file_reference from extensions.ext_storage import storage -from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from models.model import MessageFile from models.tools import ToolFile diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 4a10c7e23e5..584bae39b9d 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -8,6 +8,7 @@ from threading import Lock from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict, Union, cast import sqlalchemy as sa +from graphon.runtime import VariablePool from sqlalchemy import select from sqlalchemy.orm import Session from yarl import URL @@ -25,7 +26,6 @@ from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from extensions.ext_database import db -from graphon.runtime.variable_pool import VariablePool from models.provider_ids import ToolProviderID from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService @@ -33,6 +33,8 @@ from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: pass +from graphon.model_runtime.utils.encoders import jsonable_encoder + from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source @@ -56,7 +58,6 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool -from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index dad5133a7a6..6a77fda7ef3 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,6 +1,7 @@ import threading from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select @@ -15,7 +16,6 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model: DefaultRetrievalModelDict = { diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 5cf46b25640..bb5b3ba76e9 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -8,11 +8,11 @@ from uuid import UUID import numpy as np import pytz +from graphon.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager from core.workflow.file_reference import parse_file_reference -from graphon.file import File, FileTransferMethod, FileType from libs.login import current_user from models import Account diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 9e1d41cb39c..8d6f83dc07c 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -8,9 +8,6 @@ import json from decimal import Decimal from typing import cast -from core.model_manager import ModelManager -from core.tools.entities.tool_entities import ToolProviderType -from extensions.ext_database import db from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType @@ -23,6 +20,10 @@ from graphon.model_runtime.errors.invoke import ( ) from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from graphon.model_runtime.utils.encoders import jsonable_encoder + +from core.model_manager import ModelManager +from core.tools.entities.tool_entities import ToolProviderType +from extensions.ext_database import db from models.tools import ToolModelInvoke diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 1e4f3ed2a7f..c4b7d574493 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -1,12 +1,13 @@ from collections.abc import Mapping, Sequence from typing import Any -from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration -from core.tools.errors import WorkflowToolHumanInputNotSupportedError from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.entities import OutputVariableEntity from graphon.variables.input_entities import VariableEntity +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration +from core.tools.errors import WorkflowToolHumanInputNotSupportedError + class WorkflowToolConfigurationUtils: @classmethod diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 716368c1917..f48b24be30e 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -2,6 +2,7 @@ from __future__ import annotations from collections.abc import Mapping +from graphon.variables.input_entities import VariableEntity, VariableEntityType from pydantic import Field from sqlalchemy.orm import Session @@ -23,7 +24,6 @@ from core.tools.entities.tool_entities import ( from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 495fcd48b33..a3fb4eda928 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -5,6 +5,8 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import Any, cast +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController @@ -20,8 +22,6 @@ from core.tools.entities.tool_entities import ( from core.tools.errors import ToolInvokeError from core.workflow.file_reference import resolve_file_record_id from factories.file_factory import build_from_mapping -from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from models import Account, Tenant from models.model import App, EndUser from models.utils.file_input_compat import build_file_from_stored_mapping diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index 24c1271488c..61d1cd85402 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -8,6 +8,7 @@ from collections.abc import Mapping from datetime import datetime from typing import Any +from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from core.plugin.entities.request import TriggerInvokeEventResponse @@ -27,7 +28,6 @@ from core.trigger.debug.events import ( from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig from extensions.ext_redis import redis_client -from graphon.entities.graph_config import NodeConfigDict from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.schedule_utils import calculate_next_run_at from models.model import App diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_compat.py index 75a0a0c2029..c95516a240b 100644 --- a/api/core/workflow/human_input_compat.py +++ b/api/core/workflow/human_input_compat.py @@ -14,13 +14,12 @@ from typing import Annotated, Any, ClassVar, Literal import bleach import markdown -from markdown.extensions.tables import TableExtension -from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter - from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.variable_template_parser import VariableTemplateParser from graphon.runtime import VariablePool from graphon.variables.consts import SELECTORS_LENGTH +from markdown.extensions.tables import TableExtension +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter class DeliveryMethodType(enum.StrEnum): diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 028e38fbee8..8cc21d2cd96 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -4,6 +4,22 @@ from collections.abc import Callable, Iterator, Mapping, MutableMapping from functools import lru_cache from typing import TYPE_CHECKING, Any, TypeAlias, cast, final +from graphon.entities.base_node_data import BaseNodeData +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.file.file_manager import file_manager +from graphon.graph.graph import NodeFactory +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.nodes.base.node import Node +from graphon.nodes.code.code_node import WorkflowCodeExecutor +from graphon.nodes.code.entities import CodeLanguage +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.nodes.document_extractor import UnstructuredApiConfig +from graphon.nodes.http_request import build_http_request_config +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from sqlalchemy import select from sqlalchemy.orm import Session from typing_extensions import override @@ -40,22 +56,6 @@ from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer from extensions.ext_database import db -from graphon.entities.base_node_data import BaseNodeData -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.file.file_manager import file_manager -from graphon.graph.graph import NodeFactory -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.nodes.base.node import Node -from graphon.nodes.code.code_node import WorkflowCodeExecutor -from graphon.nodes.code.entities import CodeLanguage -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.nodes.document_extractor import UnstructuredApiConfig -from graphon.nodes.http_request import build_http_request_config -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from models.model import Conversation if TYPE_CHECKING: diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index 2e632e56f02..19cb3a7b0ab 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -4,32 +4,6 @@ from collections.abc import Callable, Generator, Mapping, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.app.file_access import DatabaseFileAccessController -from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from core.model_manager import ModelInstance -from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError -from core.plugin.impl.plugin import PluginInstaller -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormRepository, - HumanInputFormRepositoryImpl, -) -from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType -from core.tools.errors import ToolInvokeError -from core.tools.tool_engine import ToolEngine -from core.tools.tool_file_manager import ToolFileManager -from core.tools.tool_manager import ToolManager -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.file_reference import build_file_reference -from extensions.ext_database import db -from factories import file_factory from graphon.file import FileTransferMethod, FileType from graphon.model_runtime.entities import LLMMode from graphon.model_runtime.entities.llm_entities import ( @@ -60,6 +34,32 @@ from graphon.nodes.tool_runtime_entities import ( ToolRuntimeMessage, ToolRuntimeParameter, ) +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.file_access import DatabaseFileAccessController +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output +from core.model_manager import ModelInstance +from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError +from core.plugin.impl.plugin import PluginInstaller +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormRepository, + HumanInputFormRepositoryImpl, +) +from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.tool_engine import ToolEngine +from core.tools.tool_file_manager import ToolFileManager +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.file_reference import build_file_reference +from extensions.ext_database import db +from factories import file_factory from models.dataset import SegmentAttachmentBinding from models.model import UploadFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService @@ -76,12 +76,13 @@ from .human_input_compat import ( from .system_variables import SystemVariableKey, get_system_text if TYPE_CHECKING: - from core.tools.__base.tool import Tool - from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage from graphon.file import File from graphon.nodes.llm.file_saver import LLMFileSaver from graphon.nodes.tool.entities import ToolNodeData + from core.tools.__base.tool import Tool + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage + _file_access_controller = DatabaseFileAccessController() diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 7b000101b0f..bfd5536e4a7 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -3,14 +3,15 @@ from __future__ import annotations from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.workflow.system_variables import SystemVariableKey, get_system_text from graphon.entities.graph_config import NodeConfigDict from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent from graphon.nodes.base.node import Node from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.workflow.system_variables import SystemVariableKey, get_system_text + from .entities import AgentNodeData from .exceptions import ( AgentInvocationError, diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 51452c29a3f..c52aad150bb 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,12 +1,12 @@ from enum import IntEnum, StrEnum, auto from typing import Any, Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType class AgentNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/agent/message_transformer.py b/api/core/workflow/nodes/agent/message_transformer.py index f44681377dc..db74590ed76 100644 --- a/api/core/workflow/nodes/agent/message_transformer.py +++ b/api/core/workflow/nodes/agent/message_transformer.py @@ -3,14 +3,6 @@ from __future__ import annotations from collections.abc import Generator, Mapping from typing import Any, cast -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.app.file_access import DatabaseFileAccessController -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from extensions.ext_database import db -from factories import file_factory from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata @@ -23,6 +15,14 @@ from graphon.node_events import ( StreamCompletedEvent, ) from graphon.variables.segments import ArrayFileSegment +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.file_access import DatabaseFileAccessController +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from extensions.ext_database import db +from factories import file_factory from models import ToolFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py index a872774c98c..be50edbc4d4 100644 --- a/api/core/workflow/nodes/agent/runtime_support.py +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -4,6 +4,8 @@ import json from collections.abc import Sequence from typing import Any, cast +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.runtime import VariablePool from packaging.version import Version from pydantic import ValidationError from sqlalchemy import select @@ -19,8 +21,6 @@ from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolP from core.tools.tool_manager import ToolManager from core.workflow.system_variables import SystemVariableKey, get_system_text from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.runtime import VariablePool from models.model import Conversation from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 38f39b3f940..d9247b25932 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,18 +1,23 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( + BuiltinNodeTypes, + NodeExecutionType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.node_events import NodeRunResult, StreamCompletedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser + from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError from core.workflow.file_reference import resolve_file_record_id from core.workflow.system_variables import SystemVariableKey, get_system_segment -from graphon.entities.graph_config import NodeConfigDict -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionMetadataKey -from graphon.node_events import NodeRunResult, StreamCompletedEvent -from graphon.nodes.base.node import Node -from graphon.nodes.base.variable_template_parser import VariableTemplateParser from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam from .exc import DatasourceNodeError diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index 28966f2392c..cad32f8d5bd 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -1,10 +1,9 @@ from typing import Any, Literal, Union -from pydantic import BaseModel, field_validator -from pydantic_core.core_schema import ValidationInfo - from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType +from pydantic import BaseModel, field_validator +from pydantic_core.core_schema import ValidationInfo class DatasourceEntity(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 11339bb1226..cba6c12dca0 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -1,12 +1,12 @@ from typing import Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType class RerankingModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index b465a2d8ffc..bb72fe38816 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -2,17 +2,17 @@ import logging from collections.abc import Mapping from typing import TYPE_CHECKING, Any +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template + from core.rag.index_processor.index_processor import IndexProcessor from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text -from graphon.entities.graph_config import NodeConfigDict -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import NodeExecutionType -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.base.template import Template from .entities import KnowledgeIndexNodeData from .exc import ( diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 3f7cc364d30..b1fa8593efe 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,11 +1,10 @@ from collections.abc import Sequence from typing import Literal -from pydantic import BaseModel, Field - from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.llm.entities import ModelConfig, VisionConfig +from pydantic import BaseModel, Field class RerankingModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 117f426adec..13624b27b37 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -8,11 +8,6 @@ import logging from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal -from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.workflow.file_reference import parse_file_reference from graphon.entities import GraphInitParams from graphon.entities.graph_config import NodeConfigDict from graphon.enums import ( @@ -32,6 +27,12 @@ from graphon.variables import ( ) from graphon.variables.segments import ArrayObjectSegment +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.file_reference import parse_file_reference + from .entities import ( Condition, KnowledgeRetrievalNodeData, @@ -44,7 +45,7 @@ from .exc import ( from .retrieval import KnowledgeRetrievalRequest, Source if TYPE_CHECKING: - from graphon.file.models import File + from graphon.file import File from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) diff --git a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py index ea45dcf5c20..39e2008a2ca 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py +++ b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py @@ -1,10 +1,10 @@ from typing import Any, Literal, Protocol +from graphon.model_runtime.entities import LLMUsage +from graphon.nodes.llm.entities import ModelConfig from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from graphon.model_runtime.entities import LLMUsage -from graphon.nodes.llm.entities import ModelConfig from .entities import MetadataFilteringCondition diff --git a/api/core/workflow/nodes/trigger_plugin/entities.py b/api/core/workflow/nodes/trigger_plugin/entities.py index 23ed2cd4088..bf5be2379af 100644 --- a/api/core/workflow/nodes/trigger_plugin/entities.py +++ b/api/core/workflow/nodes/trigger_plugin/entities.py @@ -1,12 +1,12 @@ from collections.abc import Mapping from typing import Any, Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel, Field, ValidationInfo, field_validator from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.entities.entities import EventParameter -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType from .exc import TriggerEventParameterError diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index a2c952a8991..e50de11bb90 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -1,13 +1,13 @@ from collections.abc import Mapping from typing import Any -from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult from graphon.nodes.base.node import Node +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID + from .entities import TriggerEventNodeData diff --git a/api/core/workflow/nodes/trigger_schedule/entities.py b/api/core/workflow/nodes/trigger_schedule/entities.py index 207c1e7253e..f14ca893c9e 100644 --- a/api/core/workflow/nodes/trigger_schedule/entities.py +++ b/api/core/workflow/nodes/trigger_schedule/entities.py @@ -1,10 +1,10 @@ from typing import Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel, Field from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType class TriggerScheduleNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index dd80617dfcc..a9753ab387d 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,11 +1,11 @@ from collections.abc import Mapping +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node + from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import NodeExecutionType -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node from .entities import TriggerScheduleNodeData diff --git a/api/core/workflow/nodes/trigger_webhook/entities.py b/api/core/workflow/nodes/trigger_webhook/entities.py index 3125fe17e61..4d5ad72154b 100644 --- a/api/core/workflow/nodes/trigger_webhook/entities.py +++ b/api/core/workflow/nodes/trigger_webhook/entities.py @@ -1,12 +1,12 @@ from collections.abc import Sequence from enum import StrEnum -from pydantic import BaseModel, Field, field_validator - -from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from graphon.entities.base_node_data import BaseNodeData from graphon.enums import NodeType from graphon.variables.types import SegmentType +from pydantic import BaseModel, Field, field_validator + +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE _WEBHOOK_HEADER_ALLOWED_TYPES = frozenset( { diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 6858d6dc359..ebaac939345 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -2,12 +2,7 @@ import logging from collections.abc import Mapping from typing import Any -from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from core.workflow.file_reference import resolve_file_record_id -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from factories.variable_factory import build_segment_with_type -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import NodeExecutionType +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus from graphon.file import FileTransferMethod from graphon.node_events import NodeRunResult from graphon.nodes.base.node import Node @@ -15,6 +10,11 @@ from graphon.nodes.protocols import FileReferenceFactoryProtocol from graphon.variables.types import SegmentType from graphon.variables.variables import FileVariable +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from factories.variable_factory import build_segment_with_type + from .entities import ContentType, WebhookData logger = logging.getLogger(__name__) diff --git a/api/core/workflow/template_rendering.py b/api/core/workflow/template_rendering.py index b4ffb37549d..d51cfadd098 100644 --- a/api/core/workflow/template_rendering.py +++ b/api/core/workflow/template_rendering.py @@ -3,10 +3,11 @@ from __future__ import annotations from collections.abc import Mapping from typing import Any -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor from graphon.nodes.code.entities import CodeLanguage from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor + class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): """Sandbox-backed Jinja2 renderer for workflow-owned node composition.""" diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 7429c95c7c6..2346a95d6a8 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -3,6 +3,20 @@ import time from collections.abc import Generator, Mapping, Sequence from typing import Any +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file import File +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import CommandChannel, InMemoryChannel +from graphon.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool + from configs import dify_config from context import capture_current_context from core.app.apps.exc import GenerateTaskStoppedError @@ -21,20 +35,6 @@ from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add from core.workflow.variable_prefixes import ENVIRONMENT_VARIABLE_NODE_ID from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.errors import WorkflowNodeRunFailedError -from graphon.file.models import File -from graphon.graph import Graph -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer -from graphon.graph_engine.protocols.command_channel import CommandChannel -from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.base.node import Node -from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow logger = logging.getLogger(__name__) diff --git a/api/enterprise/telemetry/draft_trace.py b/api/enterprise/telemetry/draft_trace.py index dff558988c1..5a8d0ee6f49 100644 --- a/api/enterprise/telemetry/draft_trace.py +++ b/api/enterprise/telemetry/draft_trace.py @@ -3,9 +3,10 @@ from __future__ import annotations from collections.abc import Mapping from typing import Any +from graphon.enums import WorkflowNodeExecutionMetadataKey + from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName from core.telemetry import emit as telemetry_emit -from graphon.enums import WorkflowNodeExecutionMetadataKey from models.workflow import WorkflowNodeExecutionModel diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index ba9758175fa..7bd8e88231a 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,11 +1,12 @@ import logging +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.tool.entities import ToolEntity + from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_draft_workflow_was_synced -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.tool.entities import ToolEntity logger = logging.getLogger(__name__) diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 6769b94cde0..86b5b2bbf05 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,11 +1,11 @@ from typing import cast +from graphon.nodes import BuiltinNodeTypes from sqlalchemy import delete, select from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db -from graphon.nodes import BuiltinNodeTypes from models.dataset import AppDatasetJoin from models.workflow import Workflow diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 120febecfbf..651f8ed8989 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,13 +5,12 @@ from dify_app import DifyApp def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import sentry_sdk + from graphon.model_runtime.errors.invoke import InvokeRateLimitError from langfuse import parse_error from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException - from graphon.model_runtime.errors.invoke import InvokeRateLimitError - def before_send(event, hint): if "exc_info" in hint: _, exc_value, _ = hint["exc_info"] diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index 64ff0f06745..db599c5d495 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -11,12 +11,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Any +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value -from graphon.enums import WorkflowNodeExecutionStatus from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 5208f8f37ea..3c83ab4f84e 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -20,12 +20,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, cast +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string -from graphon.enums import WorkflowExecutionStatus from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index ea4a2b3dd17..f71b2fa1df9 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -4,14 +4,14 @@ import os import time from typing import Union +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from extensions.logstore.aliyun_logstore import AliyunLogStore -from graphon.entities import WorkflowExecution -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index 976b5db8e30..b7254366817 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -13,6 +13,10 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, Union +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -21,10 +25,6 @@ from core.repositories.factory import OrderConfig, WorkflowNodeExecutionReposito from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier -from graphon.entities import WorkflowNodeExecution -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index eefcaa126e6..23d324f9ead 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -10,17 +10,17 @@ Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when import json from typing import Any, Protocol +from graphon.enums import BuiltinNodeTypes +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel from configs import dify_config from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes -from graphon.enums import BuiltinNodeTypes -from graphon.file.models import File -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.variables import Segment def should_include_content() -> bool: diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py index ec3c78a12d4..335c5cc29e2 100644 --- a/api/extensions/otel/parser/llm.py +++ b/api/extensions/otel/parser/llm.py @@ -6,12 +6,12 @@ import logging from collections.abc import Mapping from typing import Any +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import LLMAttributes -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py index 56672d1fd45..6df5f62c155 100644 --- a/api/extensions/otel/parser/retrieval.py +++ b/api/extensions/otel/parser/retrieval.py @@ -6,13 +6,13 @@ import logging from collections.abc import Sequence from typing import Any +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import RetrieverAttributes -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.variables import Segment logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py index 75ddbba4480..b9fdd9e1caa 100644 --- a/api/extensions/otel/parser/tool.py +++ b/api/extensions/otel/parser/tool.py @@ -2,14 +2,14 @@ Parser for tool nodes that captures tool-specific metadata. """ -from opentelemetry.trace import Span - -from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps -from extensions.otel.semconv.gen_ai import ToolAttributes from graphon.enums import WorkflowNodeExecutionMetadataKey from graphon.graph_events import GraphNodeEventBase from graphon.nodes.base.node import Node from graphon.nodes.tool.entities import ToolNodeData +from opentelemetry.trace import Span + +from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps +from extensions.otel.semconv.gen_ai import ToolAttributes class ToolNodeOTelParser: diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py index bc87510d439..7516d18c8e0 100644 --- a/api/factories/file_factory/builders.py +++ b/api/factories/file_factory/builders.py @@ -7,13 +7,12 @@ import uuid from collections.abc import Mapping, Sequence from typing import Any +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type from sqlalchemy import select from core.app.file_access import FileAccessControllerProtocol from core.workflow.file_reference import build_file_reference from extensions.ext_database import db -from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers -from graphon.file.file_factory import standardize_file_type from models import ToolFile, UploadFile from .common import resolve_mapping_file_id diff --git a/api/factories/file_factory/message_files.py b/api/factories/file_factory/message_files.py index 4b3d5142386..5582b85c956 100644 --- a/api/factories/file_factory/message_files.py +++ b/api/factories/file_factory/message_files.py @@ -4,8 +4,9 @@ from __future__ import annotations from collections.abc import Sequence -from core.app.file_access import FileAccessControllerProtocol from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig + +from core.app.file_access import FileAccessControllerProtocol from models import MessageFile from .builders import build_from_mapping diff --git a/api/factories/file_factory/storage_keys.py b/api/factories/file_factory/storage_keys.py index dba4c84407f..db3a7f30159 100644 --- a/api/factories/file_factory/storage_keys.py +++ b/api/factories/file_factory/storage_keys.py @@ -5,12 +5,12 @@ from __future__ import annotations import uuid from collections.abc import Mapping, Sequence +from graphon.file import File, FileTransferMethod from sqlalchemy import select from sqlalchemy.orm import Session from core.app.file_access import FileAccessControllerProtocol from core.workflow.file_reference import build_file_reference, parse_file_reference -from graphon.file import File, FileTransferMethod from models import ToolFile, UploadFile diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index fd7acb14d3a..57205b5739f 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -8,11 +8,6 @@ shared conversion functions for legacy callers and tests. from collections.abc import Mapping, Sequence from typing import Any, cast -from configs import dify_config -from core.workflow.variable_prefixes import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, -) from graphon.variables.exc import VariableError from graphon.variables.factory import ( TypeMismatchError, @@ -36,6 +31,12 @@ from graphon.variables.variables import ( VariableBase, ) +from configs import dify_config +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, +) + __all__ = [ "TypeMismatchError", "UnsupportedSegmentTypeError", diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 801949747e1..30d02aeedc2 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -3,9 +3,8 @@ from __future__ import annotations from datetime import datetime from typing import Any, TypeAlias -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator - from graphon.file import File +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator JSONValue: TypeAlias = Any diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 4e201e66e62..b8daa5af303 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -3,9 +3,8 @@ from __future__ import annotations from datetime import datetime from flask_restx import fields -from pydantic import BaseModel, ConfigDict, computed_field, field_validator - from graphon.file import helpers as file_helpers +from pydantic import BaseModel, ConfigDict, computed_field, field_validator simple_account_fields = { "id": fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 86c4f285cd2..d982c31aeeb 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -4,11 +4,11 @@ from datetime import datetime from typing import TypeAlias from uuid import uuid4 +from graphon.file import File from pydantic import BaseModel, ConfigDict, Field, field_validator from core.entities.execution_extra_content import ExecutionExtraContentDomainModel from fields.conversation_fields import AgentThought, JSONValue, MessageFile -from graphon.file import File JSONValueType: TypeAlias = JSONValue diff --git a/api/fields/raws.py b/api/fields/raws.py index ee6f53b360c..4c65cdab7af 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,5 +1,4 @@ from flask_restx import fields - from graphon.file import File diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index f9b5e989364..b0b6cc0b483 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,8 +1,8 @@ from flask_restx import fields +from graphon.variables import SecretVariable, SegmentType, VariableBase from core.helper import encrypter from fields.member_fields import simple_account_fields -from graphon.variables import SecretVariable, SegmentType, VariableBase from libs.helper import TimestampField from ._value_type_serializer import serialize_value_type diff --git a/api/graphon/README.md b/api/graphon/README.md deleted file mode 100644 index 725f122cd84..00000000000 --- a/api/graphon/README.md +++ /dev/null @@ -1,135 +0,0 @@ -# Workflow - -## Project Overview - -This is the workflow graph engine module of Dify, implementing a queue-based distributed workflow execution system. The engine handles agentic AI workflows with support for parallel execution, node iteration, conditional logic, and external command control. - -## Architecture - -### Core Components - -The graph engine follows a layered architecture with strict dependency rules: - -1. **Graph Engine** (`graph_engine/`) - Orchestrates workflow execution - - - **Manager** - External control interface for stop/pause/resume commands - - **Worker** - Node execution runtime - - **Command Processing** - Handles control commands (abort, pause, resume) - - **Event Management** - Event propagation and layer notifications - - **Graph Traversal** - Edge processing and skip propagation - - **Response Coordinator** - Path tracking and session management - - **Layers** - Pluggable middleware (debug logging, execution limits) - - **Command Channels** - Communication channels (InMemory, Redis) - -1. **Graph** (`graph/`) - Graph structure and runtime state - - - **Graph Template** - Workflow definition - - **Edge** - Node connections with conditions - - **Runtime State Protocol** - State management interface - -1. **Nodes** (`nodes/`) - Node implementations - - - **Base** - Abstract node classes and variable parsing - - **Specific Nodes** - LLM, Agent, Code, HTTP Request, Iteration, Loop, etc. - -1. **Events** (`node_events/`) - Event system - - - **Base** - Event protocols - - **Node Events** - Node lifecycle events - -1. **Entities** (`entities/`) - Domain models - - - **Variable Pool** - Variable storage - - **Graph Init Params** - Initialization configuration - -## Key Design Patterns - -### Command Channel Pattern - -External workflow control via Redis or in-memory channels: - -```python -# Send stop command to running workflow -channel = RedisChannel(redis_client, f"workflow:{task_id}:commands") -channel.send_command(AbortCommand(reason="User requested")) -``` - -### Layer System - -Extensible middleware for cross-cutting concerns: - -```python -engine = GraphEngine(graph) -engine.layer(DebugLoggingLayer(level="INFO")) -engine.layer(ExecutionLimitsLayer(max_nodes=100)) -``` - -`engine.layer()` binds the read-only runtime state before execution, so layer hooks -can assume `graph_runtime_state` is available. - -### Event-Driven Architecture - -All node executions emit events for monitoring and integration: - -- `NodeRunStartedEvent` - Node execution begins -- `NodeRunSucceededEvent` - Node completes successfully -- `NodeRunFailedEvent` - Node encounters error -- `GraphRunStartedEvent/GraphRunCompletedEvent` - Workflow lifecycle - -### Variable Pool - -Centralized variable storage with namespace isolation: - -```python -# Variables scoped by node_id -pool.add(["node1", "output"], value) -result = pool.get(["node1", "output"]) -``` - -## Import Architecture Rules - -The codebase enforces strict layering via import-linter: - -1. **Workflow Layers** (top to bottom): - - - graph_engine โ†’ graph_events โ†’ graph โ†’ nodes โ†’ node_events โ†’ entities - -1. **Graph Engine Internal Layers**: - - - orchestration โ†’ command_processing โ†’ event_management โ†’ graph_traversal โ†’ domain - -1. **Domain Isolation**: - - - Domain models cannot import from infrastructure layers - -1. **Command Channel Independence**: - - - InMemory and Redis channels must remain independent - -## Common Tasks - -### Adding a New Node Type - -1. Create node class in `nodes//` -1. Inherit from `BaseNode` or appropriate base class -1. Implement `_run()` method -1. Ensure the node module is importable under `nodes//` -1. Add tests in `tests/unit_tests/graphon/nodes/` - -### Implementing a Custom Layer - -1. Create class inheriting from `Layer` base -1. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()` -1. Add to engine via `engine.layer()` - -### Debugging Workflow Execution - -Enable debug logging layer: - -```python -debug_layer = DebugLoggingLayer( - level="DEBUG", - include_inputs=True, - include_outputs=True -) -``` diff --git a/api/graphon/__init__.py b/api/graphon/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/graphon/entities/__init__.py b/api/graphon/entities/__init__.py deleted file mode 100644 index ef7789c49c8..00000000000 --- a/api/graphon/entities/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .graph_init_params import GraphInitParams -from .workflow_execution import WorkflowExecution -from .workflow_node_execution import WorkflowNodeExecution -from .workflow_start_reason import WorkflowStartReason - -__all__ = [ - "GraphInitParams", - "WorkflowExecution", - "WorkflowNodeExecution", - "WorkflowStartReason", -] diff --git a/api/graphon/entities/base_node_data.py b/api/graphon/entities/base_node_data.py deleted file mode 100644 index e8267043a9c..00000000000 --- a/api/graphon/entities/base_node_data.py +++ /dev/null @@ -1,178 +0,0 @@ -from __future__ import annotations - -import json -from abc import ABC -from builtins import type as type_ -from enum import StrEnum -from typing import Any, Union - -from pydantic import BaseModel, ConfigDict, Field, model_validator - -from graphon.entities.exc import DefaultValueTypeError -from graphon.enums import ErrorStrategy, NodeType - -# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`. -_NumberType = Union[int, float] - - -class RetryConfig(BaseModel): - """node retry config""" - - max_retries: int = 0 # max retry times - retry_interval: int = 0 # retry interval in milliseconds - retry_enabled: bool = False # whether retry is enabled - - @property - def retry_interval_seconds(self) -> float: - return self.retry_interval / 1000 - - -class DefaultValueType(StrEnum): - STRING = "string" - NUMBER = "number" - OBJECT = "object" - ARRAY_NUMBER = "array[number]" - ARRAY_STRING = "array[string]" - ARRAY_OBJECT = "array[object]" - ARRAY_FILES = "array[file]" - - -class DefaultValue(BaseModel): - value: Any = None - type: DefaultValueType - key: str - - @staticmethod - def _parse_json(value: str): - """Unified JSON parsing handler""" - try: - return json.loads(value) - except json.JSONDecodeError: - raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") - - @staticmethod - def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool: - """Unified array type validation""" - return isinstance(value, list) and all(isinstance(x, element_type) for x in value) - - @staticmethod - def _convert_number(value: str) -> float: - """Unified number conversion handler""" - try: - return float(value) - except ValueError: - raise DefaultValueTypeError(f"Cannot convert to number: {value}") - - @model_validator(mode="after") - def validate_value_type(self) -> DefaultValue: - # Type validation configuration - type_validators: dict[DefaultValueType, dict[str, Any]] = { - DefaultValueType.STRING: { - "type": str, - "converter": lambda x: x, - }, - DefaultValueType.NUMBER: { - "type": _NumberType, - "converter": self._convert_number, - }, - DefaultValueType.OBJECT: { - "type": dict, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_NUMBER: { - "type": list, - "element_type": _NumberType, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_STRING: { - "type": list, - "element_type": str, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_OBJECT: { - "type": list, - "element_type": dict, - "converter": self._parse_json, - }, - } - - validator: dict[str, Any] = type_validators.get(self.type, {}) - if not validator: - if self.type == DefaultValueType.ARRAY_FILES: - # Handle files type - return self - raise DefaultValueTypeError(f"Unsupported type: {self.type}") - - # Handle string input cases - if isinstance(self.value, str) and self.type != DefaultValueType.STRING: - self.value = validator["converter"](self.value) - - # Validate base type - if not isinstance(self.value, validator["type"]): - raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") - - # Validate array element types - if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): - raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") - - return self - - -class BaseNodeData(ABC, BaseModel): - # Raw graph payloads are first validated through `NodeConfigDictAdapter`, where - # `node["data"]` is typed as `BaseNodeData` before the concrete node class is known. - # `type` therefore accepts downstream string node kinds; unknown node implementations - # are rejected later when the node factory resolves the node registry. - # At that boundary, node-specific fields are still "extra" relative to this shared DTO, - # and persisted templates/workflows also carry undeclared compatibility keys such as - # `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive - # here until graph parsing becomes discriminated by node type or those legacy payloads - # are normalized. - model_config = ConfigDict(extra="allow") - - type: NodeType - title: str = "" - desc: str | None = None - version: str = "1" - error_strategy: ErrorStrategy | None = None - default_value: list[DefaultValue] | None = None - retry_config: RetryConfig = Field(default_factory=RetryConfig) - - @property - def default_value_dict(self) -> dict[str, Any]: - if self.default_value: - return {item.key: item.value for item in self.default_value} - return {} - - def __getitem__(self, key: str) -> Any: - """ - Dict-style access without calling model_dump() on every lookup. - Prefer using model fields and Pydantic's extra storage. - """ - # First, check declared model fields - if key in self.__class__.model_fields: - return getattr(self, key) - - # Then, check undeclared compatibility fields stored in Pydantic's extra dict. - extras = getattr(self, "__pydantic_extra__", None) - if extras is None: - extras = getattr(self, "model_extra", None) - if extras is not None and key in extras: - return extras[key] - - raise KeyError(key) - - def get(self, key: str, default: Any = None) -> Any: - """ - Dict-style .get() without calling model_dump() on every lookup. - """ - if key in self.__class__.model_fields: - return getattr(self, key) - - extras = getattr(self, "__pydantic_extra__", None) - if extras is None: - extras = getattr(self, "model_extra", None) - if extras is not None and key in extras: - return extras.get(key, default) - - return default diff --git a/api/graphon/entities/exc.py b/api/graphon/entities/exc.py deleted file mode 100644 index aeecf406403..00000000000 --- a/api/graphon/entities/exc.py +++ /dev/null @@ -1,10 +0,0 @@ -class BaseNodeError(ValueError): - """Base class for node errors.""" - - pass - - -class DefaultValueTypeError(BaseNodeError): - """Raised when the default value type is invalid.""" - - pass diff --git a/api/graphon/entities/graph_config.py b/api/graphon/entities/graph_config.py deleted file mode 100644 index 392241c6317..00000000000 --- a/api/graphon/entities/graph_config.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -import sys - -from pydantic import TypeAdapter, with_config - -from graphon.entities.base_node_data import BaseNodeData - -if sys.version_info >= (3, 12): - from typing import TypedDict -else: - from typing_extensions import TypedDict - - -@with_config(extra="allow") -class NodeConfigDict(TypedDict): - id: str - # This is the permissive raw graph boundary. Node factories re-validate `data` - # with the concrete `NodeData` subtype after resolving the node implementation. - data: BaseNodeData - - -NodeConfigDictAdapter = TypeAdapter(NodeConfigDict) diff --git a/api/graphon/entities/graph_init_params.py b/api/graphon/entities/graph_init_params.py deleted file mode 100644 index f785d58a528..00000000000 --- a/api/graphon/entities/graph_init_params.py +++ /dev/null @@ -1,24 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import BaseModel, Field - -DIFY_RUN_CONTEXT_KEY = "_dify" - - -class GraphInitParams(BaseModel): - """GraphInitParams encapsulates the configurations and contextual information - that remain constant throughout a single execution of the graph engine. - - A single execution is defined as follows: as long as the execution has not reached - its conclusion, it is considered one execution. For instance, if a workflow is suspended - and later resumed, it is still regarded as a single execution, not two. - - For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`. - """ - - # init params - workflow_id: str = Field(..., description="workflow id") - graph_config: Mapping[str, Any] = Field(..., description="graph config") - run_context: Mapping[str, Any] = Field(..., description="runtime context") - call_depth: int = Field(..., description="call depth") diff --git a/api/graphon/entities/pause_reason.py b/api/graphon/entities/pause_reason.py deleted file mode 100644 index ba2973fd450..00000000000 --- a/api/graphon/entities/pause_reason.py +++ /dev/null @@ -1,42 +0,0 @@ -from collections.abc import Mapping -from enum import StrEnum, auto -from typing import Annotated, Any, Literal, TypeAlias - -from pydantic import BaseModel, Field - -from graphon.nodes.human_input.entities import FormInput, UserAction - - -class PauseReasonType(StrEnum): - HUMAN_INPUT_REQUIRED = auto() - SCHEDULED_PAUSE = auto() - - -class HumanInputRequired(BaseModel): - TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED - form_id: str - form_content: str - inputs: list[FormInput] = Field(default_factory=list) - actions: list[UserAction] = Field(default_factory=list) - node_id: str - node_title: str - - # The `resolved_default_values` stores the resolved values of variable defaults. It's a mapping from - # `output_variable_name` to their resolved values. - # - # For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its - # selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable - # `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The - # `resolved_default_values` is `{"name": "John"}`. - # - # Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`. - resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) - - -class SchedulingPause(BaseModel): - TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE - - message: str - - -PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")] diff --git a/api/graphon/entities/workflow_execution.py b/api/graphon/entities/workflow_execution.py deleted file mode 100644 index b8de7eed1a6..00000000000 --- a/api/graphon/entities/workflow_execution.py +++ /dev/null @@ -1,71 +0,0 @@ -""" -Domain entities for workflow execution. - -Models describe graph runtime state and avoid infrastructure-specific details. -""" - -from __future__ import annotations - -from collections.abc import Mapping -from datetime import UTC, datetime -from typing import Any - -from pydantic import BaseModel, Field - -from graphon.enums import WorkflowExecutionStatus, WorkflowType - - -class WorkflowExecution(BaseModel): - """ - Domain model for a workflow execution within the graph runtime. - """ - - id_: str = Field(...) - workflow_id: str = Field(...) - workflow_version: str = Field(...) - workflow_type: WorkflowType = Field(...) - graph: Mapping[str, Any] = Field(...) - - inputs: Mapping[str, Any] = Field(...) - outputs: Mapping[str, Any] | None = None - - status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING - error_message: str = Field(default="") - total_tokens: int = Field(default=0) - total_steps: int = Field(default=0) - exceptions_count: int = Field(default=0) - - started_at: datetime = Field(...) - finished_at: datetime | None = None - - @property - def elapsed_time(self) -> float: - """ - Calculate elapsed time in seconds. - If workflow is not finished, use current time. - """ - end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None) - return (end_time - self.started_at).total_seconds() - - @classmethod - def new( - cls, - *, - id_: str, - workflow_id: str, - workflow_type: WorkflowType, - workflow_version: str, - graph: Mapping[str, Any], - inputs: Mapping[str, Any], - started_at: datetime, - ) -> WorkflowExecution: - return WorkflowExecution( - id_=id_, - workflow_id=workflow_id, - workflow_type=workflow_type, - workflow_version=workflow_version, - graph=graph, - inputs=inputs, - status=WorkflowExecutionStatus.RUNNING, - started_at=started_at, - ) diff --git a/api/graphon/entities/workflow_node_execution.py b/api/graphon/entities/workflow_node_execution.py deleted file mode 100644 index 5458572e7e1..00000000000 --- a/api/graphon/entities/workflow_node_execution.py +++ /dev/null @@ -1,141 +0,0 @@ -""" -Domain entities for workflow node execution. - -These models capture node-level execution state for the graph runtime without -describing storage or application-layer concerns. -""" - -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, Field, PrivateAttr - -from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus - - -class WorkflowNodeExecution(BaseModel): - """ - Domain model for workflow node execution. - - This model represents the graph-level record of a node execution and - contains only execution state relevant to the runtime. - """ - - # --------- Core identification fields --------- - - # Unique identifier for this execution record, used when persisting to storage. - # Value is a UUID string (e.g., '09b3e04c-f9ae-404c-ad82-290b8d7bd382'). - id: str - - # Optional secondary ID for cross-referencing purposes. - # - # NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`. - # While `node_execution_id` may sometimes be a UUID string, this is not guaranteed. - # In most scenarios, `id` should be used as the primary identifier. - node_execution_id: str | None = None - workflow_id: str # ID of the workflow this node belongs to - workflow_execution_id: str | None = None # ID of the workflow execution (null for single-step debugging) - # --------- Core identification fields ends --------- - - # Execution positioning and flow - index: int # Sequence number for ordering in trace visualization - predecessor_node_id: str | None = None # ID of the node that executed before this one - node_id: str # ID of the node being executed - node_type: NodeType # Type of node (e.g., start, llm, downstream response node) - title: str # Display title of the node - - # Execution data - # The `inputs` and `outputs` fields hold the full content - inputs: Mapping[str, Any] | None = None # Input variables used by this node - process_data: Mapping[str, Any] | None = None # Intermediate processing data - outputs: Mapping[str, Any] | None = None # Output variables produced by this node - - # Execution state - status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status - error: str | None = None # Error message if execution failed - elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds - - # Additional metadata - metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.) - - # Timing information - created_at: datetime # When execution started - finished_at: datetime | None = None # When execution completed - - _truncated_inputs: Mapping[str, Any] | None = PrivateAttr(None) - _truncated_outputs: Mapping[str, Any] | None = PrivateAttr(None) - _truncated_process_data: Mapping[str, Any] | None = PrivateAttr(None) - - def get_truncated_inputs(self) -> Mapping[str, Any] | None: - return self._truncated_inputs - - def get_truncated_outputs(self) -> Mapping[str, Any] | None: - return self._truncated_outputs - - def get_truncated_process_data(self) -> Mapping[str, Any] | None: - return self._truncated_process_data - - def set_truncated_inputs(self, truncated_inputs: Mapping[str, Any] | None): - self._truncated_inputs = truncated_inputs - - def set_truncated_outputs(self, truncated_outputs: Mapping[str, Any] | None): - self._truncated_outputs = truncated_outputs - - def set_truncated_process_data(self, truncated_process_data: Mapping[str, Any] | None): - self._truncated_process_data = truncated_process_data - - def get_response_inputs(self) -> Mapping[str, Any] | None: - inputs = self.get_truncated_inputs() - if inputs: - return inputs - return self.inputs - - @property - def inputs_truncated(self): - return self._truncated_inputs is not None - - @property - def outputs_truncated(self): - return self._truncated_outputs is not None - - @property - def process_data_truncated(self): - return self._truncated_process_data is not None - - def get_response_outputs(self) -> Mapping[str, Any] | None: - outputs = self.get_truncated_outputs() - if outputs is not None: - return outputs - return self.outputs - - def get_response_process_data(self) -> Mapping[str, Any] | None: - process_data = self.get_truncated_process_data() - if process_data is not None: - return process_data - return self.process_data - - def update_from_mapping( - self, - inputs: Mapping[str, Any] | None = None, - process_data: Mapping[str, Any] | None = None, - outputs: Mapping[str, Any] | None = None, - metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None, - ): - """ - Update the model from mappings. - - Args: - inputs: The inputs to update - process_data: The process data to update - outputs: The outputs to update - metadata: The metadata to update - """ - if inputs is not None: - self.inputs = dict(inputs) - if process_data is not None: - self.process_data = dict(process_data) - if outputs is not None: - self.outputs = dict(outputs) - if metadata is not None: - self.metadata = dict(metadata) diff --git a/api/graphon/entities/workflow_start_reason.py b/api/graphon/entities/workflow_start_reason.py deleted file mode 100644 index df0f75383b0..00000000000 --- a/api/graphon/entities/workflow_start_reason.py +++ /dev/null @@ -1,8 +0,0 @@ -from enum import StrEnum - - -class WorkflowStartReason(StrEnum): - """Reason for workflow start events across graph/queue/SSE layers.""" - - INITIAL = "initial" # First start of a workflow run. - RESUMPTION = "resumption" # Start triggered after resuming a paused run. diff --git a/api/graphon/enums.py b/api/graphon/enums.py deleted file mode 100644 index bbc973abe5b..00000000000 --- a/api/graphon/enums.py +++ /dev/null @@ -1,262 +0,0 @@ -from enum import StrEnum -from typing import ClassVar, TypeAlias - - -class NodeState(StrEnum): - """State of a node or edge during workflow execution.""" - - UNKNOWN = "unknown" - TAKEN = "taken" - SKIPPED = "skipped" - - -NodeType: TypeAlias = str - - -class BuiltinNodeTypes: - """Built-in node type string constants. - - `node_type` values are plain strings throughout the graph runtime. This namespace - only exposes the built-in values shipped by `graphon`; downstream packages can - use additional strings without extending this class. - """ - - START: ClassVar[NodeType] = "start" - END: ClassVar[NodeType] = "end" - ANSWER: ClassVar[NodeType] = "answer" - LLM: ClassVar[NodeType] = "llm" - KNOWLEDGE_RETRIEVAL: ClassVar[NodeType] = "knowledge-retrieval" - IF_ELSE: ClassVar[NodeType] = "if-else" - CODE: ClassVar[NodeType] = "code" - TEMPLATE_TRANSFORM: ClassVar[NodeType] = "template-transform" - QUESTION_CLASSIFIER: ClassVar[NodeType] = "question-classifier" - HTTP_REQUEST: ClassVar[NodeType] = "http-request" - TOOL: ClassVar[NodeType] = "tool" - DATASOURCE: ClassVar[NodeType] = "datasource" - VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-aggregator" - LEGACY_VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-assigner" - LOOP: ClassVar[NodeType] = "loop" - LOOP_START: ClassVar[NodeType] = "loop-start" - LOOP_END: ClassVar[NodeType] = "loop-end" - ITERATION: ClassVar[NodeType] = "iteration" - ITERATION_START: ClassVar[NodeType] = "iteration-start" - PARAMETER_EXTRACTOR: ClassVar[NodeType] = "parameter-extractor" - VARIABLE_ASSIGNER: ClassVar[NodeType] = "assigner" - DOCUMENT_EXTRACTOR: ClassVar[NodeType] = "document-extractor" - LIST_OPERATOR: ClassVar[NodeType] = "list-operator" - AGENT: ClassVar[NodeType] = "agent" - HUMAN_INPUT: ClassVar[NodeType] = "human-input" - - -BUILT_IN_NODE_TYPES: tuple[NodeType, ...] = ( - BuiltinNodeTypes.START, - BuiltinNodeTypes.END, - BuiltinNodeTypes.ANSWER, - BuiltinNodeTypes.LLM, - BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, - BuiltinNodeTypes.IF_ELSE, - BuiltinNodeTypes.CODE, - BuiltinNodeTypes.TEMPLATE_TRANSFORM, - BuiltinNodeTypes.QUESTION_CLASSIFIER, - BuiltinNodeTypes.HTTP_REQUEST, - BuiltinNodeTypes.TOOL, - BuiltinNodeTypes.DATASOURCE, - BuiltinNodeTypes.VARIABLE_AGGREGATOR, - BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR, - BuiltinNodeTypes.LOOP, - BuiltinNodeTypes.LOOP_START, - BuiltinNodeTypes.LOOP_END, - BuiltinNodeTypes.ITERATION, - BuiltinNodeTypes.ITERATION_START, - BuiltinNodeTypes.PARAMETER_EXTRACTOR, - BuiltinNodeTypes.VARIABLE_ASSIGNER, - BuiltinNodeTypes.DOCUMENT_EXTRACTOR, - BuiltinNodeTypes.LIST_OPERATOR, - BuiltinNodeTypes.AGENT, - BuiltinNodeTypes.HUMAN_INPUT, -) - - -class NodeExecutionType(StrEnum): - """Node execution type classification.""" - - EXECUTABLE = "executable" # Regular nodes that execute and produce outputs - RESPONSE = "response" # Response nodes that stream outputs (Answer, End) - BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier) - CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph) - ROOT = "root" # Nodes that can serve as execution entry points - - -class ErrorStrategy(StrEnum): - FAIL_BRANCH = "fail-branch" - DEFAULT_VALUE = "default-value" - - -class FailBranchSourceHandle(StrEnum): - FAILED = "fail-branch" - SUCCESS = "success-branch" - - -class WorkflowType(StrEnum): - """ - Workflow Type Enum for domain layer - """ - - WORKFLOW = "workflow" - CHAT = "chat" - RAG_PIPELINE = "rag-pipeline" - - -class WorkflowExecutionStatus(StrEnum): - # State diagram for the workflw status: - # (@) means start, (*) means end - # - # โ”Œ------------------>------------------------->------------------->--------------โ” - # | | - # | โ”Œ-----------------------<--------------------โ” | - # ^ | | | - # | | ^ | - # | V | | - # โ”Œ-----------โ” โ”Œ-----------------------โ” โ”Œ-----------โ” V - # | Scheduled |------->| Running |---------------------->| paused | | - # โ””-----------โ”˜ โ””-----------------------โ”˜ โ””-----------โ”˜ | - # | | | | | | | - # | | | | | | | - # ^ | | | V V | - # | | | | | โ”Œ---------โ” | - # (@) | | | โ””------------------------>| Stopped |<----โ”˜ - # | | | โ””---------โ”˜ - # | | | | - # | | V V - # | | โ”Œ-----------โ” | - # | | | Succeeded |------------->--------------โ”ค - # | | โ””-----------โ”˜ | - # | V V - # | +--------โ” | - # | | Failed |---------------------->----------------โ”ค - # | โ””--------โ”˜ | - # V V - # โ”Œ---------------------โ” | - # | Partially Succeeded |---------------------->-----------------โ”˜--------> (*) - # โ””---------------------โ”˜ - # - # Mermaid diagram: - # - # --- - # title: State diagram for Workflow run state - # --- - # stateDiagram-v2 - # scheduled: Scheduled - # running: Running - # succeeded: Succeeded - # failed: Failed - # partial_succeeded: Partial Succeeded - # paused: Paused - # stopped: Stopped - # - # [*] --> scheduled: - # scheduled --> running: Start Execution - # running --> paused: Human input required - # paused --> running: human input added - # paused --> stopped: User stops execution - # running --> succeeded: Execution finishes without any error - # running --> failed: Execution finishes with errors - # running --> stopped: User stops execution - # running --> partial_succeeded: some execution occurred and handled during execution - # - # scheduled --> stopped: User stops execution - # - # succeeded --> [*] - # failed --> [*] - # partial_succeeded --> [*] - # stopped --> [*] - - # `SCHEDULED` means that the workflow is scheduled to run, but has not - # started running yet. (maybe due to possible worker saturation.) - # - # This enum value is currently unused. - SCHEDULED = "scheduled" - - # `RUNNING` means the workflow is exeuting. - RUNNING = "running" - - # `SUCCEEDED` means the execution of workflow succeed without any error. - SUCCEEDED = "succeeded" - - # `FAILED` means the execution of workflow failed without some errors. - FAILED = "failed" - - # `STOPPED` means the execution of workflow was stopped, either manually - # by the user, or automatically by the Dify application (E.G. the moderation - # mechanism.) - STOPPED = "stopped" - - # `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow - # execution, but they were successfully handled (e.g., by using an error - # strategy such as "fail branch" or "default value"). - PARTIAL_SUCCEEDED = "partial-succeeded" - - # `PAUSED` indicates that the workflow execution is temporarily paused - # (e.g., awaiting human input) and is expected to resume later. - PAUSED = "paused" - - def is_ended(self) -> bool: - return self in _END_STATE - - @classmethod - def ended_values(cls) -> list[str]: - return [status.value for status in _END_STATE] - - -_END_STATE = frozenset( - [ - WorkflowExecutionStatus.SUCCEEDED, - WorkflowExecutionStatus.FAILED, - WorkflowExecutionStatus.PARTIAL_SUCCEEDED, - WorkflowExecutionStatus.STOPPED, - ] -) - - -class WorkflowNodeExecutionMetadataKey(StrEnum): - """ - Node Run Metadata Key. - - Values in this enum are persisted as execution metadata and must stay in sync - with every node that writes `NodeRunResult.metadata`. - """ - - TOTAL_TOKENS = "total_tokens" - TOTAL_PRICE = "total_price" - CURRENCY = "currency" - TOOL_INFO = "tool_info" - AGENT_LOG = "agent_log" - ITERATION_ID = "iteration_id" - ITERATION_INDEX = "iteration_index" - LOOP_ID = "loop_id" - LOOP_INDEX = "loop_index" - PARALLEL_ID = "parallel_id" - PARALLEL_START_NODE_ID = "parallel_start_node_id" - PARENT_PARALLEL_ID = "parent_parallel_id" - PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" - PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" - ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs - LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs - ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field - LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output - DATASOURCE_INFO = "datasource_info" - TRIGGER_INFO = "trigger_info" - COMPLETED_REASON = "completed_reason" # completed reason for loop node - - -class WorkflowNodeExecutionStatus(StrEnum): - PENDING = "pending" # Node is scheduled but not yet executing - RUNNING = "running" - SUCCEEDED = "succeeded" - FAILED = "failed" - EXCEPTION = "exception" - STOPPED = "stopped" - PAUSED = "paused" - - # Legacy statuses - kept for backward compatibility - RETRY = "retry" # Legacy: replaced by retry mechanism in error handling diff --git a/api/graphon/errors.py b/api/graphon/errors.py deleted file mode 100644 index 7eb007524d1..00000000000 --- a/api/graphon/errors.py +++ /dev/null @@ -1,16 +0,0 @@ -from graphon.nodes.base.node import Node - - -class WorkflowNodeRunFailedError(Exception): - def __init__(self, node: Node, err_msg: str): - self._node = node - self._error = err_msg - super().__init__(f"Node {node.title} run failed: {err_msg}") - - @property - def node(self) -> Node: - return self._node - - @property - def error(self) -> str: - return self._error diff --git a/api/graphon/file/__init__.py b/api/graphon/file/__init__.py deleted file mode 100644 index 4908ae9795b..00000000000 --- a/api/graphon/file/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .constants import FILE_MODEL_IDENTITY -from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType -from .file_factory import get_file_type_by_mime_type, standardize_file_type -from .models import ( - File, - FileUploadConfig, - ImageConfig, -) - -__all__ = [ - "FILE_MODEL_IDENTITY", - "ArrayFileAttribute", - "File", - "FileAttribute", - "FileBelongsTo", - "FileTransferMethod", - "FileType", - "FileUploadConfig", - "ImageConfig", - "get_file_type_by_mime_type", - "standardize_file_type", -] diff --git a/api/graphon/file/constants.py b/api/graphon/file/constants.py deleted file mode 100644 index 56b95b5f0d6..00000000000 --- a/api/graphon/file/constants.py +++ /dev/null @@ -1,48 +0,0 @@ -from collections.abc import Iterable -from typing import Any - -# TODO(QuantumGhost): Refactor variable type identification. Instead of directly -# comparing `dify_model_identity` with constants throughout the codebase, extract -# this logic into a dedicated function. This would encapsulate the implementation -# details of how different variable types are identified. -FILE_MODEL_IDENTITY = "__dify__file__" -DEFAULT_MIME_TYPE = "application/octet-stream" -DEFAULT_EXTENSION = ".bin" - - -def _with_case_variants(extensions: Iterable[str]) -> frozenset[str]: - normalized = {extension.lower() for extension in extensions} - return frozenset(normalized | {extension.upper() for extension in normalized}) - - -IMAGE_EXTENSIONS = _with_case_variants({"jpg", "jpeg", "png", "webp", "gif", "svg"}) -VIDEO_EXTENSIONS = _with_case_variants({"mp4", "mov", "mpeg", "webm"}) -AUDIO_EXTENSIONS = _with_case_variants({"mp3", "m4a", "wav", "amr", "mpga"}) -DOCUMENT_EXTENSIONS = _with_case_variants( - { - "txt", - "markdown", - "md", - "mdx", - "pdf", - "html", - "htm", - "xlsx", - "xls", - "vtt", - "properties", - "doc", - "docx", - "csv", - "eml", - "msg", - "ppt", - "pptx", - "xml", - "epub", - } -) - - -def maybe_file_object(o: Any) -> bool: - return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY diff --git a/api/graphon/file/enums.py b/api/graphon/file/enums.py deleted file mode 100644 index 170eb4fc233..00000000000 --- a/api/graphon/file/enums.py +++ /dev/null @@ -1,57 +0,0 @@ -from enum import StrEnum - - -class FileType(StrEnum): - IMAGE = "image" - DOCUMENT = "document" - AUDIO = "audio" - VIDEO = "video" - CUSTOM = "custom" - - @staticmethod - def value_of(value): - for member in FileType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileTransferMethod(StrEnum): - REMOTE_URL = "remote_url" - LOCAL_FILE = "local_file" - TOOL_FILE = "tool_file" - DATASOURCE_FILE = "datasource_file" - - @staticmethod - def value_of(value): - for member in FileTransferMethod: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileBelongsTo(StrEnum): - USER = "user" - ASSISTANT = "assistant" - - @staticmethod - def value_of(value): - for member in FileBelongsTo: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileAttribute(StrEnum): - TYPE = "type" - SIZE = "size" - NAME = "name" - MIME_TYPE = "mime_type" - TRANSFER_METHOD = "transfer_method" - URL = "url" - EXTENSION = "extension" - RELATED_ID = "related_id" - - -class ArrayFileAttribute(StrEnum): - LENGTH = "length" diff --git a/api/graphon/file/file_factory.py b/api/graphon/file/file_factory.py deleted file mode 100644 index 3d20b9377d7..00000000000 --- a/api/graphon/file/file_factory.py +++ /dev/null @@ -1,39 +0,0 @@ -from .constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS -from .enums import FileType - - -def standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType: - """ - Infer the actual file type from extension and mime type. - """ - guessed_type = None - if extension: - guessed_type = _get_file_type_by_extension(extension) - if guessed_type is None and mime_type: - guessed_type = get_file_type_by_mime_type(mime_type) - return guessed_type or FileType.CUSTOM - - -def _get_file_type_by_extension(extension: str) -> FileType | None: - normalized_extension = extension.lstrip(".") - if normalized_extension in IMAGE_EXTENSIONS: - return FileType.IMAGE - if normalized_extension in VIDEO_EXTENSIONS: - return FileType.VIDEO - if normalized_extension in AUDIO_EXTENSIONS: - return FileType.AUDIO - if normalized_extension in DOCUMENT_EXTENSIONS: - return FileType.DOCUMENT - return None - - -def get_file_type_by_mime_type(mime_type: str) -> FileType: - if "image" in mime_type: - return FileType.IMAGE - if "video" in mime_type: - return FileType.VIDEO - if "audio" in mime_type: - return FileType.AUDIO - if "text" in mime_type or "pdf" in mime_type: - return FileType.DOCUMENT - return FileType.CUSTOM diff --git a/api/graphon/file/file_manager.py b/api/graphon/file/file_manager.py deleted file mode 100644 index d7e4d472e78..00000000000 --- a/api/graphon/file/file_manager.py +++ /dev/null @@ -1,129 +0,0 @@ -from __future__ import annotations - -import base64 -from collections.abc import Mapping - -from graphon.model_runtime.entities import ( - AudioPromptMessageContent, - DocumentPromptMessageContent, - ImagePromptMessageContent, - TextPromptMessageContent, - VideoPromptMessageContent, -) -from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes - -from .enums import FileAttribute -from .models import File, FileTransferMethod, FileType -from .runtime import get_workflow_file_runtime - - -def get_attr(*, file: File, attr: FileAttribute): - match attr: - case FileAttribute.TYPE: - return file.type.value - case FileAttribute.SIZE: - return file.size - case FileAttribute.NAME: - return file.filename - case FileAttribute.MIME_TYPE: - return file.mime_type - case FileAttribute.TRANSFER_METHOD: - return file.transfer_method.value - case FileAttribute.URL: - return _to_url(file) - case FileAttribute.EXTENSION: - return file.extension - case FileAttribute.RELATED_ID: - return file.related_id - - -def to_prompt_message_content( - f: File, - /, - *, - image_detail_config: ImagePromptMessageContent.DETAIL | None = None, -) -> PromptMessageContentUnionTypes: - """Convert a file to prompt message content.""" - if f.extension is None: - raise ValueError("Missing file extension") - if f.mime_type is None: - raise ValueError("Missing file mime_type") - - prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = { - FileType.IMAGE: ImagePromptMessageContent, - FileType.AUDIO: AudioPromptMessageContent, - FileType.VIDEO: VideoPromptMessageContent, - FileType.DOCUMENT: DocumentPromptMessageContent, - } - - if f.type not in prompt_class_map: - return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]") - - send_format = get_workflow_file_runtime().multimodal_send_format - params = { - "base64_data": _get_encoded_string(f) if send_format == "base64" else "", - "url": _to_url(f) if send_format == "url" else "", - "format": f.extension.removeprefix("."), - "mime_type": f.mime_type, - "filename": f.filename or "", - } - if f.type == FileType.IMAGE: - params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - - return prompt_class_map[f.type].model_validate(params) - - -def download(f: File, /) -> bytes: - if f.transfer_method in ( - FileTransferMethod.TOOL_FILE, - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.DATASOURCE_FILE, - ): - return _download_file_content(f) - elif f.transfer_method == FileTransferMethod.REMOTE_URL: - if f.remote_url is None: - raise ValueError("Missing file remote_url") - response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True) - response.raise_for_status() - return response.content - raise ValueError(f"unsupported transfer method: {f.transfer_method}") - - -def _download_file_content(file: File, /) -> bytes: - """Download and return a file from storage as bytes.""" - return get_workflow_file_runtime().load_file_bytes(file=file) - - -def _get_encoded_string(f: File, /) -> str: - match f.transfer_method: - case FileTransferMethod.REMOTE_URL: - if f.remote_url is None: - raise ValueError("Missing file remote_url") - response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True) - response.raise_for_status() - data = response.content - case FileTransferMethod.LOCAL_FILE: - data = _download_file_content(f) - case FileTransferMethod.TOOL_FILE: - data = _download_file_content(f) - case FileTransferMethod.DATASOURCE_FILE: - data = _download_file_content(f) - - return base64.b64encode(data).decode("utf-8") - - -def _to_url(f: File, /): - url = f.generate_url() - if url is None: - raise ValueError(f"Unsupported transfer method: {f.transfer_method}") - return url - - -class FileManager: - """Adapter exposing file manager helpers behind FileManagerProtocol.""" - - def download(self, f: File, /) -> bytes: - return download(f) - - -file_manager = FileManager() diff --git a/api/graphon/file/helpers.py b/api/graphon/file/helpers.py deleted file mode 100644 index dade761227e..00000000000 --- a/api/graphon/file/helpers.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from .runtime import get_workflow_file_runtime - -if TYPE_CHECKING: - from .models import File - - -def resolve_file_url(file: File, /, *, for_external: bool = True) -> str | None: - return get_workflow_file_runtime().resolve_file_url(file=file, for_external=for_external) - - -def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_external: bool = True) -> str: - return get_workflow_file_runtime().resolve_upload_file_url( - upload_file_id=upload_file_id, - as_attachment=as_attachment, - for_external=for_external, - ) - - -def get_signed_tool_file_url(tool_file_id: str, extension: str, for_external: bool = True) -> str: - return get_workflow_file_runtime().resolve_tool_file_url( - tool_file_id=tool_file_id, - extension=extension, - for_external=for_external, - ) - - -def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - return get_workflow_file_runtime().verify_preview_signature( - preview_kind="image", - file_id=upload_file_id, - timestamp=timestamp, - nonce=nonce, - sign=sign, - ) - - -def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - return get_workflow_file_runtime().verify_preview_signature( - preview_kind="file", - file_id=upload_file_id, - timestamp=timestamp, - nonce=nonce, - sign=sign, - ) diff --git a/api/graphon/file/models.py b/api/graphon/file/models.py deleted file mode 100644 index ccd75843712..00000000000 --- a/api/graphon/file/models.py +++ /dev/null @@ -1,215 +0,0 @@ -from __future__ import annotations - -import base64 -import json -from collections.abc import Mapping, Sequence -from typing import Any - -from pydantic import BaseModel, Field, model_validator - -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent - -from . import helpers -from .constants import FILE_MODEL_IDENTITY -from .enums import FileTransferMethod, FileType - -_FILE_REFERENCE_PREFIX = "dify-file-ref:" - - -def sign_tool_file(*, tool_file_id: str, extension: str, for_external: bool = True) -> str: - """Compatibility shim for tests and legacy callers patching ``models.sign_tool_file``.""" - return helpers.get_signed_tool_file_url( - tool_file_id=tool_file_id, - extension=extension, - for_external=for_external, - ) - - -class ImageConfig(BaseModel): - """ - NOTE: This part of validation is deprecated, but still used in app features "Image Upload". - """ - - number_limits: int = 0 - transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list) - detail: ImagePromptMessageContent.DETAIL | None = None - - -class FileUploadConfig(BaseModel): - """ - File Upload Entity. - """ - - image_config: ImageConfig | None = None - allowed_file_types: Sequence[FileType] = Field(default_factory=list) - allowed_file_extensions: Sequence[str] = Field(default_factory=list) - allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) - number_limits: int = 0 - - -def _parse_reference(reference: str | None) -> tuple[str | None, str | None]: - """Best-effort parser for record references and historical storage-key payloads.""" - if not reference: - return None, None - - if not reference.startswith(_FILE_REFERENCE_PREFIX): - return reference, None - - encoded_payload = reference.removeprefix(_FILE_REFERENCE_PREFIX) - try: - payload = json.loads(base64.urlsafe_b64decode(encoded_payload.encode())) - except (ValueError, json.JSONDecodeError): - return reference, None - - record_id = payload.get("record_id") - if not isinstance(record_id, str) or not record_id: - return reference, None - - storage_key = payload.get("storage_key") - if not isinstance(storage_key, str): - storage_key = None - - return record_id, storage_key - - -class File(BaseModel): - """Graph-owned file reference. - - The graph layer deliberately keeps only the metadata required to route, - serialize, and render files. Application ownership concerns such as - tenant/user/conversation identity stay in the workflow/storage layer. - """ - - # NOTE: dify_model_identity is a special identifier used to distinguish between - # new and old data formats during serialization and deserialization. - dify_model_identity: str = FILE_MODEL_IDENTITY - - id: str | None = None # message file id - type: FileType - transfer_method: FileTransferMethod - # If `transfer_method` is `FileTransferMethod.remote_url`, the - # `remote_url` attribute must not be `None`. - remote_url: str | None = None # remote url - # Opaque workflow-layer reference for files resolved outside ``graphon``. - # New payloads only carry the backing record id; historical payloads may - # still include storage_key and must remain readable. - reference: str | None = None - filename: str | None = None - extension: str | None = Field(default=None, description="File extension, should contain dot") - mime_type: str | None = None - size: int = -1 - _storage_key: str - - def __init__( - self, - *, - id: str | None = None, - tenant_id: str | None = None, - type: FileType, - transfer_method: FileTransferMethod, - remote_url: str | None = None, - reference: str | None = None, - related_id: str | None = None, - filename: str | None = None, - extension: str | None = None, - mime_type: str | None = None, - size: int = -1, - storage_key: str | None = None, - dify_model_identity: str | None = FILE_MODEL_IDENTITY, - url: str | None = None, - # Legacy compatibility fields - explicitly accept known extra fields - tool_file_id: str | None = None, - upload_file_id: str | None = None, - datasource_file_id: str | None = None, - ): - legacy_record_id = related_id or tool_file_id or upload_file_id or datasource_file_id - normalized_reference = reference - if normalized_reference is None and legacy_record_id is not None: - normalized_reference = str(legacy_record_id) - _, parsed_storage_key = _parse_reference(normalized_reference) - - super().__init__( - id=id, - type=type, - transfer_method=transfer_method, - remote_url=remote_url, - reference=normalized_reference, - filename=filename, - extension=extension, - mime_type=mime_type, - size=size, - dify_model_identity=dify_model_identity, - url=url, - ) - # Accept legacy constructor fields without promoting them back into the graph model. - _ = tenant_id - self._storage_key = storage_key or parsed_storage_key or "" - - def to_dict(self) -> Mapping[str, str | int | None]: - data = self.model_dump(mode="json") - return { - **data, - "related_id": self.related_id, - "url": self.generate_url(), - } - - @property - def markdown(self) -> str: - url = self.generate_url() - if self.type == FileType.IMAGE: - text = f"![{self.filename or ''}]({url})" - else: - text = f"[{self.filename or url}]({url})" - - return text - - def generate_url(self, for_external: bool = True) -> str | None: - return helpers.resolve_file_url(self, for_external=for_external) - - def to_plugin_parameter(self) -> dict[str, Any]: - return { - "dify_model_identity": FILE_MODEL_IDENTITY, - "mime_type": self.mime_type, - "filename": self.filename, - "extension": self.extension, - "size": self.size, - "type": self.type, - "url": self.generate_url(for_external=False), - } - - @model_validator(mode="after") - def validate_after(self) -> File: - match self.transfer_method: - case FileTransferMethod.REMOTE_URL: - if not self.remote_url: - raise ValueError("Missing file url") - if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"): - raise ValueError("Invalid file url") - case FileTransferMethod.LOCAL_FILE: - if not self.reference: - raise ValueError("Missing file reference") - case FileTransferMethod.TOOL_FILE: - if not self.reference: - raise ValueError("Missing file reference") - case FileTransferMethod.DATASOURCE_FILE: - if not self.reference: - raise ValueError("Missing file reference") - return self - - @property - def related_id(self) -> str | None: - record_id, _ = _parse_reference(self.reference) - return record_id - - @related_id.setter - def related_id(self, value: str | None) -> None: - self.reference = value - - @property - def storage_key(self) -> str: - _, storage_key = _parse_reference(self.reference) - return storage_key or self._storage_key - - @storage_key.setter - def storage_key(self, value: str) -> None: - self._storage_key = value diff --git a/api/graphon/file/protocols.py b/api/graphon/file/protocols.py deleted file mode 100644 index 0acabe35e52..00000000000 --- a/api/graphon/file/protocols.py +++ /dev/null @@ -1,56 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator -from typing import TYPE_CHECKING, Literal, Protocol - -if TYPE_CHECKING: - from .models import File - - -class HttpResponseProtocol(Protocol): - """Subset of response behavior needed by workflow file helpers.""" - - @property - def content(self) -> bytes: ... - - def raise_for_status(self) -> object: ... - - -class WorkflowFileRuntimeProtocol(Protocol): - """Runtime dependencies required by ``graphon.file``. - - Implementations are expected to be provided by integration layers (for example, - ``core.app.workflow.file_runtime``) so the workflow package avoids importing - application infrastructure modules directly. - """ - - @property - def multimodal_send_format(self) -> str: ... - - def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: ... - - def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ... - - def load_file_bytes(self, *, file: File) -> bytes: ... - - def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: ... - - def resolve_upload_file_url( - self, - *, - upload_file_id: str, - as_attachment: bool = False, - for_external: bool = True, - ) -> str: ... - - def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ... - - def verify_preview_signature( - self, - *, - preview_kind: Literal["image", "file"], - file_id: str, - timestamp: str, - nonce: str, - sign: str, - ) -> bool: ... diff --git a/api/graphon/file/runtime.py b/api/graphon/file/runtime.py deleted file mode 100644 index 1c5d1c3ca47..00000000000 --- a/api/graphon/file/runtime.py +++ /dev/null @@ -1,71 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator -from typing import TYPE_CHECKING, Literal, NoReturn - -from .protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol - -if TYPE_CHECKING: - from .models import File - - -class WorkflowFileRuntimeNotConfiguredError(RuntimeError): - """Raised when workflow file runtime dependencies were not configured.""" - - -class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol): - def _raise(self) -> NoReturn: - raise WorkflowFileRuntimeNotConfiguredError( - "workflow file runtime is not configured, call set_workflow_file_runtime(...) first" - ) - - @property - def multimodal_send_format(self) -> str: - self._raise() - - def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: - self._raise() - - def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: - self._raise() - - def load_file_bytes(self, *, file: File) -> bytes: - self._raise() - - def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: - self._raise() - - def resolve_upload_file_url( - self, - *, - upload_file_id: str, - as_attachment: bool = False, - for_external: bool = True, - ) -> str: - self._raise() - - def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: - self._raise() - - def verify_preview_signature( - self, - *, - preview_kind: Literal["image", "file"], - file_id: str, - timestamp: str, - nonce: str, - sign: str, - ) -> bool: - self._raise() - - -_runtime: WorkflowFileRuntimeProtocol = _UnconfiguredWorkflowFileRuntime() - - -def set_workflow_file_runtime(runtime: WorkflowFileRuntimeProtocol) -> None: - global _runtime - _runtime = runtime - - -def get_workflow_file_runtime() -> WorkflowFileRuntimeProtocol: - return _runtime diff --git a/api/graphon/file/tool_file_parser.py b/api/graphon/file/tool_file_parser.py deleted file mode 100644 index 2d7a3d43df4..00000000000 --- a/api/graphon/file/tool_file_parser.py +++ /dev/null @@ -1,9 +0,0 @@ -from collections.abc import Callable -from typing import Any - -_tool_file_manager_factory: Callable[[], Any] | None = None - - -def set_tool_file_manager_factory(factory: Callable[[], Any]): - global _tool_file_manager_factory - _tool_file_manager_factory = factory diff --git a/api/graphon/graph/__init__.py b/api/graphon/graph/__init__.py deleted file mode 100644 index 4830ea83d3d..00000000000 --- a/api/graphon/graph/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .edge import Edge -from .graph import Graph, GraphBuilder, NodeFactory -from .graph_template import GraphTemplate - -__all__ = [ - "Edge", - "Graph", - "GraphBuilder", - "GraphTemplate", - "NodeFactory", -] diff --git a/api/graphon/graph/edge.py b/api/graphon/graph/edge.py deleted file mode 100644 index 1f8a2884e3b..00000000000 --- a/api/graphon/graph/edge.py +++ /dev/null @@ -1,15 +0,0 @@ -import uuid -from dataclasses import dataclass, field - -from graphon.enums import NodeState - - -@dataclass -class Edge: - """Edge connecting two nodes in a workflow graph.""" - - id: str = field(default_factory=lambda: str(uuid.uuid4())) - tail: str = "" # tail node id (source) - head: str = "" # head node id (target) - source_handle: str = "source" # source handle for conditional branching - state: NodeState = field(default=NodeState.UNKNOWN) # edge execution state diff --git a/api/graphon/graph/graph.py b/api/graphon/graph/graph.py deleted file mode 100644 index 0f4cd8925fc..00000000000 --- a/api/graphon/graph/graph.py +++ /dev/null @@ -1,438 +0,0 @@ -from __future__ import annotations - -import logging -from collections import defaultdict -from collections.abc import Mapping, Sequence -from typing import Protocol, cast, final - -from pydantic import TypeAdapter - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ErrorStrategy, NodeExecutionType, NodeState -from graphon.nodes.base.node import Node - -from .edge import Edge -from .validation import get_graph_validator - -logger = logging.getLogger(__name__) - -_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict]) - - -class NodeFactory(Protocol): - """ - Protocol for creating Node instances from node data dictionaries. - - This protocol decouples the Graph class from specific node mapping implementations, - allowing for different node creation strategies while maintaining type safety. - """ - - def create_node(self, node_config: NodeConfigDict) -> Node: - """ - Create a Node instance from node configuration data. - - :param node_config: node configuration dictionary containing type and other data - :return: initialized Node instance - :raises ValueError: if node type is unknown or no implementation exists for the resolved version - :raises ValidationError: if node_config does not satisfy NodeConfigDict/BaseNodeData validation - """ - ... - - -@final -class Graph: - """Graph representation with nodes and edges for workflow execution.""" - - def __init__( - self, - *, - nodes: dict[str, Node] | None = None, - edges: dict[str, Edge] | None = None, - in_edges: dict[str, list[str]] | None = None, - out_edges: dict[str, list[str]] | None = None, - root_node: Node, - ): - """ - Initialize Graph instance. - - :param nodes: graph nodes mapping (node id: node object) - :param edges: graph edges mapping (edge id: edge object) - :param in_edges: incoming edges mapping (node id: list of edge ids) - :param out_edges: outgoing edges mapping (node id: list of edge ids) - :param root_node: root node object - """ - self.nodes = nodes or {} - self.edges = edges or {} - self.in_edges = in_edges or {} - self.out_edges = out_edges or {} - self.root_node = root_node - - @classmethod - def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]: - """ - Parse node configurations and build a mapping of node IDs to configs. - - :param node_configs: list of node configuration dictionaries - :return: mapping of node ID to node config - """ - node_configs_map: dict[str, NodeConfigDict] = {} - - for node_config in node_configs: - node_configs_map[node_config["id"]] = node_config - - return node_configs_map - - @classmethod - def _build_edges( - cls, edge_configs: list[dict[str, object]] - ) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]: - """ - Build edge objects and mappings from edge configurations. - - :param edge_configs: list of edge configurations - :return: tuple of (edges dict, in_edges dict, out_edges dict) - """ - edges: dict[str, Edge] = {} - in_edges: dict[str, list[str]] = defaultdict(list) - out_edges: dict[str, list[str]] = defaultdict(list) - - edge_counter = 0 - for edge_config in edge_configs: - source = edge_config.get("source") - target = edge_config.get("target") - - if not isinstance(source, str) or not isinstance(target, str): - continue - - # Create edge - edge_id = f"edge_{edge_counter}" - edge_counter += 1 - - source_handle = edge_config.get("sourceHandle", "source") - if not isinstance(source_handle, str): - continue - - edge = Edge( - id=edge_id, - tail=source, - head=target, - source_handle=source_handle, - ) - - edges[edge_id] = edge - out_edges[source].append(edge_id) - in_edges[target].append(edge_id) - - return edges, dict(in_edges), dict(out_edges) - - @classmethod - def _create_node_instances( - cls, - node_configs_map: dict[str, NodeConfigDict], - node_factory: NodeFactory, - ) -> dict[str, Node]: - """ - Create node instances from configurations using the node factory. - - :param node_configs_map: mapping of node ID to node config - :param node_factory: factory for creating node instances - :return: mapping of node ID to node instance - """ - nodes: dict[str, Node] = {} - - for node_id, node_config in node_configs_map.items(): - try: - node_instance = node_factory.create_node(node_config) - except Exception: - logger.exception("Failed to create node instance for node_id %s", node_id) - raise - nodes[node_id] = node_instance - - return nodes - - @classmethod - def new(cls) -> GraphBuilder: - """Create a fluent builder for assembling a graph programmatically.""" - - return GraphBuilder(graph_cls=cls) - - @staticmethod - def _filter_canvas_only_nodes(node_configs: Sequence[Mapping[str, object]]) -> list[dict[str, object]]: - """ - Remove editor-only nodes before `NodeConfigDict` validation. - - Persisted note widgets use a top-level `type == "custom-note"` but leave - `data.type` empty because they are never executable graph nodes. Filter - them while configs are still raw dicts so Pydantic does not validate - their placeholder payloads against `BaseNodeData.type: NodeType`. - """ - filtered_node_configs: list[dict[str, object]] = [] - for node_config in node_configs: - if node_config.get("type", "") == "custom-note": - continue - filtered_node_configs.append(dict(node_config)) - return filtered_node_configs - - @classmethod - def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None: - """ - Promote nodes configured with FAIL_BRANCH error strategy to branch execution type. - - :param nodes: mapping of node ID to node instance - """ - for node in nodes.values(): - if node.error_strategy == ErrorStrategy.FAIL_BRANCH: - node.execution_type = NodeExecutionType.BRANCH - - @classmethod - def _mark_inactive_root_branches( - cls, - nodes: dict[str, Node], - edges: dict[str, Edge], - in_edges: dict[str, list[str]], - out_edges: dict[str, list[str]], - active_root_id: str, - ) -> None: - """ - Mark nodes and edges from inactive root branches as skipped. - - Algorithm: - 1. Mark inactive root nodes as skipped - 2. For skipped nodes, mark all their outgoing edges as skipped - 3. For each edge marked as skipped, check its target node: - - If ALL incoming edges are skipped, mark the node as skipped - - Otherwise, leave the node state unchanged - - :param nodes: mapping of node ID to node instance - :param edges: mapping of edge ID to edge instance - :param in_edges: mapping of node ID to incoming edge IDs - :param out_edges: mapping of node ID to outgoing edge IDs - :param active_root_id: ID of the active root node - """ - # Find all top-level root nodes (nodes with ROOT execution type and no incoming edges) - top_level_roots: list[str] = [ - node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT - ] - - # If there's only one root or the active root is not a top-level root, no marking needed - if len(top_level_roots) <= 1 or active_root_id not in top_level_roots: - return - - # Mark inactive root nodes as skipped - inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id] - for root_id in inactive_roots: - if root_id in nodes: - nodes[root_id].state = NodeState.SKIPPED - - # Recursively mark downstream nodes and edges - def mark_downstream(node_id: str) -> None: - """Recursively mark downstream nodes and edges as skipped.""" - if nodes[node_id].state != NodeState.SKIPPED: - return - # If this node is skipped, mark all its outgoing edges as skipped - out_edge_ids = out_edges.get(node_id, []) - for edge_id in out_edge_ids: - edge = edges[edge_id] - edge.state = NodeState.SKIPPED - - # Check the target node of this edge - target_node = nodes[edge.head] - in_edge_ids = in_edges.get(target_node.id, []) - in_edge_states = [edges[eid].state for eid in in_edge_ids] - - # If all incoming edges are skipped, mark the node as skipped - if all(state == NodeState.SKIPPED for state in in_edge_states): - target_node.state = NodeState.SKIPPED - # Recursively process downstream nodes - mark_downstream(target_node.id) - - # Process each inactive root and its downstream nodes - for root_id in inactive_roots: - mark_downstream(root_id) - - @classmethod - def init( - cls, - *, - graph_config: Mapping[str, object], - node_factory: NodeFactory, - root_node_id: str, - skip_validation: bool = False, - ) -> Graph: - """ - Initialize a graph with an explicit execution entry point. - - :param graph_config: graph config containing nodes and edges - :param node_factory: factory for creating node instances from config data - :param root_node_id: active root node id - :return: graph instance - """ - # Parse configs - edge_configs = graph_config.get("edges", []) - node_configs = graph_config.get("nodes", []) - - edge_configs = cast(list[dict[str, object]], edge_configs) - node_configs = cast(list[dict[str, object]], node_configs) - node_configs = cls._filter_canvas_only_nodes(node_configs) - node_configs = _ListNodeConfigDict.validate_python(node_configs) - - if not node_configs: - raise ValueError("Graph must have at least one node") - - # Parse node configurations - node_configs_map = cls._parse_node_configs(node_configs) - - if root_node_id not in node_configs_map: - raise ValueError(f"Root node id {root_node_id} not found in the graph") - - # Build edges - edges, in_edges, out_edges = cls._build_edges(edge_configs) - - # Create node instances - nodes = cls._create_node_instances(node_configs_map, node_factory) - - # Promote fail-branch nodes to branch execution type at graph level - cls._promote_fail_branch_nodes(nodes) - - # Get root node instance - root_node = nodes[root_node_id] - - # Mark inactive root branches as skipped - cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id) - - # Create and return the graph - graph = cls( - nodes=nodes, - edges=edges, - in_edges=in_edges, - out_edges=out_edges, - root_node=root_node, - ) - - if not skip_validation: - # Validate the graph structure using built-in validators - get_graph_validator().validate(graph) - - return graph - - @property - def node_ids(self) -> list[str]: - """ - Get list of node IDs (compatibility property for existing code) - - :return: list of node IDs - """ - return list(self.nodes.keys()) - - def get_outgoing_edges(self, node_id: str) -> list[Edge]: - """ - Get all outgoing edges from a node (V2 method) - - :param node_id: node id - :return: list of outgoing edges - """ - edge_ids = self.out_edges.get(node_id, []) - return [self.edges[eid] for eid in edge_ids if eid in self.edges] - - def get_incoming_edges(self, node_id: str) -> list[Edge]: - """ - Get all incoming edges to a node (V2 method) - - :param node_id: node id - :return: list of incoming edges - """ - edge_ids = self.in_edges.get(node_id, []) - return [self.edges[eid] for eid in edge_ids if eid in self.edges] - - -@final -class GraphBuilder: - """Fluent helper for constructing simple graphs, primarily for tests.""" - - def __init__(self, *, graph_cls: type[Graph]): - self._graph_cls = graph_cls - self._nodes: list[Node] = [] - self._nodes_by_id: dict[str, Node] = {} - self._edges: list[Edge] = [] - self._edge_counter = 0 - - def add_root(self, node: Node) -> GraphBuilder: - """Register the root node. Must be called exactly once.""" - - if self._nodes: - raise ValueError("Root node has already been added") - self._register_node(node) - self._nodes.append(node) - return self - - def add_node( - self, - node: Node, - *, - from_node_id: str | None = None, - source_handle: str = "source", - ) -> GraphBuilder: - """Append a node and connect it from the specified predecessor.""" - - if not self._nodes: - raise ValueError("Root node must be added before adding other nodes") - - predecessor_id = from_node_id or self._nodes[-1].id - if predecessor_id not in self._nodes_by_id: - raise ValueError(f"Predecessor node '{predecessor_id}' not found") - - predecessor = self._nodes_by_id[predecessor_id] - self._register_node(node) - self._nodes.append(node) - - edge_id = f"edge_{self._edge_counter}" - self._edge_counter += 1 - edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle) - self._edges.append(edge) - - return self - - def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder: - """Connect two existing nodes without adding a new node.""" - - if tail not in self._nodes_by_id: - raise ValueError(f"Tail node '{tail}' not found") - if head not in self._nodes_by_id: - raise ValueError(f"Head node '{head}' not found") - - edge_id = f"edge_{self._edge_counter}" - self._edge_counter += 1 - edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle) - self._edges.append(edge) - - return self - - def build(self) -> Graph: - """Materialize the graph instance from the accumulated nodes and edges.""" - - if not self._nodes: - raise ValueError("Cannot build an empty graph") - - nodes = {node.id: node for node in self._nodes} - edges = {edge.id: edge for edge in self._edges} - in_edges: dict[str, list[str]] = defaultdict(list) - out_edges: dict[str, list[str]] = defaultdict(list) - - for edge in self._edges: - out_edges[edge.tail].append(edge.id) - in_edges[edge.head].append(edge.id) - - return self._graph_cls( - nodes=nodes, - edges=edges, - in_edges=dict(in_edges), - out_edges=dict(out_edges), - root_node=self._nodes[0], - ) - - def _register_node(self, node: Node) -> None: - if not node.id: - raise ValueError("Node must have a non-empty id") - if node.id in self._nodes_by_id: - raise ValueError(f"Duplicate node id detected: {node.id}") - self._nodes_by_id[node.id] = node diff --git a/api/graphon/graph/graph_template.py b/api/graphon/graph/graph_template.py deleted file mode 100644 index 34e2dc19e60..00000000000 --- a/api/graphon/graph/graph_template.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Any - -from pydantic import BaseModel, Field - - -class GraphTemplate(BaseModel): - """ - Graph Template for container nodes and subgraph expansion - - According to GraphEngine V2 spec, GraphTemplate contains: - - nodes: mapping of node definitions - - edges: mapping of edge definitions - - root_ids: list of root node IDs - - output_selectors: list of output selectors for the template - """ - - nodes: dict[str, dict[str, Any]] = Field(default_factory=dict, description="node definitions mapping") - edges: dict[str, dict[str, Any]] = Field(default_factory=dict, description="edge definitions mapping") - root_ids: list[str] = Field(default_factory=list, description="root node IDs") - output_selectors: list[str] = Field(default_factory=list, description="output selectors") diff --git a/api/graphon/graph/validation.py b/api/graphon/graph/validation.py deleted file mode 100644 index 04b501fd331..00000000000 --- a/api/graphon/graph/validation.py +++ /dev/null @@ -1,125 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from dataclasses import dataclass -from typing import TYPE_CHECKING, Protocol - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeType - -if TYPE_CHECKING: - from .graph import Graph - - -@dataclass(frozen=True, slots=True) -class GraphValidationIssue: - """Immutable value object describing a single validation issue.""" - - code: str - message: str - node_id: str | None = None - - -class GraphValidationError(ValueError): - """Raised when graph validation fails.""" - - def __init__(self, issues: Sequence[GraphValidationIssue]) -> None: - if not issues: - raise ValueError("GraphValidationError requires at least one issue.") - self.issues: tuple[GraphValidationIssue, ...] = tuple(issues) - message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues) - super().__init__(message) - - -class GraphValidationRule(Protocol): - """Protocol that individual validation rules must satisfy.""" - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - """Validate the provided graph and return any discovered issues.""" - ... - - -@dataclass(frozen=True, slots=True) -class _EdgeEndpointValidator: - """Ensures all edges reference existing nodes.""" - - missing_node_code: str = "MISSING_NODE" - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - issues: list[GraphValidationIssue] = [] - for edge in graph.edges.values(): - if edge.tail not in graph.nodes: - issues.append( - GraphValidationIssue( - code=self.missing_node_code, - message=f"Edge {edge.id} references unknown source node '{edge.tail}'.", - node_id=edge.tail, - ) - ) - if edge.head not in graph.nodes: - issues.append( - GraphValidationIssue( - code=self.missing_node_code, - message=f"Edge {edge.id} references unknown target node '{edge.head}'.", - node_id=edge.head, - ) - ) - return issues - - -@dataclass(frozen=True, slots=True) -class _RootNodeValidator: - """Validates root node invariants.""" - - invalid_root_code: str = "INVALID_ROOT" - container_entry_types: tuple[NodeType, ...] = (BuiltinNodeTypes.ITERATION_START, BuiltinNodeTypes.LOOP_START) - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - root_node = graph.root_node - issues: list[GraphValidationIssue] = [] - if root_node.id not in graph.nodes: - issues.append( - GraphValidationIssue( - code=self.invalid_root_code, - message=f"Root node '{root_node.id}' is missing from the node registry.", - node_id=root_node.id, - ) - ) - return issues - - node_type = root_node.node_type - if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types: - issues.append( - GraphValidationIssue( - code=self.invalid_root_code, - message=f"Root node '{root_node.id}' must declare execution type 'root'.", - node_id=root_node.id, - ) - ) - return issues - - -@dataclass(frozen=True, slots=True) -class GraphValidator: - """Coordinates execution of graph validation rules.""" - - rules: tuple[GraphValidationRule, ...] - - def validate(self, graph: Graph) -> None: - """Validate the graph against all configured rules.""" - issues: list[GraphValidationIssue] = [] - for rule in self.rules: - issues.extend(rule.validate(graph)) - - if issues: - raise GraphValidationError(issues) - - -_DEFAULT_RULES: tuple[GraphValidationRule, ...] = ( - _EdgeEndpointValidator(), - _RootNodeValidator(), -) - - -def get_graph_validator() -> GraphValidator: - """Construct the validator composed of default rules.""" - return GraphValidator(_DEFAULT_RULES) diff --git a/api/graphon/graph_engine/__init__.py b/api/graphon/graph_engine/__init__.py deleted file mode 100644 index 0e1c7dd60a7..00000000000 --- a/api/graphon/graph_engine/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .config import GraphEngineConfig -from .graph_engine import GraphEngine - -__all__ = ["GraphEngine", "GraphEngineConfig"] diff --git a/api/graphon/graph_engine/_engine_utils.py b/api/graphon/graph_engine/_engine_utils.py deleted file mode 100644 index 28898268fea..00000000000 --- a/api/graphon/graph_engine/_engine_utils.py +++ /dev/null @@ -1,15 +0,0 @@ -import time - - -def get_timestamp() -> float: - """Retrieve a timestamp as a float point numer representing the number of seconds - since the Unix epoch. - - This function is primarily used to measure the execution time of the workflow engine. - Since workflow execution may be paused and resumed on a different machine, - `time.perf_counter` cannot be used as it is inconsistent across machines. - - To address this, the function uses the wall clock as the time source. - However, it assumes that the clocks of all servers are properly synchronized. - """ - return round(time.time()) diff --git a/api/graphon/graph_engine/command_channels/README.md b/api/graphon/graph_engine/command_channels/README.md deleted file mode 100644 index e35e12054ae..00000000000 --- a/api/graphon/graph_engine/command_channels/README.md +++ /dev/null @@ -1,33 +0,0 @@ -# Command Channels - -Channel implementations for external workflow control. - -## Components - -### InMemoryChannel - -Thread-safe in-memory queue for single-process deployments. - -- `fetch_commands()` - Get pending commands -- `send_command()` - Add command to queue - -### RedisChannel - -Redis-based queue for distributed deployments. - -- `fetch_commands()` - Get commands with JSON deserialization -- `send_command()` - Store commands with TTL - -## Usage - -```python -# Local execution -channel = InMemoryChannel() -channel.send_command(AbortCommand(graph_id="workflow-123")) - -# Distributed execution -redis_channel = RedisChannel( - redis_client=redis_client, - channel_key="workflow:123:commands" -) -``` diff --git a/api/graphon/graph_engine/command_channels/__init__.py b/api/graphon/graph_engine/command_channels/__init__.py deleted file mode 100644 index 863e6032d60..00000000000 --- a/api/graphon/graph_engine/command_channels/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Command channel implementations for GraphEngine.""" - -from .in_memory_channel import InMemoryChannel -from .redis_channel import RedisChannel - -__all__ = ["InMemoryChannel", "RedisChannel"] diff --git a/api/graphon/graph_engine/command_channels/in_memory_channel.py b/api/graphon/graph_engine/command_channels/in_memory_channel.py deleted file mode 100644 index bdaf2367967..00000000000 --- a/api/graphon/graph_engine/command_channels/in_memory_channel.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -In-memory implementation of CommandChannel for local/testing scenarios. - -This implementation uses a thread-safe queue for command communication -within a single process. Each instance handles commands for one workflow execution. -""" - -from queue import Queue -from typing import final - -from ..entities.commands import GraphEngineCommand - - -@final -class InMemoryChannel: - """ - In-memory command channel implementation using a thread-safe queue. - - Each instance is dedicated to a single GraphEngine/workflow execution. - Suitable for local development, testing, and single-instance deployments. - """ - - def __init__(self) -> None: - """Initialize the in-memory channel with a single queue.""" - self._queue: Queue[GraphEngineCommand] = Queue() - - def fetch_commands(self) -> list[GraphEngineCommand]: - """ - Fetch all pending commands from the queue. - - Returns: - List of pending commands (drains the queue) - """ - commands: list[GraphEngineCommand] = [] - - # Drain all available commands from the queue - while not self._queue.empty(): - try: - command = self._queue.get_nowait() - commands.append(command) - except Exception: - break - - return commands - - def send_command(self, command: GraphEngineCommand) -> None: - """ - Send a command to this channel's queue. - - Args: - command: The command to send - """ - self._queue.put(command) diff --git a/api/graphon/graph_engine/command_channels/redis_channel.py b/api/graphon/graph_engine/command_channels/redis_channel.py deleted file mode 100644 index 77cf884c67a..00000000000 --- a/api/graphon/graph_engine/command_channels/redis_channel.py +++ /dev/null @@ -1,153 +0,0 @@ -""" -Redis-based implementation of CommandChannel for distributed scenarios. - -This implementation uses Redis lists for command queuing, supporting -multi-instance deployments and cross-server communication. -Each instance uses a unique key for its command queue. -""" - -import json -from contextlib import AbstractContextManager -from typing import Any, Protocol, final - -from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand - - -class RedisPipelineProtocol(Protocol): - """Minimal Redis pipeline contract used by the command channel.""" - - def lrange(self, name: str, start: int, end: int) -> Any: ... - def delete(self, *names: str) -> Any: ... - def execute(self) -> list[Any]: ... - def rpush(self, name: str, *values: str) -> Any: ... - def expire(self, name: str, time: int) -> Any: ... - def set(self, name: str, value: str, ex: int | None = None) -> Any: ... - def get(self, name: str) -> Any: ... - - -class RedisClientProtocol(Protocol): - """Redis client contract required by the command channel.""" - - def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ... - - -@final -class RedisChannel: - """ - Redis-based command channel implementation for distributed systems. - - Each instance uses a unique Redis key for its command queue. - Commands are JSON-serialized for transport. - """ - - def __init__( - self, - redis_client: RedisClientProtocol, - channel_key: str, - command_ttl: int = 3600, - ) -> None: - """ - Initialize the Redis channel. - - Args: - redis_client: Redis client instance - channel_key: Unique key for this channel's command queue - command_ttl: TTL for command keys in seconds (default: 3600) - """ - self._redis = redis_client - self._key = channel_key - self._command_ttl = command_ttl - self._pending_key = f"{channel_key}:pending" - - def fetch_commands(self) -> list[GraphEngineCommand]: - """ - Fetch all pending commands from Redis. - - Returns: - List of pending commands (drains the Redis list) - """ - if not self._has_pending_commands(): - return [] - - commands: list[GraphEngineCommand] = [] - - # Use pipeline for atomic operations - with self._redis.pipeline() as pipe: - # Get all commands and clear the list atomically - pipe.lrange(self._key, 0, -1) - pipe.delete(self._key) - results = pipe.execute() - - # Parse commands from JSON - if results[0]: - for command_json in results[0]: - try: - command_data = json.loads(command_json) - command = self._deserialize_command(command_data) - if command: - commands.append(command) - except (json.JSONDecodeError, ValueError): - # Skip invalid commands - continue - - return commands - - def send_command(self, command: GraphEngineCommand) -> None: - """ - Send a command to Redis. - - Args: - command: The command to send - """ - command_json = json.dumps(command.model_dump()) - - # Push to list and set expiry - with self._redis.pipeline() as pipe: - pipe.rpush(self._key, command_json) - pipe.expire(self._key, self._command_ttl) - pipe.set(self._pending_key, "1", ex=self._command_ttl) - pipe.execute() - - def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None: - """ - Deserialize a command from dictionary data. - - Args: - data: Command data dictionary - - Returns: - Deserialized command or None if invalid - """ - command_type_value = data.get("command_type") - if not isinstance(command_type_value, str): - return None - - try: - command_type = CommandType(command_type_value) - - if command_type == CommandType.ABORT: - return AbortCommand.model_validate(data) - if command_type == CommandType.PAUSE: - return PauseCommand.model_validate(data) - if command_type == CommandType.UPDATE_VARIABLES: - return UpdateVariablesCommand.model_validate(data) - - # For other command types, use base class - return GraphEngineCommand.model_validate(data) - - except (ValueError, TypeError): - return None - - def _has_pending_commands(self) -> bool: - """ - Check and consume the pending marker to avoid unnecessary list reads. - - Returns: - True if commands should be fetched from Redis. - """ - with self._redis.pipeline() as pipe: - pipe.get(self._pending_key) - pipe.delete(self._pending_key) - pending_value, _ = pipe.execute() - - return pending_value is not None diff --git a/api/graphon/graph_engine/command_processing/__init__.py b/api/graphon/graph_engine/command_processing/__init__.py deleted file mode 100644 index 7b4f0dfff79..00000000000 --- a/api/graphon/graph_engine/command_processing/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Command processing subsystem for graph engine. - -This package handles external commands sent to the engine -during execution. -""" - -from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler -from .command_processor import CommandProcessor - -__all__ = [ - "AbortCommandHandler", - "CommandProcessor", - "PauseCommandHandler", - "UpdateVariablesCommandHandler", -] diff --git a/api/graphon/graph_engine/command_processing/command_handlers.py b/api/graphon/graph_engine/command_processing/command_handlers.py deleted file mode 100644 index ad92fd1abb0..00000000000 --- a/api/graphon/graph_engine/command_processing/command_handlers.py +++ /dev/null @@ -1,56 +0,0 @@ -import logging -from typing import final - -from typing_extensions import override - -from graphon.entities.pause_reason import SchedulingPause -from graphon.runtime import VariablePool - -from ..domain.graph_execution import GraphExecution -from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand -from .command_processor import CommandHandler - -logger = logging.getLogger(__name__) - - -@final -class AbortCommandHandler(CommandHandler): - @override - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: - assert isinstance(command, AbortCommand) - logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason) - execution.abort(command.reason or "User requested abort") - - -@final -class PauseCommandHandler(CommandHandler): - @override - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: - assert isinstance(command, PauseCommand) - logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason) - # Convert string reason to PauseReason if needed - reason = command.reason - pause_reason = SchedulingPause(message=reason) - execution.pause(pause_reason) - - -@final -class UpdateVariablesCommandHandler(CommandHandler): - def __init__(self, variable_pool: VariablePool) -> None: - self._variable_pool = variable_pool - - @override - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: - assert isinstance(command, UpdateVariablesCommand) - for update in command.updates: - try: - variable = update.value - self._variable_pool.add(variable.selector, variable) - logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id) - except ValueError as exc: - logger.warning( - "Skipping invalid variable selector %s for workflow %s: %s", - getattr(update.value, "selector", None), - execution.workflow_id, - exc, - ) diff --git a/api/graphon/graph_engine/command_processing/command_processor.py b/api/graphon/graph_engine/command_processing/command_processor.py deleted file mode 100644 index 942c2d77a5a..00000000000 --- a/api/graphon/graph_engine/command_processing/command_processor.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Main command processor for handling external commands. -""" - -import logging -from typing import Protocol, final - -from ..domain.graph_execution import GraphExecution -from ..entities.commands import GraphEngineCommand -from ..protocols.command_channel import CommandChannel - -logger = logging.getLogger(__name__) - - -class CommandHandler(Protocol): - """Protocol for command handlers.""" - - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ... - - -@final -class CommandProcessor: - """ - Processes external commands sent to the engine. - - This polls the command channel and dispatches commands to - appropriate handlers. - """ - - def __init__( - self, - command_channel: CommandChannel, - graph_execution: GraphExecution, - ) -> None: - """ - Initialize the command processor. - - Args: - command_channel: Channel for receiving commands - graph_execution: Graph execution aggregate - """ - self._command_channel = command_channel - self._graph_execution = graph_execution - self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {} - - def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None: - """ - Register a handler for a command type. - - Args: - command_type: Type of command to handle - handler: Handler for the command - """ - self._handlers[command_type] = handler - - def process_commands(self) -> None: - """Check for and process any pending commands.""" - try: - commands = self._command_channel.fetch_commands() - for command in commands: - self._handle_command(command) - except Exception as e: - logger.warning("Error processing commands: %s", e) - - def _handle_command(self, command: GraphEngineCommand) -> None: - """ - Handle a single command. - - Args: - command: The command to handle - """ - handler = self._handlers.get(type(command)) - if handler: - try: - handler.handle(command, self._graph_execution) - except Exception: - logger.exception("Error handling command %s", command.__class__.__name__) - else: - logger.warning("No handler registered for command: %s", command.__class__.__name__) diff --git a/api/graphon/graph_engine/config.py b/api/graphon/graph_engine/config.py deleted file mode 100644 index d56a69cee03..00000000000 --- a/api/graphon/graph_engine/config.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -GraphEngine configuration models. -""" - -from pydantic import BaseModel, ConfigDict - - -class GraphEngineConfig(BaseModel): - """Configuration for GraphEngine worker pool scaling.""" - - model_config = ConfigDict(frozen=True) - - min_workers: int = 1 - max_workers: int = 5 - scale_up_threshold: int = 3 - scale_down_idle_time: float = 5.0 diff --git a/api/graphon/graph_engine/domain/__init__.py b/api/graphon/graph_engine/domain/__init__.py deleted file mode 100644 index 9e9afe4c219..00000000000 --- a/api/graphon/graph_engine/domain/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Domain models for graph engine. - -This package contains the core domain entities, value objects, and aggregates -that represent the business concepts of workflow graph execution. -""" - -from .graph_execution import GraphExecution -from .node_execution import NodeExecution - -__all__ = [ - "GraphExecution", - "NodeExecution", -] diff --git a/api/graphon/graph_engine/domain/graph_execution.py b/api/graphon/graph_engine/domain/graph_execution.py deleted file mode 100644 index 9c0c7d16240..00000000000 --- a/api/graphon/graph_engine/domain/graph_execution.py +++ /dev/null @@ -1,242 +0,0 @@ -"""GraphExecution aggregate root managing the overall graph execution state.""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from importlib import import_module -from typing import Literal - -from pydantic import BaseModel, Field - -from graphon.entities.pause_reason import PauseReason -from graphon.enums import NodeState -from graphon.runtime.graph_runtime_state import GraphExecutionProtocol - -from .node_execution import NodeExecution - - -class GraphExecutionErrorState(BaseModel): - """Serializable representation of an execution error.""" - - module: str = Field(description="Module containing the exception class") - qualname: str = Field(description="Qualified name of the exception class") - message: str | None = Field(default=None, description="Exception message string") - - -class NodeExecutionState(BaseModel): - """Serializable representation of a node execution entity.""" - - node_id: str - state: NodeState = Field(default=NodeState.UNKNOWN) - retry_count: int = Field(default=0) - execution_id: str | None = Field(default=None) - error: str | None = Field(default=None) - - -class GraphExecutionState(BaseModel): - """Pydantic model describing serialized GraphExecution state.""" - - type: Literal["GraphExecution"] = Field(default="GraphExecution") - version: str = Field(default="1.0") - workflow_id: str - started: bool = Field(default=False) - completed: bool = Field(default=False) - aborted: bool = Field(default=False) - paused: bool = Field(default=False) - pause_reasons: list[PauseReason] = Field(default_factory=list) - error: GraphExecutionErrorState | None = Field(default=None) - exceptions_count: int = Field(default=0) - node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState]) - - -def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None: - """Convert an exception into its serializable representation.""" - - if error is None: - return None - - return GraphExecutionErrorState( - module=error.__class__.__module__, - qualname=error.__class__.__qualname__, - message=str(error), - ) - - -def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]: - """Locate an exception class from its module and qualified name.""" - - module = import_module(module_name) - attr: object = module - for part in qualname.split("."): - attr = getattr(attr, part) - - if isinstance(attr, type) and issubclass(attr, Exception): - return attr - - raise TypeError(f"{qualname} in {module_name} is not an Exception subclass") - - -def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None: - """Reconstruct an exception instance from serialized data.""" - - if state is None: - return None - - try: - exception_class = _resolve_exception_class(state.module, state.qualname) - if state.message is None: - return exception_class() - return exception_class(state.message) - except Exception: - # Fallback to RuntimeError when reconstruction fails - if state.message is None: - return RuntimeError(state.qualname) - return RuntimeError(state.message) - - -@dataclass -class GraphExecution: - """ - Aggregate root for graph execution. - - This manages the overall execution state of a workflow graph, - coordinating between multiple node executions. - """ - - workflow_id: str - started: bool = False - completed: bool = False - aborted: bool = False - paused: bool = False - pause_reasons: list[PauseReason] = field(default_factory=list) - error: Exception | None = None - node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution]) - exceptions_count: int = 0 - - def start(self) -> None: - """Mark the graph execution as started.""" - if self.started: - raise RuntimeError("Graph execution already started") - self.started = True - - def complete(self) -> None: - """Mark the graph execution as completed.""" - if not self.started: - raise RuntimeError("Cannot complete execution that hasn't started") - if self.completed: - raise RuntimeError("Graph execution already completed") - self.completed = True - - def abort(self, reason: str) -> None: - """Abort the graph execution.""" - self.aborted = True - self.error = RuntimeError(f"Aborted: {reason}") - - def pause(self, reason: PauseReason) -> None: - """Pause the graph execution without marking it complete.""" - if self.completed: - raise RuntimeError("Cannot pause execution that has completed") - if self.aborted: - raise RuntimeError("Cannot pause execution that has been aborted") - self.paused = True - self.pause_reasons.append(reason) - - def fail(self, error: Exception) -> None: - """Mark the graph execution as failed.""" - self.error = error - self.completed = True - - def get_or_create_node_execution(self, node_id: str) -> NodeExecution: - """Get or create a node execution entity.""" - if node_id not in self.node_executions: - self.node_executions[node_id] = NodeExecution(node_id=node_id) - return self.node_executions[node_id] - - @property - def is_running(self) -> bool: - """Check if the execution is currently running.""" - return self.started and not self.completed and not self.aborted and not self.paused - - @property - def is_paused(self) -> bool: - """Check if the execution is currently paused.""" - return self.paused - - @property - def has_error(self) -> bool: - """Check if the execution has encountered an error.""" - return self.error is not None - - @property - def error_message(self) -> str | None: - """Get the error message if an error exists.""" - if not self.error: - return None - return str(self.error) - - def dumps(self) -> str: - """Serialize the aggregate state into a JSON string.""" - - node_states = [ - NodeExecutionState( - node_id=node_id, - state=node_execution.state, - retry_count=node_execution.retry_count, - execution_id=node_execution.execution_id, - error=node_execution.error, - ) - for node_id, node_execution in sorted(self.node_executions.items()) - ] - - state = GraphExecutionState( - workflow_id=self.workflow_id, - started=self.started, - completed=self.completed, - aborted=self.aborted, - paused=self.paused, - pause_reasons=self.pause_reasons, - error=_serialize_error(self.error), - exceptions_count=self.exceptions_count, - node_executions=node_states, - ) - - return state.model_dump_json() - - def loads(self, data: str) -> None: - """Restore aggregate state from a serialized JSON string.""" - - state = GraphExecutionState.model_validate_json(data) - - if state.type != "GraphExecution": - raise ValueError(f"Invalid serialized data type: {state.type}") - - if state.version != "1.0": - raise ValueError(f"Unsupported serialized version: {state.version}") - - if self.workflow_id != state.workflow_id: - raise ValueError("Serialized workflow_id does not match aggregate identity") - - self.started = state.started - self.completed = state.completed - self.aborted = state.aborted - self.paused = state.paused - self.pause_reasons = state.pause_reasons - self.error = _deserialize_error(state.error) - self.exceptions_count = state.exceptions_count - self.node_executions = { - item.node_id: NodeExecution( - node_id=item.node_id, - state=item.state, - retry_count=item.retry_count, - execution_id=item.execution_id, - error=item.error, - ) - for item in state.node_executions - } - - def record_node_failure(self) -> None: - """Increment the count of node failures encountered during execution.""" - self.exceptions_count += 1 - - -_: GraphExecutionProtocol = GraphExecution(workflow_id="") diff --git a/api/graphon/graph_engine/domain/node_execution.py b/api/graphon/graph_engine/domain/node_execution.py deleted file mode 100644 index dafd6ccd8a4..00000000000 --- a/api/graphon/graph_engine/domain/node_execution.py +++ /dev/null @@ -1,45 +0,0 @@ -""" -NodeExecution entity representing a node's execution state. -""" - -from dataclasses import dataclass - -from graphon.enums import NodeState - - -@dataclass -class NodeExecution: - """ - Entity representing the execution state of a single node. - - This is a mutable entity that tracks the runtime state of a node - during graph execution. - """ - - node_id: str - state: NodeState = NodeState.UNKNOWN - retry_count: int = 0 - execution_id: str | None = None - error: str | None = None - - def mark_started(self, execution_id: str) -> None: - """Mark the node as started with an execution ID.""" - self.state = NodeState.TAKEN - self.execution_id = execution_id - - def mark_taken(self) -> None: - """Mark the node as successfully completed.""" - self.state = NodeState.TAKEN - self.error = None - - def mark_failed(self, error: str) -> None: - """Mark the node as failed with an error.""" - self.error = error - - def mark_skipped(self) -> None: - """Mark the node as skipped.""" - self.state = NodeState.SKIPPED - - def increment_retry(self) -> None: - """Increment the retry count for this node.""" - self.retry_count += 1 diff --git a/api/graphon/graph_engine/entities/__init__.py b/api/graphon/graph_engine/entities/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/graphon/graph_engine/entities/commands.py b/api/graphon/graph_engine/entities/commands.py deleted file mode 100644 index 25ebc804b6d..00000000000 --- a/api/graphon/graph_engine/entities/commands.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -GraphEngine command entities for external control. - -This module defines command types that can be sent to a running GraphEngine -instance to control its execution flow. -""" - -from collections.abc import Sequence -from enum import StrEnum, auto -from typing import Any - -from pydantic import BaseModel, Field - -from graphon.variables.variables import Variable - - -class CommandType(StrEnum): - """Types of commands that can be sent to GraphEngine.""" - - ABORT = auto() - PAUSE = auto() - UPDATE_VARIABLES = auto() - - -class GraphEngineCommand(BaseModel): - """Base class for all GraphEngine commands.""" - - command_type: CommandType = Field(..., description="Type of command") - payload: dict[str, Any] | None = Field(default=None, description="Optional command payload") - - -class AbortCommand(GraphEngineCommand): - """Command to abort a running workflow execution.""" - - command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command") - reason: str | None = Field(default=None, description="Optional reason for abort") - - -class PauseCommand(GraphEngineCommand): - """Command to pause a running workflow execution.""" - - command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command") - reason: str = Field(default="unknown reason", description="reason for pause") - - -class VariableUpdate(BaseModel): - """Represents a single variable update instruction.""" - - value: Variable = Field(description="New variable value") - - -class UpdateVariablesCommand(GraphEngineCommand): - """Command to update a group of variables in the variable pool.""" - - command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command") - updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates") diff --git a/api/graphon/graph_engine/error_handler.py b/api/graphon/graph_engine/error_handler.py deleted file mode 100644 index 43ce8bb502a..00000000000 --- a/api/graphon/graph_engine/error_handler.py +++ /dev/null @@ -1,213 +0,0 @@ -""" -Main error handler that coordinates error strategies. -""" - -import logging -import time -from typing import TYPE_CHECKING, final - -from graphon.enums import ( - ErrorStrategy as ErrorStrategyEnum, -) -from graphon.enums import ( - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.graph import Graph -from graphon.graph_events import ( - GraphNodeEventBase, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunRetryEvent, -) -from graphon.node_events import NodeRunResult - -if TYPE_CHECKING: - from .domain import GraphExecution - -logger = logging.getLogger(__name__) - - -@final -class ErrorHandler: - """ - Coordinates error handling strategies for node failures. - - This acts as a facade for the various error strategies, - selecting and applying the appropriate strategy based on - node configuration. - """ - - def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None: - """ - Initialize the error handler. - - Args: - graph: The workflow graph - graph_execution: The graph execution state - """ - self._graph = graph - self._graph_execution = graph_execution - - def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None: - """ - Handle a node failure event. - - Selects and applies the appropriate error strategy based on - the node's configuration. - - Args: - event: The node failure event - - Returns: - Optional new event to process, or None to abort - """ - node = self._graph.nodes[event.node_id] - # Get retry count from NodeExecution - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - retry_count = node_execution.retry_count - - # First check if retry is configured and not exhausted - if node.retry and retry_count < node.retry_config.max_retries: - result = self._handle_retry(event, retry_count) - if result: - # Retry count will be incremented when NodeRunRetryEvent is handled - return result - - # Apply configured error strategy - strategy = node.error_strategy - - match strategy: - case None: - return self._handle_abort(event) - case ErrorStrategyEnum.FAIL_BRANCH: - return self._handle_fail_branch(event) - case ErrorStrategyEnum.DEFAULT_VALUE: - return self._handle_default_value(event) - - def _handle_abort(self, event: NodeRunFailedEvent): - """ - Handle error by aborting execution. - - This is the default strategy when no other strategy is specified. - It stops the entire graph execution when a node fails. - - Args: - event: The failure event - - Returns: - None - signals abortion - """ - logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error) - # Return None to signal that execution should stop - - def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int): - """ - Handle error by retrying the node. - - This strategy re-attempts node execution up to a configured - maximum number of retries with configurable intervals. - - Args: - event: The failure event - retry_count: Current retry attempt count - - Returns: - NodeRunRetryEvent if retry should occur, None otherwise - """ - node = self._graph.nodes[event.node_id] - - # Check if we've exceeded max retries - if not node.retry or retry_count >= node.retry_config.max_retries: - return None - - # Wait for retry interval - time.sleep(node.retry_config.retry_interval_seconds) - - # Create retry event - return NodeRunRetryEvent( - id=event.id, - node_title=node.title, - node_id=event.node_id, - node_type=event.node_type, - node_run_result=event.node_run_result, - start_at=event.start_at, - error=event.error, - retry_index=retry_count + 1, - ) - - def _handle_fail_branch(self, event: NodeRunFailedEvent): - """ - Handle error by taking the fail branch. - - This strategy converts failures to exceptions and routes execution - through a designated fail-branch edge. - - Args: - event: The failure event - - Returns: - NodeRunExceptionEvent to continue via fail branch - """ - outputs = { - "error_message": event.node_run_result.error, - "error_type": event.node_run_result.error_type, - } - - return NodeRunExceptionEvent( - id=event.id, - node_id=event.node_id, - node_type=event.node_type, - start_at=event.start_at, - finished_at=event.finished_at, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - inputs=event.node_run_result.inputs, - process_data=event.node_run_result.process_data, - outputs=outputs, - edge_source_handle="fail-branch", - metadata={ - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH, - }, - ), - error=event.error, - ) - - def _handle_default_value(self, event: NodeRunFailedEvent): - """ - Handle error by using default values. - - This strategy allows nodes to fail gracefully by providing - predefined default output values. - - Args: - event: The failure event - - Returns: - NodeRunExceptionEvent with default values - """ - node = self._graph.nodes[event.node_id] - - outputs = { - **node.default_value_dict, - "error_message": event.node_run_result.error, - "error_type": event.node_run_result.error_type, - } - - return NodeRunExceptionEvent( - id=event.id, - node_id=event.node_id, - node_type=event.node_type, - start_at=event.start_at, - finished_at=event.finished_at, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - inputs=event.node_run_result.inputs, - process_data=event.node_run_result.process_data, - outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE, - }, - ), - error=event.error, - ) diff --git a/api/graphon/graph_engine/event_management/__init__.py b/api/graphon/graph_engine/event_management/__init__.py deleted file mode 100644 index f6c3c0f753f..00000000000 --- a/api/graphon/graph_engine/event_management/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Event management subsystem for graph engine. - -This package handles event routing, collection, and emission for -workflow graph execution events. -""" - -from .event_handlers import EventHandler -from .event_manager import EventManager - -__all__ = [ - "EventHandler", - "EventManager", -] diff --git a/api/graphon/graph_engine/event_management/event_handlers.py b/api/graphon/graph_engine/event_management/event_handlers.py deleted file mode 100644 index 184148280db..00000000000 --- a/api/graphon/graph_engine/event_management/event_handlers.py +++ /dev/null @@ -1,367 +0,0 @@ -""" -Event handler implementations for different event types. -""" - -import logging -from collections.abc import Mapping -from functools import singledispatchmethod -from typing import TYPE_CHECKING, final - -from graphon.enums import ErrorStrategy, NodeExecutionType, NodeState -from graphon.graph import Graph -from graphon.graph_events import ( - GraphNodeEventBase, - NodeRunAgentLogEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState - -from ..domain.graph_execution import GraphExecution -from ..response_coordinator import ResponseStreamCoordinator - -if TYPE_CHECKING: - from ..error_handler import ErrorHandler - from ..graph_state_manager import GraphStateManager - from ..graph_traversal import EdgeProcessor - from .event_manager import EventManager - -logger = logging.getLogger(__name__) - - -@final -class EventHandler: - """ - Registry of event handlers for different event types. - - This centralizes the business logic for handling specific events, - keeping it separate from the routing and collection infrastructure. - """ - - def __init__( - self, - graph: Graph, - graph_runtime_state: GraphRuntimeState, - graph_execution: GraphExecution, - response_coordinator: ResponseStreamCoordinator, - event_collector: "EventManager", - edge_processor: "EdgeProcessor", - state_manager: "GraphStateManager", - error_handler: "ErrorHandler", - ) -> None: - """ - Initialize the event handler registry. - - Args: - graph: The workflow graph - graph_runtime_state: Runtime state with variable pool - graph_execution: Graph execution aggregate - response_coordinator: Response stream coordinator - event_collector: Event manager for collecting events - edge_processor: Edge processor for edge traversal - state_manager: Unified state manager - error_handler: Error handler - """ - self._graph = graph - self._graph_runtime_state = graph_runtime_state - self._graph_execution = graph_execution - self._response_coordinator = response_coordinator - self._event_collector = event_collector - self._edge_processor = edge_processor - self._state_manager = state_manager - self._error_handler = error_handler - - def dispatch(self, event: GraphNodeEventBase) -> None: - """ - Handle any node event by dispatching to the appropriate handler. - - Args: - event: The event to handle - """ - if isinstance(event, NodeRunVariableUpdatedEvent): - self._dispatch(event) - return - - # Events in loops or iterations are always collected - if event.in_loop_id or event.in_iteration_id: - self._event_collector.collect(event) - return - return self._dispatch(event) - - @singledispatchmethod - def _dispatch(self, event: GraphNodeEventBase) -> None: - self._event_collector.collect(event) - logger.warning("Unhandled event type: %s", type(event).__name__) - - @_dispatch.register(NodeRunIterationStartedEvent) - @_dispatch.register(NodeRunIterationNextEvent) - @_dispatch.register(NodeRunIterationSucceededEvent) - @_dispatch.register(NodeRunIterationFailedEvent) - @_dispatch.register(NodeRunLoopStartedEvent) - @_dispatch.register(NodeRunLoopNextEvent) - @_dispatch.register(NodeRunLoopSucceededEvent) - @_dispatch.register(NodeRunLoopFailedEvent) - @_dispatch.register(NodeRunAgentLogEvent) - @_dispatch.register(NodeRunRetrieverResourceEvent) - def _(self, event: GraphNodeEventBase) -> None: - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunStartedEvent) -> None: - """ - Handle node started event. - - Args: - event: The node started event - """ - # Track execution in domain model - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - is_initial_attempt = node_execution.retry_count == 0 - node_execution.mark_started(event.id) - self._graph_runtime_state.increment_node_run_steps() - - # Track in response coordinator for stream ordering - self._response_coordinator.track_node_execution(event.node_id, event.id) - - # Collect the event only for the first attempt; retries remain silent - if is_initial_attempt: - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunStreamChunkEvent) -> None: - """ - Handle stream chunk event with full processing. - - Args: - event: The stream chunk event - """ - # Process with response coordinator - streaming_events = list(self._response_coordinator.intercept_event(event)) - - # Collect all events - for stream_event in streaming_events: - self._event_collector.collect(stream_event) - - @_dispatch.register - def _(self, event: NodeRunVariableUpdatedEvent) -> None: - """ - Apply a node-requested variable mutation before downstream observers run. - - The event is collected like other node events so parent/container engines can - forward the updated payload to outer layers, including persistence listeners. - """ - self._graph_runtime_state.variable_pool.add(event.variable.selector, event.variable) - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunSucceededEvent) -> None: - """ - Handle node success by coordinating subsystems. - - This method coordinates between different subsystems to process - node completion, handle edges, and trigger downstream execution. - - Args: - event: The node succeeded event - """ - # Update domain model - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.mark_taken() - - self._accumulate_node_usage(event.node_run_result.llm_usage) - - # Store outputs in variable pool - self._store_node_outputs(event.node_id, event.node_run_result.outputs) - - # Forward to response coordinator and emit streaming events - streaming_events = self._response_coordinator.intercept_event(event) - for stream_event in streaming_events: - self._event_collector.collect(stream_event) - - # Process edges and get ready nodes - node = self._graph.nodes[event.node_id] - if node.execution_type == NodeExecutionType.BRANCH: - ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion( - event.node_id, event.node_run_result.edge_source_handle - ) - else: - ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id) - - # Collect streaming events from edge processing - for edge_event in edge_streaming_events: - self._event_collector.collect(edge_event) - - # Enqueue ready nodes - if self._graph_execution.is_paused: - for node_id in ready_nodes: - self._graph_runtime_state.register_deferred_node(node_id) - else: - for node_id in ready_nodes: - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) - - # Update execution tracking - self._state_manager.finish_execution(event.node_id) - - # Handle response node outputs - if node.execution_type == NodeExecutionType.RESPONSE: - self._update_response_outputs(event.node_run_result.outputs) - - # Collect the event - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunPauseRequestedEvent) -> None: - """Handle pause requests emitted by nodes.""" - - pause_reason = event.reason - self._graph_execution.pause(pause_reason) - self._state_manager.finish_execution(event.node_id) - if event.node_id in self._graph.nodes: - self._graph.nodes[event.node_id].state = NodeState.UNKNOWN - self._graph_runtime_state.register_paused_node(event.node_id) - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunFailedEvent) -> None: - """ - Handle node failure using error handler. - - Args: - event: The node failed event - """ - # Update domain model - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.mark_failed(event.error) - self._graph_execution.record_node_failure() - - self._accumulate_node_usage(event.node_run_result.llm_usage) - - result = self._error_handler.handle_node_failure(event) - - if result: - # Process the resulting event (retry, exception, etc.) - self.dispatch(result) - else: - # Abort execution - self._graph_execution.fail(RuntimeError(event.error)) - self._event_collector.collect(event) - self._state_manager.finish_execution(event.node_id) - - @_dispatch.register - def _(self, event: NodeRunExceptionEvent) -> None: - """ - Handle node exception event (fail-branch strategy). - - Args: - event: The node exception event - """ - # Node continues via fail-branch/default-value, treat as completion - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.mark_taken() - - self._accumulate_node_usage(event.node_run_result.llm_usage) - - # Persist outputs produced by the exception strategy (e.g. default values) - self._store_node_outputs(event.node_id, event.node_run_result.outputs) - - node = self._graph.nodes[event.node_id] - - if node.error_strategy == ErrorStrategy.DEFAULT_VALUE: - ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id) - elif node.error_strategy == ErrorStrategy.FAIL_BRANCH: - ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion( - event.node_id, event.node_run_result.edge_source_handle - ) - else: - raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}") - - for edge_event in edge_streaming_events: - self._event_collector.collect(edge_event) - - for node_id in ready_nodes: - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) - - # Update response outputs if applicable - if node.execution_type == NodeExecutionType.RESPONSE: - self._update_response_outputs(event.node_run_result.outputs) - - self._state_manager.finish_execution(event.node_id) - - # Collect the exception event for observers - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunRetryEvent) -> None: - """ - Handle node retry event. - - Args: - event: The node retry event - """ - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.increment_retry() - - # Finish the previous attempt before re-queuing the node - self._state_manager.finish_execution(event.node_id) - - # Emit retry event for observers - self._event_collector.collect(event) - - # Re-queue node for execution - self._state_manager.enqueue_node(event.node_id) - self._state_manager.start_execution(event.node_id) - - def _accumulate_node_usage(self, usage: LLMUsage) -> None: - """Accumulate token usage into the shared runtime state.""" - if usage.total_tokens <= 0: - return - - self._graph_runtime_state.add_tokens(usage.total_tokens) - - current_usage = self._graph_runtime_state.llm_usage - if current_usage.total_tokens == 0: - self._graph_runtime_state.llm_usage = usage - else: - self._graph_runtime_state.llm_usage = current_usage.plus(usage) - - def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None: - """ - Store node outputs in the variable pool. - - Args: - event: The node succeeded event containing outputs - """ - for variable_name, variable_value in outputs.items(): - self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value) - - def _update_response_outputs(self, outputs: Mapping[str, object]) -> None: - """Update response outputs for response nodes.""" - # TODO: Design a mechanism for nodes to notify the engine about how to update outputs - # in runtime state, rather than allowing nodes to directly access runtime state. - for key, value in outputs.items(): - if key == "answer": - existing = self._graph_runtime_state.get_output("answer", "") - if existing: - self._graph_runtime_state.set_output("answer", f"{existing}{value}") - else: - self._graph_runtime_state.set_output("answer", value) - else: - self._graph_runtime_state.set_output(key, value) diff --git a/api/graphon/graph_engine/event_management/event_manager.py b/api/graphon/graph_engine/event_management/event_manager.py deleted file mode 100644 index 5b2fb365e9b..00000000000 --- a/api/graphon/graph_engine/event_management/event_manager.py +++ /dev/null @@ -1,186 +0,0 @@ -""" -Unified event manager for collecting and emitting events. -""" - -import logging -import threading -import time -from collections.abc import Generator -from contextlib import contextmanager -from typing import final - -from graphon.graph_events import GraphEngineEvent - -from ..layers.base import GraphEngineLayer - -_logger = logging.getLogger(__name__) - - -@final -class ReadWriteLock: - """ - A read-write lock implementation that allows multiple concurrent readers - but only one writer at a time. - """ - - def __init__(self) -> None: - self._read_ready = threading.Condition(threading.RLock()) - self._readers = 0 - - def acquire_read(self) -> None: - """Acquire a read lock.""" - _ = self._read_ready.acquire() - try: - self._readers += 1 - finally: - self._read_ready.release() - - def release_read(self) -> None: - """Release a read lock.""" - _ = self._read_ready.acquire() - try: - self._readers -= 1 - if self._readers == 0: - self._read_ready.notify_all() - finally: - self._read_ready.release() - - def acquire_write(self) -> None: - """Acquire a write lock.""" - _ = self._read_ready.acquire() - while self._readers > 0: - _ = self._read_ready.wait() - - def release_write(self) -> None: - """Release a write lock.""" - self._read_ready.release() - - @contextmanager - def read_lock(self): - """Return a context manager for read locking.""" - self.acquire_read() - try: - yield - finally: - self.release_read() - - @contextmanager - def write_lock(self): - """Return a context manager for write locking.""" - self.acquire_write() - try: - yield - finally: - self.release_write() - - -@final -class EventManager: - """ - Unified event manager that collects, buffers, and emits events. - - This class combines event collection with event emission, providing - thread-safe event management with support for notifying layers and - streaming events to external consumers. - """ - - def __init__(self) -> None: - """Initialize the event manager.""" - self._events: list[GraphEngineEvent] = [] - self._lock = ReadWriteLock() - self._layers: list[GraphEngineLayer] = [] - self._execution_complete = threading.Event() - - def set_layers(self, layers: list[GraphEngineLayer]) -> None: - """ - Set the layers to notify on event collection. - - Args: - layers: List of layers to notify - """ - self._layers = layers - - def notify_layers(self, event: GraphEngineEvent) -> None: - """Notify registered layers about an event without buffering it.""" - self._notify_layers(event) - - def collect(self, event: GraphEngineEvent) -> None: - """ - Thread-safe method to collect an event. - - Args: - event: The event to collect - """ - with self._lock.write_lock(): - self._events.append(event) - - # NOTE: `_notify_layers` is intentionally called outside the critical section - # to minimize lock contention and avoid blocking other readers or writers. - # - # The public `notify_layers` method also does not use a write lock, - # so protecting `_notify_layers` with a lock here is unnecessary. - self._notify_layers(event) - - def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]: - """ - Get new events starting from a specific index. - - Args: - start_index: The index to start from - - Returns: - List of new events - """ - with self._lock.read_lock(): - return list(self._events[start_index:]) - - def _event_count(self) -> int: - """ - Get the current count of collected events. - - Returns: - Number of collected events - """ - with self._lock.read_lock(): - return len(self._events) - - def mark_complete(self) -> None: - """Mark execution as complete to stop the event emission generator.""" - self._execution_complete.set() - - def emit_events(self) -> Generator[GraphEngineEvent, None, None]: - """ - Generator that yields events as they're collected. - - Yields: - GraphEngineEvent instances as they're processed - """ - yielded_count = 0 - - while not self._execution_complete.is_set() or yielded_count < self._event_count(): - # Get new events since last yield - new_events = self._get_new_events(yielded_count) - - # Yield any new events - for event in new_events: - yield event - yielded_count += 1 - - # Small sleep to avoid busy waiting - if not self._execution_complete.is_set() and not new_events: - time.sleep(0.001) - - def _notify_layers(self, event: GraphEngineEvent) -> None: - """ - Notify all layers of an event. - - Layer exceptions are caught and logged to prevent disrupting collection. - - Args: - event: The event to send to layers - """ - for layer in self._layers: - try: - layer.on_event(event) - except Exception: - _logger.exception("Error in layer on_event, layer_type=%s", type(layer)) diff --git a/api/graphon/graph_engine/graph_engine.py b/api/graphon/graph_engine/graph_engine.py deleted file mode 100644 index 32e0e60502f..00000000000 --- a/api/graphon/graph_engine/graph_engine.py +++ /dev/null @@ -1,377 +0,0 @@ -""" -QueueBasedGraphEngine - Main orchestrator for queue-based workflow execution. - -This engine uses a modular architecture with separated packages following -Domain-Driven Design principles for improved maintainability and testability. -""" - -from __future__ import annotations - -import logging -import queue -from collections.abc import Generator -from typing import TYPE_CHECKING, cast, final - -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import NodeExecutionType -from graphon.graph import Graph -from graphon.graph_events import ( - GraphEngineEvent, - GraphNodeEventBase, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) -from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool -from graphon.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol - -if TYPE_CHECKING: # pragma: no cover - used only for static analysis - from graphon.runtime.graph_runtime_state import GraphProtocol - -from .command_processing import ( - AbortCommandHandler, - CommandProcessor, - PauseCommandHandler, - UpdateVariablesCommandHandler, -) -from .config import GraphEngineConfig -from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand -from .error_handler import ErrorHandler -from .event_management import EventHandler, EventManager -from .graph_state_manager import GraphStateManager -from .graph_traversal import EdgeProcessor, SkipPropagator -from .layers.base import GraphEngineLayer -from .orchestration import Dispatcher, ExecutionCoordinator -from .protocols.command_channel import CommandChannel -from .worker_management import WorkerPool - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.graph_engine.domain.graph_execution import GraphExecution - from graphon.graph_engine.response_coordinator import ResponseStreamCoordinator - -logger = logging.getLogger(__name__) - - -_DEFAULT_CONFIG = GraphEngineConfig() - - -@final -class GraphEngine: - """ - Queue-based graph execution engine. - - Uses a modular architecture that delegates responsibilities to specialized - subsystems, following Domain-Driven Design and SOLID principles. - """ - - def __init__( - self, - workflow_id: str, - graph: Graph, - graph_runtime_state: GraphRuntimeState, - command_channel: CommandChannel, - config: GraphEngineConfig = _DEFAULT_CONFIG, - child_engine_builder: ChildGraphEngineBuilderProtocol | None = None, - ) -> None: - """Initialize the graph engine with all subsystems and dependencies.""" - - # Bind runtime state to current workflow context - self._graph = graph - self._graph_runtime_state = graph_runtime_state - self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) - self._command_channel = command_channel - self._config = config - self._layers: list[GraphEngineLayer] = [] - self._child_engine_builder = child_engine_builder - if child_engine_builder is not None: - self._graph_runtime_state.bind_child_engine_builder(child_engine_builder) - - # Graph execution tracks the overall execution state - self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution) - self._graph_execution.workflow_id = workflow_id - - # === Execution Queues === - self._ready_queue = self._graph_runtime_state.ready_queue - - # Queue for events generated during execution - self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() - - # === State Management === - # Unified state manager handles all node state transitions and queue operations - self._state_manager = GraphStateManager(self._graph, self._ready_queue) - - # === Response Coordination === - # Coordinates response streaming from response nodes - self._response_coordinator = cast("ResponseStreamCoordinator", self._graph_runtime_state.response_coordinator) - - # === Event Management === - # Event manager handles both collection and emission of events - self._event_manager = EventManager() - - # === Error Handling === - # Centralized error handler for graph execution errors - self._error_handler = ErrorHandler(self._graph, self._graph_execution) - - # === Graph Traversal Components === - # Propagates skip status through the graph when conditions aren't met - self._skip_propagator = SkipPropagator( - graph=self._graph, - state_manager=self._state_manager, - ) - - # Processes edges to determine next nodes after execution - # Also handles conditional branching and route selection - self._edge_processor = EdgeProcessor( - graph=self._graph, - state_manager=self._state_manager, - response_coordinator=self._response_coordinator, - skip_propagator=self._skip_propagator, - ) - - # === Command Processing === - # Processes external commands (e.g., abort requests) - self._command_processor = CommandProcessor( - command_channel=self._command_channel, - graph_execution=self._graph_execution, - ) - - # Register command handlers - abort_handler = AbortCommandHandler() - self._command_processor.register_handler(AbortCommand, abort_handler) - - pause_handler = PauseCommandHandler() - self._command_processor.register_handler(PauseCommand, pause_handler) - - update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool) - self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler) - - # === Worker Pool Setup === - # Create worker pool for parallel node execution - self._worker_pool = WorkerPool( - ready_queue=self._ready_queue, - event_queue=self._event_queue, - graph=self._graph, - layers=self._layers, - execution_context=self._graph_runtime_state.execution_context, - config=self._config, - ) - - # === Orchestration === - # Coordinates the overall execution lifecycle - self._execution_coordinator = ExecutionCoordinator( - graph_execution=self._graph_execution, - state_manager=self._state_manager, - command_processor=self._command_processor, - worker_pool=self._worker_pool, - ) - - # === Event Handler Registry === - # Central registry for handling all node execution events - self._event_handler_registry = EventHandler( - graph=self._graph, - graph_runtime_state=self._graph_runtime_state, - graph_execution=self._graph_execution, - response_coordinator=self._response_coordinator, - event_collector=self._event_manager, - edge_processor=self._edge_processor, - state_manager=self._state_manager, - error_handler=self._error_handler, - ) - - # Dispatches events and manages execution flow - self._dispatcher = Dispatcher( - event_queue=self._event_queue, - event_handler=self._event_handler_registry, - execution_coordinator=self._execution_coordinator, - event_emitter=self._event_manager, - ) - - # === Validation === - # Ensure all nodes share the same GraphRuntimeState instance - self._validate_graph_state_consistency() - - def _validate_graph_state_consistency(self) -> None: - """Validate that all nodes share the same GraphRuntimeState.""" - expected_state_id = id(self._graph_runtime_state) - for node in self._graph.nodes.values(): - if id(node.graph_runtime_state) != expected_state_id: - raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance") - - def _bind_layer_context( - self, - layer: GraphEngineLayer, - ) -> None: - layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel) - - def layer(self, layer: GraphEngineLayer) -> GraphEngine: - """Add a layer for extending functionality.""" - self._layers.append(layer) - self._bind_layer_context(layer) - return self - - def request_abort(self, reason: str | None = None) -> None: - """Queue an abort command for this engine.""" - self._command_channel.send_command(AbortCommand(reason=reason or "User requested abort")) - - def create_child_engine( - self, - *, - workflow_id: str, - graph_init_params: GraphInitParams, - root_node_id: str, - variable_pool: VariablePool | None = None, - ) -> GraphEngine: - return self._graph_runtime_state.create_child_engine( - workflow_id=workflow_id, - graph_init_params=graph_init_params, - root_node_id=root_node_id, - variable_pool=variable_pool, - ) - - def run(self) -> Generator[GraphEngineEvent, None, None]: - """ - Execute the graph using the modular architecture. - - Returns: - Generator yielding GraphEngineEvent instances - """ - try: - # Initialize layers - self._initialize_layers() - - is_resume = self._graph_execution.started - if not is_resume: - self._graph_execution.start() - else: - self._graph_execution.paused = False - self._graph_execution.pause_reasons = [] - - start_event = GraphRunStartedEvent( - reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL, - ) - self._event_manager.notify_layers(start_event) - yield start_event - - # Start subsystems - self._start_execution(resume=is_resume) - - # Yield events as they occur - yield from self._event_manager.emit_events() - - # Handle completion - if self._graph_execution.is_paused: - pause_reasons = self._graph_execution.pause_reasons - assert pause_reasons, "pause_reasons should not be empty when execution is paused." - # Ensure we have a valid PauseReason for the event - paused_event = GraphRunPausedEvent( - reasons=pause_reasons, - outputs=self._graph_runtime_state.outputs, - ) - self._event_manager.notify_layers(paused_event) - yield paused_event - elif self._graph_execution.aborted: - abort_reason = "Workflow execution aborted by user command" - if self._graph_execution.error: - abort_reason = str(self._graph_execution.error) - aborted_event = GraphRunAbortedEvent( - reason=abort_reason, - outputs=self._graph_runtime_state.outputs, - ) - self._event_manager.notify_layers(aborted_event) - yield aborted_event - elif self._graph_execution.has_error: - if self._graph_execution.error: - raise self._graph_execution.error - else: - outputs = self._graph_runtime_state.outputs - exceptions_count = self._graph_execution.exceptions_count - if exceptions_count > 0: - partial_event = GraphRunPartialSucceededEvent( - exceptions_count=exceptions_count, - outputs=outputs, - ) - self._event_manager.notify_layers(partial_event) - yield partial_event - else: - succeeded_event = GraphRunSucceededEvent( - outputs=outputs, - ) - self._event_manager.notify_layers(succeeded_event) - yield succeeded_event - - except Exception as e: - failed_event = GraphRunFailedEvent( - error=str(e), - exceptions_count=self._graph_execution.exceptions_count, - ) - self._event_manager.notify_layers(failed_event) - yield failed_event - raise - - finally: - self._stop_execution() - - def _initialize_layers(self) -> None: - """Initialize layers with context.""" - self._event_manager.set_layers(self._layers) - for layer in self._layers: - try: - layer.on_graph_start() - except Exception: - logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__) - - def _start_execution(self, *, resume: bool = False) -> None: - """Start execution subsystems.""" - paused_nodes: list[str] = [] - deferred_nodes: list[str] = [] - if resume: - paused_nodes = self._graph_runtime_state.consume_paused_nodes() - deferred_nodes = self._graph_runtime_state.consume_deferred_nodes() - - # Start worker pool (it calculates initial workers internally) - self._worker_pool.start() - - # Register response nodes - for node in self._graph.nodes.values(): - if node.execution_type == NodeExecutionType.RESPONSE: - self._response_coordinator.register(node.id) - - if not resume: - # Enqueue root node - root_node = self._graph.root_node - self._state_manager.enqueue_node(root_node.id) - self._state_manager.start_execution(root_node.id) - else: - seen_nodes: set[str] = set() - for node_id in paused_nodes + deferred_nodes: - if node_id in seen_nodes: - continue - seen_nodes.add(node_id) - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) - - # Start dispatcher - self._dispatcher.start() - - def _stop_execution(self) -> None: - """Stop execution subsystems.""" - self._dispatcher.stop() - self._worker_pool.stop() - # Don't mark complete here as the dispatcher already does it - - # Notify layers - for layer in self._layers: - try: - layer.on_graph_end(self._graph_execution.error) - except Exception: - logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__) - - # Public property accessors for attributes that need external access - @property - def graph_runtime_state(self) -> GraphRuntimeState: - """Get the graph runtime state.""" - return self._graph_runtime_state diff --git a/api/graphon/graph_engine/graph_state_manager.py b/api/graphon/graph_engine/graph_state_manager.py deleted file mode 100644 index ade8e403a87..00000000000 --- a/api/graphon/graph_engine/graph_state_manager.py +++ /dev/null @@ -1,290 +0,0 @@ -""" -Graph state manager that combines node, edge, and execution tracking. -""" - -import threading -from collections.abc import Sequence -from typing import TypedDict, final - -from graphon.enums import NodeState -from graphon.graph import Edge, Graph - -from .ready_queue import ReadyQueue - - -class EdgeStateAnalysis(TypedDict): - """Analysis result for edge states.""" - - has_unknown: bool - has_taken: bool - all_skipped: bool - - -@final -class GraphStateManager: - def __init__(self, graph: Graph, ready_queue: ReadyQueue) -> None: - """ - Initialize the state manager. - - Args: - graph: The workflow graph - ready_queue: Queue for nodes ready to execute - """ - self._graph = graph - self._ready_queue = ready_queue - self._lock = threading.RLock() - - # Execution tracking state - self._executing_nodes: set[str] = set() - - # ============= Node State Operations ============= - - def enqueue_node(self, node_id: str) -> None: - """ - Mark a node as TAKEN and add it to the ready queue. - - This combines the state transition and enqueueing operations - that always occur together when preparing a node for execution. - - Args: - node_id: The ID of the node to enqueue - """ - with self._lock: - self._graph.nodes[node_id].state = NodeState.TAKEN - self._ready_queue.put(node_id) - - def mark_node_skipped(self, node_id: str) -> None: - """ - Mark a node as SKIPPED. - - Args: - node_id: The ID of the node to skip - """ - with self._lock: - self._graph.nodes[node_id].state = NodeState.SKIPPED - - def is_node_ready(self, node_id: str) -> bool: - """ - Check if a node is ready to be executed. - - A node is ready when all its incoming edges from taken branches - have been satisfied. - - Args: - node_id: The ID of the node to check - - Returns: - True if the node is ready for execution - """ - with self._lock: - # Get all incoming edges to this node - incoming_edges = self._graph.get_incoming_edges(node_id) - - # If no incoming edges, node is always ready - if not incoming_edges: - return True - - # If any edge is UNKNOWN, node is not ready - if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges): - return False - - # Node is ready if at least one edge is TAKEN - return any(edge.state == NodeState.TAKEN for edge in incoming_edges) - - def get_node_state(self, node_id: str) -> NodeState: - """ - Get the current state of a node. - - Args: - node_id: The ID of the node - - Returns: - The current node state - """ - with self._lock: - return self._graph.nodes[node_id].state - - # ============= Edge State Operations ============= - - def mark_edge_taken(self, edge_id: str) -> None: - """ - Mark an edge as TAKEN. - - Args: - edge_id: The ID of the edge to mark - """ - with self._lock: - self._graph.edges[edge_id].state = NodeState.TAKEN - - def mark_edge_skipped(self, edge_id: str) -> None: - """ - Mark an edge as SKIPPED. - - Args: - edge_id: The ID of the edge to mark - """ - with self._lock: - self._graph.edges[edge_id].state = NodeState.SKIPPED - - def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis: - """ - Analyze the states of edges and return summary flags. - - Args: - edges: List of edges to analyze - - Returns: - Analysis result with state flags - """ - with self._lock: - states = {edge.state for edge in edges} - - return EdgeStateAnalysis( - has_unknown=NodeState.UNKNOWN in states, - has_taken=NodeState.TAKEN in states, - all_skipped=states == {NodeState.SKIPPED} if states else True, - ) - - def get_edge_state(self, edge_id: str) -> NodeState: - """ - Get the current state of an edge. - - Args: - edge_id: The ID of the edge - - Returns: - The current edge state - """ - with self._lock: - return self._graph.edges[edge_id].state - - def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]: - """ - Categorize branch edges into selected and unselected. - - Args: - node_id: The ID of the branch node - selected_handle: The handle of the selected edge - - Returns: - A tuple of (selected_edges, unselected_edges) - """ - with self._lock: - outgoing_edges = self._graph.get_outgoing_edges(node_id) - selected_edges: list[Edge] = [] - unselected_edges: list[Edge] = [] - - for edge in outgoing_edges: - if edge.source_handle == selected_handle: - selected_edges.append(edge) - else: - unselected_edges.append(edge) - - return selected_edges, unselected_edges - - # ============= Execution Tracking Operations ============= - - def start_execution(self, node_id: str) -> None: - """ - Mark a node as executing. - - Args: - node_id: The ID of the node starting execution - """ - with self._lock: - self._executing_nodes.add(node_id) - - def finish_execution(self, node_id: str) -> None: - """ - Mark a node as no longer executing. - - Args: - node_id: The ID of the node finishing execution - """ - with self._lock: - self._executing_nodes.discard(node_id) - - def is_executing(self, node_id: str) -> bool: - """ - Check if a node is currently executing. - - Args: - node_id: The ID of the node to check - - Returns: - True if the node is executing - """ - with self._lock: - return node_id in self._executing_nodes - - def get_executing_count(self) -> int: - """ - Get the count of currently executing nodes. - - Returns: - Number of executing nodes - """ - # This count is a best-effort snapshot and can change concurrently. - # Only use it for pause-drain checks where scheduling is already frozen. - with self._lock: - return len(self._executing_nodes) - - def get_executing_nodes(self) -> set[str]: - """ - Get a copy of the set of executing node IDs. - - Returns: - Set of node IDs currently executing - """ - with self._lock: - return self._executing_nodes.copy() - - def clear_executing(self) -> None: - """Clear all executing nodes.""" - with self._lock: - self._executing_nodes.clear() - - # ============= Composite Operations ============= - - def is_execution_complete(self) -> bool: - """ - Check if graph execution is complete. - - Execution is complete when: - - Ready queue is empty - - No nodes are executing - - Returns: - True if execution is complete - """ - with self._lock: - return self._ready_queue.empty() and len(self._executing_nodes) == 0 - - def get_queue_depth(self) -> int: - """ - Get the current depth of the ready queue. - - Returns: - Number of nodes in the ready queue - """ - return self._ready_queue.qsize() - - def get_execution_stats(self) -> dict[str, int]: - """ - Get execution statistics. - - Returns: - Dictionary with execution statistics - """ - with self._lock: - taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN) - skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED) - unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN) - - return { - "queue_depth": self._ready_queue.qsize(), - "executing": len(self._executing_nodes), - "taken_nodes": taken_nodes, - "skipped_nodes": skipped_nodes, - "unknown_nodes": unknown_nodes, - } diff --git a/api/graphon/graph_engine/graph_traversal/__init__.py b/api/graphon/graph_engine/graph_traversal/__init__.py deleted file mode 100644 index d629140d066..00000000000 --- a/api/graphon/graph_engine/graph_traversal/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Graph traversal subsystem for graph engine. - -This package handles graph navigation, edge processing, -and skip propagation logic. -""" - -from .edge_processor import EdgeProcessor -from .skip_propagator import SkipPropagator - -__all__ = [ - "EdgeProcessor", - "SkipPropagator", -] diff --git a/api/graphon/graph_engine/graph_traversal/edge_processor.py b/api/graphon/graph_engine/graph_traversal/edge_processor.py deleted file mode 100644 index e51eee8a69d..00000000000 --- a/api/graphon/graph_engine/graph_traversal/edge_processor.py +++ /dev/null @@ -1,201 +0,0 @@ -""" -Edge processing logic for graph traversal. -""" - -from collections.abc import Sequence -from typing import TYPE_CHECKING, final - -from graphon.enums import NodeExecutionType -from graphon.graph import Edge, Graph -from graphon.graph_events import NodeRunStreamChunkEvent - -from ..graph_state_manager import GraphStateManager -from ..response_coordinator import ResponseStreamCoordinator - -if TYPE_CHECKING: - from .skip_propagator import SkipPropagator - - -@final -class EdgeProcessor: - """ - Processes edges during graph execution. - - This handles marking edges as taken or skipped, notifying - the response coordinator, triggering downstream node execution, - and managing branch node logic. - """ - - def __init__( - self, - graph: Graph, - state_manager: GraphStateManager, - response_coordinator: ResponseStreamCoordinator, - skip_propagator: "SkipPropagator", - ) -> None: - """ - Initialize the edge processor. - - Args: - graph: The workflow graph - state_manager: Unified state manager - response_coordinator: Response stream coordinator - skip_propagator: Propagator for skip states - """ - self._graph = graph - self._state_manager = state_manager - self._response_coordinator = response_coordinator - self._skip_propagator = skip_propagator - - def process_node_success( - self, node_id: str, selected_handle: str | None = None - ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Process edges after a node succeeds. - - Args: - node_id: The ID of the succeeded node - selected_handle: For branch nodes, the selected edge handle - - Returns: - Tuple of (list of downstream node IDs that are now ready, list of streaming events) - """ - node = self._graph.nodes[node_id] - - if node.execution_type == NodeExecutionType.BRANCH: - return self._process_branch_node_edges(node_id, selected_handle) - else: - return self._process_non_branch_node_edges(node_id) - - def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Process edges for non-branch nodes (mark all as TAKEN). - - Args: - node_id: The ID of the succeeded node - - Returns: - Tuple of (list of downstream nodes ready for execution, list of streaming events) - """ - ready_nodes: list[str] = [] - all_streaming_events: list[NodeRunStreamChunkEvent] = [] - outgoing_edges = self._graph.get_outgoing_edges(node_id) - - for edge in outgoing_edges: - nodes, events = self._process_taken_edge(edge) - ready_nodes.extend(nodes) - all_streaming_events.extend(events) - - return ready_nodes, all_streaming_events - - def _process_branch_node_edges( - self, node_id: str, selected_handle: str | None - ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Process edges for branch nodes. - - Args: - node_id: The ID of the branch node - selected_handle: The handle of the selected edge - - Returns: - Tuple of (list of downstream nodes ready for execution, list of streaming events) - - Raises: - ValueError: If no edge was selected - """ - if not selected_handle: - raise ValueError(f"Branch node {node_id} did not select any edge") - - ready_nodes: list[str] = [] - all_streaming_events: list[NodeRunStreamChunkEvent] = [] - - # Categorize edges - selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle) - - # Process unselected edges first (mark as skipped) - for edge in unselected_edges: - self._process_skipped_edge(edge) - - # Process selected edges - for edge in selected_edges: - nodes, events = self._process_taken_edge(edge) - ready_nodes.extend(nodes) - all_streaming_events.extend(events) - - return ready_nodes, all_streaming_events - - def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Mark edge as taken and check downstream node. - - Args: - edge: The edge to process - - Returns: - Tuple of (list containing downstream node ID if it's ready, list of streaming events) - """ - # Mark edge as taken - self._state_manager.mark_edge_taken(edge.id) - - # Notify response coordinator and get streaming events - streaming_events = self._response_coordinator.on_edge_taken(edge.id) - - # Check if downstream node is ready - ready_nodes: list[str] = [] - if self._state_manager.is_node_ready(edge.head): - ready_nodes.append(edge.head) - - return ready_nodes, streaming_events - - def _process_skipped_edge(self, edge: Edge) -> None: - """ - Mark edge as skipped. - - Args: - edge: The edge to skip - """ - self._state_manager.mark_edge_skipped(edge.id) - - def handle_branch_completion( - self, node_id: str, selected_handle: str | None - ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Handle completion of a branch node. - - Args: - node_id: The ID of the branch node - selected_handle: The handle of the selected branch - - Returns: - Tuple of (list of downstream nodes ready for execution, list of streaming events) - - Raises: - ValueError: If no branch was selected - """ - if not selected_handle: - raise ValueError(f"Branch node {node_id} completed without selecting a branch") - - # Categorize edges into selected and unselected - _, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle) - - # Skip all unselected paths - self._skip_propagator.skip_branch_paths(unselected_edges) - - # Process selected edges and get ready nodes and streaming events - return self.process_node_success(node_id, selected_handle) - - def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool: - """ - Validate that a branch selection is valid. - - Args: - node_id: The ID of the branch node - selected_handle: The handle to validate - - Returns: - True if the selection is valid - """ - outgoing_edges = self._graph.get_outgoing_edges(node_id) - valid_handles = {edge.source_handle for edge in outgoing_edges} - return selected_handle in valid_handles diff --git a/api/graphon/graph_engine/graph_traversal/skip_propagator.py b/api/graphon/graph_engine/graph_traversal/skip_propagator.py deleted file mode 100644 index bdb83b38ad4..00000000000 --- a/api/graphon/graph_engine/graph_traversal/skip_propagator.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -Skip state propagation through the graph. -""" - -from collections.abc import Sequence -from typing import final - -from graphon.graph import Edge, Graph - -from ..graph_state_manager import GraphStateManager - - -@final -class SkipPropagator: - """ - Propagates skip states through the graph. - - When a node is skipped, this ensures all downstream nodes - that depend solely on it are also skipped. - """ - - def __init__( - self, - graph: Graph, - state_manager: GraphStateManager, - ) -> None: - """ - Initialize the skip propagator. - - Args: - graph: The workflow graph - state_manager: Unified state manager - """ - self._graph = graph - self._state_manager = state_manager - - def propagate_skip_from_edge(self, edge_id: str) -> None: - """ - Recursively propagate skip state from a skipped edge. - - Rules: - - If a node has any UNKNOWN incoming edges, stop processing - - If all incoming edges are SKIPPED, skip the node and its edges - - If any incoming edge is TAKEN, the node may still execute - - Args: - edge_id: The ID of the skipped edge to start from - """ - downstream_node_id = self._graph.edges[edge_id].head - incoming_edges = self._graph.get_incoming_edges(downstream_node_id) - - # Analyze edge states - edge_states = self._state_manager.analyze_edge_states(incoming_edges) - - # Stop if there are unknown edges (not yet processed) - if edge_states["has_unknown"]: - return - - # If any edge is taken, node may still execute - if edge_states["has_taken"]: - # Enqueue node - self._state_manager.enqueue_node(downstream_node_id) - self._state_manager.start_execution(downstream_node_id) - return - - # All edges are skipped, propagate skip to this node - if edge_states["all_skipped"]: - self._propagate_skip_to_node(downstream_node_id) - - def _propagate_skip_to_node(self, node_id: str) -> None: - """ - Mark a node and all its outgoing edges as skipped. - - Args: - node_id: The ID of the node to skip - """ - # Mark node as skipped - self._state_manager.mark_node_skipped(node_id) - - # Mark all outgoing edges as skipped and propagate - outgoing_edges = self._graph.get_outgoing_edges(node_id) - for edge in outgoing_edges: - self._state_manager.mark_edge_skipped(edge.id) - # Recursively propagate skip - self.propagate_skip_from_edge(edge.id) - - def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None: - """ - Skip all paths from unselected branch edges. - - Args: - unselected_edges: List of edges not taken by the branch - """ - for edge in unselected_edges: - self._state_manager.mark_edge_skipped(edge.id) - self.propagate_skip_from_edge(edge.id) diff --git a/api/graphon/graph_engine/layers/README.md b/api/graphon/graph_engine/layers/README.md deleted file mode 100644 index b0f295037c0..00000000000 --- a/api/graphon/graph_engine/layers/README.md +++ /dev/null @@ -1,55 +0,0 @@ -# Layers - -Pluggable middleware for engine extensions. - -## Components - -### Layer (base) - -Abstract base class for layers. - -- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks) -- `on_graph_start()` - Execution start hook -- `on_event()` - Process all events -- `on_graph_end()` - Execution end hook - -### DebugLoggingLayer - -Comprehensive execution logging. - -- Configurable detail levels -- Tracks execution statistics -- Truncates long values - -## Usage - -```python -debug_layer = DebugLoggingLayer( - level="INFO", - include_outputs=True -) - -engine = GraphEngine(graph) -engine.layer(debug_layer) -engine.run() -``` - -`engine.layer()` binds the read-only runtime state before execution, so -`graph_runtime_state` is always available inside layer hooks. - -## Custom Layers - -```python -class MetricsLayer(Layer): - def on_event(self, event): - if isinstance(event, NodeRunSucceededEvent): - self.metrics[event.node_id] = event.elapsed_time -``` - -## Configuration - -**DebugLoggingLayer Options:** - -- `level` - Log level (INFO, DEBUG, ERROR) -- `include_inputs/outputs` - Log data values -- `max_value_length` - Truncate long values diff --git a/api/graphon/graph_engine/layers/__init__.py b/api/graphon/graph_engine/layers/__init__.py deleted file mode 100644 index 0a29a529936..00000000000 --- a/api/graphon/graph_engine/layers/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Layer system for GraphEngine extensibility. - -This module provides the layer infrastructure for extending GraphEngine functionality -with middleware-like components that can observe events and interact with execution. -""" - -from .base import GraphEngineLayer -from .debug_logging import DebugLoggingLayer -from .execution_limits import ExecutionLimitsLayer - -__all__ = [ - "DebugLoggingLayer", - "ExecutionLimitsLayer", - "GraphEngineLayer", -] diff --git a/api/graphon/graph_engine/layers/base.py b/api/graphon/graph_engine/layers/base.py deleted file mode 100644 index 605615d3470..00000000000 --- a/api/graphon/graph_engine/layers/base.py +++ /dev/null @@ -1,128 +0,0 @@ -""" -Base layer class for GraphEngine extensions. - -This module provides the abstract base class for implementing layers that can -intercept and respond to GraphEngine events. -""" - -from abc import ABC, abstractmethod - -from graphon.graph_engine.protocols.command_channel import CommandChannel -from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.runtime import ReadOnlyGraphRuntimeState - - -class GraphEngineLayerNotInitializedError(Exception): - """Raised when a layer's runtime state is accessed before initialization.""" - - def __init__(self, layer_name: str | None = None) -> None: - name = layer_name or "GraphEngineLayer" - super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.") - - -class GraphEngineLayer(ABC): - """ - Abstract base class for GraphEngine layers. - - Layers are middleware-like components that can: - - Observe all events emitted by the GraphEngine - - Access the graph runtime state - - Send commands to control execution - - Subclasses should override the constructor to accept configuration parameters, - then implement the three lifecycle methods. - """ - - def __init__(self) -> None: - """Initialize the layer. Subclasses can override with custom parameters.""" - self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None - self.command_channel: CommandChannel | None = None - - @property - def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState: - if self._graph_runtime_state is None: - raise GraphEngineLayerNotInitializedError(type(self).__name__) - return self._graph_runtime_state - - def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None: - """ - Initialize the layer with engine dependencies. - - Called by GraphEngine to inject the read-only runtime state and command channel. - This is invoked when the layer is registered with a `GraphEngine` instance. - Implementations should be idempotent. - Args: - graph_runtime_state: Read-only view of the runtime state - command_channel: Channel for sending commands to the engine - """ - self._graph_runtime_state = graph_runtime_state - self.command_channel = command_channel - - @abstractmethod - def on_graph_start(self) -> None: - """ - Called when graph execution starts. - - This is called after the engine has been initialized but before any nodes - are executed. Layers can use this to set up resources or log start information. - """ - pass - - @abstractmethod - def on_event(self, event: GraphEngineEvent) -> None: - """ - Called for every event emitted by the engine. - - This method receives all events generated during graph execution, including: - - Graph lifecycle events (start, success, failure) - - Node execution events (start, success, failure, retry) - - Stream events for response nodes - - Container events (iteration, loop) - - Args: - event: The event emitted by the engine - """ - pass - - @abstractmethod - def on_graph_end(self, error: Exception | None) -> None: - """ - Called when graph execution ends. - - This is called after all nodes have been executed or when execution is - aborted. Layers can use this to clean up resources or log final state. - - Args: - error: The exception that caused execution to fail, or None if successful - """ - pass - - def on_node_run_start(self, node: Node) -> None: - """ - Called immediately before a node begins execution. - - Layers can override to inject behavior (e.g., start spans) prior to node execution. - The node's execution ID is available via `node._node_execution_id` and will be - consistent with all events emitted by this node execution. - - Args: - node: The node instance about to be executed - """ - return - - def on_node_run_end( - self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None - ) -> None: - """ - Called after a node finishes execution. - - The node's execution ID is available via `node._node_execution_id` and matches - the `id` field in all events emitted by this node execution. - - Args: - node: The node instance that just finished execution - error: Exception instance if the node failed, otherwise None - result_event: The final result event from node execution (succeeded/failed/paused), if any - """ - return diff --git a/api/graphon/graph_engine/layers/debug_logging.py b/api/graphon/graph_engine/layers/debug_logging.py deleted file mode 100644 index e6585fb3b94..00000000000 --- a/api/graphon/graph_engine/layers/debug_logging.py +++ /dev/null @@ -1,247 +0,0 @@ -""" -Debug logging layer for GraphEngine. - -This module provides a layer that logs all events and state changes during -graph execution for debugging purposes. -""" - -import logging -from collections.abc import Mapping -from typing import Any, final - -from typing_extensions import override - -from graphon.graph_events import ( - GraphEngineEvent, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .base import GraphEngineLayer - - -@final -class DebugLoggingLayer(GraphEngineLayer): - """ - A layer that provides comprehensive logging of GraphEngine execution. - - This layer logs all events with configurable detail levels, helping developers - debug workflow execution and understand the flow of events. - """ - - def __init__( - self, - level: str = "INFO", - include_inputs: bool = False, - include_outputs: bool = True, - include_process_data: bool = False, - logger_name: str = "GraphEngine.Debug", - max_value_length: int = 500, - ) -> None: - """ - Initialize the debug logging layer. - - Args: - level: Logging level (DEBUG, INFO, WARNING, ERROR) - include_inputs: Whether to log node input values - include_outputs: Whether to log node output values - include_process_data: Whether to log node process data - logger_name: Name of the logger to use - max_value_length: Maximum length of logged values (truncated if longer) - """ - super().__init__() - self.level = level - self.include_inputs = include_inputs - self.include_outputs = include_outputs - self.include_process_data = include_process_data - self.max_value_length = max_value_length - - # Set up logger - self.logger = logging.getLogger(logger_name) - log_level = getattr(logging, level.upper(), logging.INFO) - self.logger.setLevel(log_level) - - # Track execution stats - self.node_count = 0 - self.success_count = 0 - self.failure_count = 0 - self.retry_count = 0 - - def _truncate_value(self, value: Any) -> str: - """Truncate long values for logging.""" - str_value = str(value) - if len(str_value) > self.max_value_length: - return str_value[: self.max_value_length] + "... (truncated)" - return str_value - - def _format_dict(self, data: dict[str, Any] | Mapping[str, Any]) -> str: - """Format a dictionary or mapping for logging with truncation.""" - if not data: - return "{}" - - formatted_items: list[str] = [] - for key, value in data.items(): - formatted_value = self._truncate_value(value) - formatted_items.append(f" {key}: {formatted_value}") - - return "{\n" + ",\n".join(formatted_items) + "\n}" - - @override - def on_graph_start(self) -> None: - """Log graph execution start.""" - self.logger.info("=" * 80) - self.logger.info("๐Ÿš€ GRAPH EXECUTION STARTED") - self.logger.info("=" * 80) - # Log initial state - self.logger.info("Initial State:") - - @override - def on_event(self, event: GraphEngineEvent) -> None: - """Log individual events based on their type.""" - event_class = event.__class__.__name__ - - # Graph-level events - if isinstance(event, GraphRunStartedEvent): - self.logger.debug("Graph run started event") - - elif isinstance(event, GraphRunSucceededEvent): - self.logger.info("โœ… Graph run succeeded") - if self.include_outputs and event.outputs: - self.logger.info(" Final outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, GraphRunPartialSucceededEvent): - self.logger.warning("โš ๏ธ Graph run partially succeeded") - if event.exceptions_count > 0: - self.logger.warning(" Total exceptions: %s", event.exceptions_count) - if self.include_outputs and event.outputs: - self.logger.info(" Final outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, GraphRunFailedEvent): - self.logger.error("โŒ Graph run failed: %s", event.error) - if event.exceptions_count > 0: - self.logger.error(" Total exceptions: %s", event.exceptions_count) - - elif isinstance(event, GraphRunAbortedEvent): - self.logger.warning("โš ๏ธ Graph run aborted: %s", event.reason) - if event.outputs: - self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs)) - - # Node-level events - # Retry before Started because Retry subclasses Started; - elif isinstance(event, NodeRunRetryEvent): - self.retry_count += 1 - self.logger.warning("๐Ÿ”„ Node retry: %s (attempt %s)", event.node_id, event.retry_index) - self.logger.warning(" Previous error: %s", event.error) - - elif isinstance(event, NodeRunStartedEvent): - self.node_count += 1 - self.logger.info('โ–ถ๏ธ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type) - - if self.include_inputs and event.node_run_result.inputs: - self.logger.debug(" Inputs: %s", self._format_dict(event.node_run_result.inputs)) - - elif isinstance(event, NodeRunSucceededEvent): - self.success_count += 1 - self.logger.info("โœ… Node succeeded: %s", event.node_id) - - if self.include_outputs and event.node_run_result.outputs: - self.logger.debug(" Outputs: %s", self._format_dict(event.node_run_result.outputs)) - - if self.include_process_data and event.node_run_result.process_data: - self.logger.debug(" Process data: %s", self._format_dict(event.node_run_result.process_data)) - - elif isinstance(event, NodeRunFailedEvent): - self.failure_count += 1 - self.logger.error("โŒ Node failed: %s", event.node_id) - self.logger.error(" Error: %s", event.error) - - if event.node_run_result.error: - self.logger.error(" Details: %s", event.node_run_result.error) - - elif isinstance(event, NodeRunExceptionEvent): - self.logger.warning("โš ๏ธ Node exception handled: %s", event.node_id) - self.logger.warning(" Error: %s", event.error) - - elif isinstance(event, NodeRunStreamChunkEvent): - # Log stream chunks at debug level to avoid spam - final_indicator = " (FINAL)" if event.is_final else "" - self.logger.debug( - "๐Ÿ“ Stream chunk from %s%s: %s", event.node_id, final_indicator, self._truncate_value(event.chunk) - ) - - # Iteration events - elif isinstance(event, NodeRunIterationStartedEvent): - self.logger.info("๐Ÿ” Iteration started: %s", event.node_id) - - elif isinstance(event, NodeRunIterationNextEvent): - self.logger.debug(" Iteration next: %s (index: %s)", event.node_id, event.index) - - elif isinstance(event, NodeRunIterationSucceededEvent): - self.logger.info("โœ… Iteration succeeded: %s", event.node_id) - if self.include_outputs and event.outputs: - self.logger.debug(" Outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, NodeRunIterationFailedEvent): - self.logger.error("โŒ Iteration failed: %s", event.node_id) - self.logger.error(" Error: %s", event.error) - - # Loop events - elif isinstance(event, NodeRunLoopStartedEvent): - self.logger.info("๐Ÿ”„ Loop started: %s", event.node_id) - - elif isinstance(event, NodeRunLoopNextEvent): - self.logger.debug(" Loop iteration: %s (index: %s)", event.node_id, event.index) - - elif isinstance(event, NodeRunLoopSucceededEvent): - self.logger.info("โœ… Loop succeeded: %s", event.node_id) - if self.include_outputs and event.outputs: - self.logger.debug(" Outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, NodeRunLoopFailedEvent): - self.logger.error("โŒ Loop failed: %s", event.node_id) - self.logger.error(" Error: %s", event.error) - - else: - # Log unknown events at debug level - self.logger.debug("Event: %s", event_class) - - @override - def on_graph_end(self, error: Exception | None) -> None: - """Log graph execution end with summary statistics.""" - self.logger.info("=" * 80) - - if error: - self.logger.error("๐Ÿ”ด GRAPH EXECUTION FAILED") - self.logger.error(" Error: %s", error) - else: - self.logger.info("๐ŸŽ‰ GRAPH EXECUTION COMPLETED SUCCESSFULLY") - - # Log execution statistics - self.logger.info("Execution Statistics:") - self.logger.info(" Total nodes executed: %s", self.node_count) - self.logger.info(" Successful nodes: %s", self.success_count) - self.logger.info(" Failed nodes: %s", self.failure_count) - self.logger.info(" Node retries: %s", self.retry_count) - - # Log final state if available - if self.include_outputs and self.graph_runtime_state.outputs: - self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs)) - - self.logger.info("=" * 80) diff --git a/api/graphon/graph_engine/layers/execution_limits.py b/api/graphon/graph_engine/layers/execution_limits.py deleted file mode 100644 index 2742b3acd32..00000000000 --- a/api/graphon/graph_engine/layers/execution_limits.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -Execution limits layer for GraphEngine. - -This layer monitors workflow execution to enforce limits on: -- Maximum execution steps -- Maximum execution time - -When limits are exceeded, the layer automatically aborts execution. -""" - -import logging -import time -from enum import StrEnum -from typing import final - -from typing_extensions import override - -from graphon.graph_engine.entities.commands import AbortCommand, CommandType -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import ( - GraphEngineEvent, - NodeRunStartedEvent, -) -from graphon.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent - - -class LimitType(StrEnum): - """Types of execution limits that can be exceeded.""" - - STEP_LIMIT = "step_limit" - TIME_LIMIT = "time_limit" - - -@final -class ExecutionLimitsLayer(GraphEngineLayer): - """ - Layer that enforces execution limits for workflows. - - Monitors: - - Step count: Tracks number of node executions - - Time limit: Monitors total execution time - - Automatically aborts execution when limits are exceeded. - """ - - def __init__(self, max_steps: int, max_time: int) -> None: - """ - Initialize the execution limits layer. - - Args: - max_steps: Maximum number of execution steps allowed - max_time: Maximum execution time in seconds allowed - """ - super().__init__() - self.max_steps = max_steps - self.max_time = max_time - - # Runtime tracking - self.start_time: float | None = None - self.step_count = 0 - self.logger = logging.getLogger(__name__) - - # State tracking - self._execution_started = False - self._execution_ended = False - self._abort_sent = False # Track if abort command has been sent - - @override - def on_graph_start(self) -> None: - """Called when graph execution starts.""" - self.start_time = time.time() - self.step_count = 0 - self._execution_started = True - self._execution_ended = False - self._abort_sent = False - - self.logger.debug("Execution limits monitoring started") - - @override - def on_event(self, event: GraphEngineEvent) -> None: - """ - Called for every event emitted by the engine. - - Monitors execution progress and enforces limits. - """ - if not self._execution_started or self._execution_ended or self._abort_sent: - return - - # Track step count for node execution events - if isinstance(event, NodeRunStartedEvent): - self.step_count += 1 - self.logger.debug("Step %d started: %s", self.step_count, event.node_id) - - # Check step limit when node execution completes - if isinstance(event, NodeRunSucceededEvent | NodeRunFailedEvent): - if self._reached_step_limitation(): - self._send_abort_command(LimitType.STEP_LIMIT) - - if self._reached_time_limitation(): - self._send_abort_command(LimitType.TIME_LIMIT) - - @override - def on_graph_end(self, error: Exception | None) -> None: - """Called when graph execution ends.""" - if self._execution_started and not self._execution_ended: - self._execution_ended = True - - if self.start_time: - total_time = time.time() - self.start_time - self.logger.debug("Execution completed: %d steps in %.2f seconds", self.step_count, total_time) - - def _reached_step_limitation(self) -> bool: - """Check if step count limit has been exceeded.""" - return self.step_count > self.max_steps - - def _reached_time_limitation(self) -> bool: - """Check if time limit has been exceeded.""" - return self.start_time is not None and (time.time() - self.start_time) > self.max_time - - def _send_abort_command(self, limit_type: LimitType) -> None: - """ - Send abort command due to limit violation. - - Args: - limit_type: Type of limit exceeded - """ - if not self.command_channel or not self._execution_started or self._execution_ended or self._abort_sent: - return - - # Format detailed reason message - if limit_type == LimitType.STEP_LIMIT: - reason = f"Maximum execution steps exceeded: {self.step_count} > {self.max_steps}" - elif limit_type == LimitType.TIME_LIMIT: - elapsed_time = time.time() - self.start_time if self.start_time else 0 - reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s" - - self.logger.warning("Execution limit exceeded: %s", reason) - - try: - # Send abort command to the engine - abort_command = AbortCommand(command_type=CommandType.ABORT, reason=reason) - self.command_channel.send_command(abort_command) - - # Mark that abort has been sent to prevent duplicate commands - self._abort_sent = True - - self.logger.debug("Abort command sent to engine") - - except Exception: - self.logger.exception("Failed to send abort command") diff --git a/api/graphon/graph_engine/manager.py b/api/graphon/graph_engine/manager.py deleted file mode 100644 index c728ff6986d..00000000000 --- a/api/graphon/graph_engine/manager.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -GraphEngine Manager for sending control commands via Redis channel. - -This module provides a simplified interface for controlling workflow executions -using the new Redis command channel, without requiring user permission checks. -Callers must provide a Redis client dependency from outside the workflow package. -""" - -import logging -from collections.abc import Sequence -from typing import final - -from graphon.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol -from graphon.graph_engine.entities.commands import ( - AbortCommand, - GraphEngineCommand, - PauseCommand, - UpdateVariablesCommand, - VariableUpdate, -) - -logger = logging.getLogger(__name__) - - -@final -class GraphEngineManager: - """ - Manager for sending control commands to GraphEngine instances. - - This class provides a simple interface for controlling workflow executions - by sending commands through Redis channels, without user validation. - """ - - _redis_client: RedisClientProtocol - - def __init__(self, redis_client: RedisClientProtocol) -> None: - self._redis_client = redis_client - - def send_stop_command(self, task_id: str, reason: str | None = None) -> None: - """ - Send a stop command to a running workflow. - - Args: - task_id: The task ID of the workflow to stop - reason: Optional reason for stopping (defaults to "User requested stop") - """ - abort_command = AbortCommand(reason=reason or "User requested stop") - self._send_command(task_id, abort_command) - - def send_pause_command(self, task_id: str, reason: str | None = None) -> None: - """Send a pause command to a running workflow.""" - - pause_command = PauseCommand(reason=reason or "User requested pause") - self._send_command(task_id, pause_command) - - def send_update_variables_command(self, task_id: str, updates: Sequence[VariableUpdate]) -> None: - """Send a command to update variables in a running workflow.""" - - if not updates: - return - - update_command = UpdateVariablesCommand(updates=updates) - self._send_command(task_id, update_command) - - def _send_command(self, task_id: str, command: GraphEngineCommand) -> None: - """Send a command to the workflow-specific Redis channel.""" - - if not task_id: - return - - channel_key = f"workflow:{task_id}:commands" - channel = RedisChannel(self._redis_client, channel_key) - - try: - channel.send_command(command) - except Exception: - # Silently fail if Redis is unavailable - # The legacy control mechanisms will still work - logger.exception("Failed to send graph engine command %s for task %s", command.__class__.__name__, task_id) diff --git a/api/graphon/graph_engine/orchestration/__init__.py b/api/graphon/graph_engine/orchestration/__init__.py deleted file mode 100644 index de08e942fb3..00000000000 --- a/api/graphon/graph_engine/orchestration/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Orchestration subsystem for graph engine. - -This package coordinates the overall execution flow between -different subsystems. -""" - -from .dispatcher import Dispatcher -from .execution_coordinator import ExecutionCoordinator - -__all__ = [ - "Dispatcher", - "ExecutionCoordinator", -] diff --git a/api/graphon/graph_engine/orchestration/dispatcher.py b/api/graphon/graph_engine/orchestration/dispatcher.py deleted file mode 100644 index f75bbee08e0..00000000000 --- a/api/graphon/graph_engine/orchestration/dispatcher.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Main dispatcher for processing events from workers. -""" - -import logging -import queue -import threading -import time -from typing import TYPE_CHECKING, final - -from graphon.graph_events import ( - GraphNodeEventBase, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunSucceededEvent, -) - -from ..event_management import EventManager -from .execution_coordinator import ExecutionCoordinator - -if TYPE_CHECKING: - from ..event_management import EventHandler - -logger = logging.getLogger(__name__) - - -@final -class Dispatcher: - """ - Main dispatcher that processes events from the event queue. - - This runs in a separate thread and coordinates event processing - with timeout and completion detection. - """ - - _COMMAND_TRIGGER_EVENTS = ( - NodeRunSucceededEvent, - NodeRunFailedEvent, - NodeRunExceptionEvent, - ) - - def __init__( - self, - event_queue: queue.Queue[GraphNodeEventBase], - event_handler: "EventHandler", - execution_coordinator: ExecutionCoordinator, - event_emitter: EventManager | None = None, - ) -> None: - """ - Initialize the dispatcher. - - Args: - event_queue: Queue of events from workers - event_handler: Event handler registry for processing events - execution_coordinator: Coordinator for execution flow - event_emitter: Optional event manager to signal completion - """ - self._event_queue = event_queue - self._event_handler = event_handler - self._execution_coordinator = execution_coordinator - self._event_emitter = event_emitter - - self._thread: threading.Thread | None = None - self._stop_event = threading.Event() - self._start_time: float | None = None - - def start(self) -> None: - """Start the dispatcher thread.""" - if self._thread and self._thread.is_alive(): - return - - self._stop_event.clear() - self._start_time = time.time() - self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True) - self._thread.start() - - def stop(self) -> None: - """Stop the dispatcher thread.""" - self._stop_event.set() - if self._thread and self._thread.is_alive(): - self._thread.join(timeout=2.0) - - def _dispatcher_loop(self) -> None: - """Main dispatcher loop.""" - try: - self._process_commands() - paused = False - while not self._stop_event.is_set(): - if self._execution_coordinator.aborted or self._execution_coordinator.execution_complete: - break - if self._execution_coordinator.paused: - paused = True - break - - self._execution_coordinator.check_scaling() - try: - event = self._event_queue.get(timeout=0.1) - self._event_handler.dispatch(event) - self._event_queue.task_done() - self._process_commands(event) - except queue.Empty: - time.sleep(0.1) - - self._process_commands() - if paused: - self._drain_events_until_idle() - else: - self._drain_event_queue() - - except Exception as e: - logger.exception("Dispatcher error") - self._execution_coordinator.mark_failed(e) - - finally: - self._execution_coordinator.mark_complete() - # Signal the event emitter that execution is complete - if self._event_emitter: - self._event_emitter.mark_complete() - - def _process_commands(self, event: GraphNodeEventBase | None = None): - if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS): - self._execution_coordinator.process_commands() - - def _drain_event_queue(self) -> None: - while True: - try: - event = self._event_queue.get(block=False) - self._event_handler.dispatch(event) - self._event_queue.task_done() - except queue.Empty: - break - - def _drain_events_until_idle(self) -> None: - while not self._stop_event.is_set(): - try: - event = self._event_queue.get(timeout=0.1) - self._event_handler.dispatch(event) - self._event_queue.task_done() - self._process_commands(event) - except queue.Empty: - if not self._execution_coordinator.has_executing_nodes(): - break - self._drain_event_queue() diff --git a/api/graphon/graph_engine/orchestration/execution_coordinator.py b/api/graphon/graph_engine/orchestration/execution_coordinator.py deleted file mode 100644 index 0f8550eb123..00000000000 --- a/api/graphon/graph_engine/orchestration/execution_coordinator.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -Execution coordinator for managing overall workflow execution. -""" - -from typing import final - -from ..command_processing import CommandProcessor -from ..domain import GraphExecution -from ..graph_state_manager import GraphStateManager -from ..worker_management import WorkerPool - - -@final -class ExecutionCoordinator: - """ - Coordinates overall execution flow between subsystems. - - This provides high-level coordination methods used by the - dispatcher to manage execution state. - """ - - def __init__( - self, - graph_execution: GraphExecution, - state_manager: GraphStateManager, - command_processor: CommandProcessor, - worker_pool: WorkerPool, - ) -> None: - """ - Initialize the execution coordinator. - - Args: - graph_execution: Graph execution aggregate - state_manager: Unified state manager - command_processor: Processor for commands - worker_pool: Pool of workers - """ - self._graph_execution = graph_execution - self._state_manager = state_manager - self._command_processor = command_processor - self._worker_pool = worker_pool - - def process_commands(self) -> None: - """Process any pending commands.""" - self._command_processor.process_commands() - - def check_scaling(self) -> None: - """Check and perform worker scaling if needed.""" - self._worker_pool.check_and_scale() - - @property - def execution_complete(self): - return self._state_manager.is_execution_complete() - - @property - def aborted(self): - return self._graph_execution.aborted or self._graph_execution.has_error - - @property - def paused(self) -> bool: - """Expose whether the underlying graph execution is paused.""" - return self._graph_execution.is_paused - - def mark_complete(self) -> None: - """Mark execution as complete.""" - if self._graph_execution.is_paused: - return - if not self._graph_execution.completed: - self._graph_execution.complete() - - def mark_failed(self, error: Exception) -> None: - """ - Mark execution as failed. - - Args: - error: The error that caused failure - """ - self._graph_execution.fail(error) - - def handle_pause_if_needed(self) -> None: - """If the execution has been paused, stop workers immediately.""" - - if not self._graph_execution.is_paused: - return - - self._worker_pool.stop() - self._state_manager.clear_executing() - - def handle_abort_if_needed(self) -> None: - """If the execution has been aborted, stop workers immediately.""" - - if not self._graph_execution.aborted: - return - - self._worker_pool.stop() - self._state_manager.clear_executing() - - def has_executing_nodes(self) -> bool: - """Return True if any nodes are currently marked as executing.""" - # This check is only safe once execution has already paused. - # Before pause, executing state can change concurrently, which makes the result unreliable. - if not self._graph_execution.is_paused: - raise AssertionError("has_executing_nodes should only be called after execution is paused") - return self._state_manager.get_executing_count() > 0 diff --git a/api/graphon/graph_engine/protocols/command_channel.py b/api/graphon/graph_engine/protocols/command_channel.py deleted file mode 100644 index fabd8634c8b..00000000000 --- a/api/graphon/graph_engine/protocols/command_channel.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -CommandChannel protocol for GraphEngine command communication. - -This protocol defines the interface for sending and receiving commands -to/from a GraphEngine instance, supporting both local and distributed scenarios. -""" - -from typing import Protocol - -from ..entities.commands import GraphEngineCommand - - -class CommandChannel(Protocol): - """ - Protocol for bidirectional command communication with GraphEngine. - - Since each GraphEngine instance processes only one workflow execution, - this channel is dedicated to that single execution. - """ - - def fetch_commands(self) -> list[GraphEngineCommand]: - """ - Fetch pending commands for this GraphEngine instance. - - Called by GraphEngine to poll for commands that need to be processed. - - Returns: - List of pending commands (may be empty) - """ - ... - - def send_command(self, command: GraphEngineCommand) -> None: - """ - Send a command to be processed by this GraphEngine instance. - - Called by external systems to send control commands to the running workflow. - - Args: - command: The command to send - """ - ... diff --git a/api/graphon/graph_engine/ready_queue/__init__.py b/api/graphon/graph_engine/ready_queue/__init__.py deleted file mode 100644 index acba0e961c8..00000000000 --- a/api/graphon/graph_engine/ready_queue/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Ready queue implementations for GraphEngine. - -This package contains the protocol and implementations for managing -the queue of nodes ready for execution. -""" - -from .factory import create_ready_queue_from_state -from .in_memory import InMemoryReadyQueue -from .protocol import ReadyQueue, ReadyQueueState - -__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"] diff --git a/api/graphon/graph_engine/ready_queue/factory.py b/api/graphon/graph_engine/ready_queue/factory.py deleted file mode 100644 index a9d4f470e53..00000000000 --- a/api/graphon/graph_engine/ready_queue/factory.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -Factory for creating ReadyQueue instances from serialized state. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from .in_memory import InMemoryReadyQueue -from .protocol import ReadyQueueState - -if TYPE_CHECKING: - from .protocol import ReadyQueue - - -def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue: - """ - Create a ReadyQueue instance from a serialized state. - - Args: - state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue - - Returns: - A ReadyQueue instance initialized with the given state - - Raises: - ValueError: If the queue type is unknown or version is unsupported - """ - if state.type == "InMemoryReadyQueue": - if state.version != "1.0": - raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}") - queue = InMemoryReadyQueue() - # Always pass as JSON string to loads() - queue.loads(state.model_dump_json()) - return queue - else: - raise ValueError(f"Unknown ready queue type: {state.type}") diff --git a/api/graphon/graph_engine/ready_queue/in_memory.py b/api/graphon/graph_engine/ready_queue/in_memory.py deleted file mode 100644 index f2c265ece09..00000000000 --- a/api/graphon/graph_engine/ready_queue/in_memory.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -In-memory implementation of the ReadyQueue protocol. - -This implementation wraps Python's standard queue.Queue and adds -serialization capabilities for state storage. -""" - -import queue -from typing import final - -from .protocol import ReadyQueue, ReadyQueueState - - -@final -class InMemoryReadyQueue(ReadyQueue): - """ - In-memory ready queue implementation with serialization support. - - This implementation uses Python's queue.Queue internally and provides - methods to serialize and restore the queue state. - """ - - def __init__(self, maxsize: int = 0) -> None: - """ - Initialize the in-memory ready queue. - - Args: - maxsize: Maximum size of the queue (0 for unlimited) - """ - self._queue: queue.Queue[str] = queue.Queue(maxsize=maxsize) - - def put(self, item: str) -> None: - """ - Add a node ID to the ready queue. - - Args: - item: The node ID to add to the queue - """ - self._queue.put(item) - - def get(self, timeout: float | None = None) -> str: - """ - Retrieve and remove a node ID from the queue. - - Args: - timeout: Maximum time to wait for an item (None for blocking) - - Returns: - The node ID retrieved from the queue - - Raises: - queue.Empty: If timeout expires and no item is available - """ - if timeout is None: - return self._queue.get(block=True) - return self._queue.get(timeout=timeout) - - def task_done(self) -> None: - """ - Indicate that a previously retrieved task is complete. - - Used by worker threads to signal task completion for - join() synchronization. - """ - self._queue.task_done() - - def empty(self) -> bool: - """ - Check if the queue is empty. - - Returns: - True if the queue has no items, False otherwise - """ - return self._queue.empty() - - def qsize(self) -> int: - """ - Get the approximate size of the queue. - - Returns: - The approximate number of items in the queue - """ - return self._queue.qsize() - - def dumps(self) -> str: - """ - Serialize the queue state to a JSON string for storage. - - Returns: - A JSON string containing the serialized queue state - """ - # Extract all items from the queue without removing them - items: list[str] = [] - temp_items: list[str] = [] - - # Drain the queue temporarily to get all items - while not self._queue.empty(): - try: - item = self._queue.get_nowait() - temp_items.append(item) - items.append(item) - except queue.Empty: - break - - # Put items back in the same order - for item in temp_items: - self._queue.put(item) - - state = ReadyQueueState( - type="InMemoryReadyQueue", - version="1.0", - items=items, - ) - return state.model_dump_json() - - def loads(self, data: str) -> None: - """ - Restore the queue state from a JSON string. - - Args: - data: The JSON string containing the serialized queue state to restore - """ - state = ReadyQueueState.model_validate_json(data) - - if state.type != "InMemoryReadyQueue": - raise ValueError(f"Invalid serialized data type: {state.type}") - - if state.version != "1.0": - raise ValueError(f"Unsupported version: {state.version}") - - # Clear the current queue - while not self._queue.empty(): - try: - self._queue.get_nowait() - except queue.Empty: - break - - # Restore items - for item in state.items: - self._queue.put(item) diff --git a/api/graphon/graph_engine/ready_queue/protocol.py b/api/graphon/graph_engine/ready_queue/protocol.py deleted file mode 100644 index 97d3ea6dd2c..00000000000 --- a/api/graphon/graph_engine/ready_queue/protocol.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -ReadyQueue protocol for GraphEngine node execution queue. - -This protocol defines the interface for managing the queue of nodes ready -for execution, supporting both in-memory and persistent storage scenarios. -""" - -from collections.abc import Sequence -from typing import Protocol - -from pydantic import BaseModel, Field - - -class ReadyQueueState(BaseModel): - """ - Pydantic model for serialized ready queue state. - - This defines the structure of the data returned by dumps() - and expected by loads() for ready queue serialization. - """ - - type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')") - version: str = Field(description="Serialization format version") - items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue") - - -class ReadyQueue(Protocol): - """ - Protocol for managing nodes ready for execution in GraphEngine. - - This protocol defines the interface that any ready queue implementation - must provide, enabling both in-memory queues and persistent queues - that can be serialized for state storage. - """ - - def put(self, item: str) -> None: - """ - Add a node ID to the ready queue. - - Args: - item: The node ID to add to the queue - """ - ... - - def get(self, timeout: float | None = None) -> str: - """ - Retrieve and remove a node ID from the queue. - - Args: - timeout: Maximum time to wait for an item (None for blocking) - - Returns: - The node ID retrieved from the queue - - Raises: - queue.Empty: If timeout expires and no item is available - """ - ... - - def task_done(self) -> None: - """ - Indicate that a previously retrieved task is complete. - - Used by worker threads to signal task completion for - join() synchronization. - """ - ... - - def empty(self) -> bool: - """ - Check if the queue is empty. - - Returns: - True if the queue has no items, False otherwise - """ - ... - - def qsize(self) -> int: - """ - Get the approximate size of the queue. - - Returns: - The approximate number of items in the queue - """ - ... - - def dumps(self) -> str: - """ - Serialize the queue state to a JSON string for storage. - - Returns: - A JSON string containing the serialized queue state - that can be persisted and later restored - """ - ... - - def loads(self, data: str) -> None: - """ - Restore the queue state from a JSON string. - - Args: - data: The JSON string containing the serialized queue state to restore - """ - ... diff --git a/api/graphon/graph_engine/response_coordinator/__init__.py b/api/graphon/graph_engine/response_coordinator/__init__.py deleted file mode 100644 index e11d31199c2..00000000000 --- a/api/graphon/graph_engine/response_coordinator/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -ResponseStreamCoordinator - Coordinates streaming output from response nodes - -This component manages response streaming sessions and ensures ordered streaming -of responses based on upstream node outputs and constants. -""" - -from .coordinator import ResponseStreamCoordinator - -__all__ = ["ResponseStreamCoordinator"] diff --git a/api/graphon/graph_engine/response_coordinator/coordinator.py b/api/graphon/graph_engine/response_coordinator/coordinator.py deleted file mode 100644 index a6562f02232..00000000000 --- a/api/graphon/graph_engine/response_coordinator/coordinator.py +++ /dev/null @@ -1,697 +0,0 @@ -""" -Main ResponseStreamCoordinator implementation. - -This module contains the public ResponseStreamCoordinator class that manages -response streaming sessions and ensures ordered streaming of responses. -""" - -import logging -from collections import deque -from collections.abc import Sequence -from threading import RLock -from typing import Literal, TypeAlias, final -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from graphon.enums import NodeExecutionType, NodeState -from graphon.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent -from graphon.nodes.base.template import TextSegment, VariableSegment -from graphon.runtime import VariablePool -from graphon.runtime.graph_runtime_state import GraphProtocol - -from .path import Path -from .session import ResponseSession - -logger = logging.getLogger(__name__) - -# Type definitions -NodeID: TypeAlias = str -EdgeID: TypeAlias = str - - -class ResponseSessionState(BaseModel): - """Serializable representation of a response session.""" - - node_id: str - index: int = Field(default=0, ge=0) - - -class StreamBufferState(BaseModel): - """Serializable representation of buffered stream chunks.""" - - selector: tuple[str, ...] - events: list[NodeRunStreamChunkEvent] = Field(default_factory=list) - - -class StreamPositionState(BaseModel): - """Serializable representation for stream read positions.""" - - selector: tuple[str, ...] - position: int = Field(default=0, ge=0) - - -class ResponseStreamCoordinatorState(BaseModel): - """Serialized snapshot of ResponseStreamCoordinator.""" - - type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator") - version: str = Field(default="1.0") - response_nodes: Sequence[str] = Field(default_factory=list) - active_session: ResponseSessionState | None = None - waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list) - pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list) - node_execution_ids: dict[str, str] = Field(default_factory=dict) - paths_map: dict[str, list[list[str]]] = Field(default_factory=dict) - stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list) - stream_positions: Sequence[StreamPositionState] = Field(default_factory=list) - closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list) - - -@final -class ResponseStreamCoordinator: - """ - Manages response streaming sessions without relying on global state. - - Ensures ordered streaming of responses based on upstream node outputs and constants. - """ - - def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None: - """ - Initialize coordinator with variable pool. - - Args: - variable_pool: VariablePool instance for accessing node variables - graph: Graph instance for looking up node information - """ - self._variable_pool = variable_pool - self._graph = graph - self._active_session: ResponseSession | None = None - self._waiting_sessions: deque[ResponseSession] = deque() - self._lock = RLock() - - # Internal stream management (replacing OutputRegistry) - self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {} - self._stream_positions: dict[tuple[str, ...], int] = {} - self._closed_streams: set[tuple[str, ...]] = set() - - # Track response nodes - self._response_nodes: set[NodeID] = set() - - # Store paths for each response node - self._paths_maps: dict[NodeID, list[Path]] = {} - - # Track node execution IDs and types for proper event forwarding - self._node_execution_ids: dict[NodeID, str] = {} # node_id -> execution_id - - # Track response sessions to ensure only one per node - self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session - - def register(self, response_node_id: NodeID) -> None: - with self._lock: - if response_node_id in self._response_nodes: - return - self._response_nodes.add(response_node_id) - - # Build and save paths map for this response node - paths_map = self._build_paths_map(response_node_id) - self._paths_maps[response_node_id] = paths_map - - # Create and store response session for this node - response_node = self._graph.nodes[response_node_id] - session = ResponseSession.from_node(response_node) - self._response_sessions[response_node_id] = session - - def track_node_execution(self, node_id: NodeID, execution_id: str) -> None: - """Track the execution ID for a node when it starts executing. - - Args: - node_id: The ID of the node - execution_id: The execution ID from NodeRunStartedEvent - """ - with self._lock: - self._node_execution_ids[node_id] = execution_id - - def _get_or_create_execution_id(self, node_id: NodeID) -> str: - """Get the execution ID for a node, creating one if it doesn't exist. - - Args: - node_id: The ID of the node - - Returns: - The execution ID for the node - """ - with self._lock: - if node_id not in self._node_execution_ids: - self._node_execution_ids[node_id] = str(uuid4()) - return self._node_execution_ids[node_id] - - def _build_paths_map(self, response_node_id: NodeID) -> list[Path]: - """ - Build a paths map for a response node by finding all paths from root node - to the response node, recording branch edges along each path. - - Args: - response_node_id: ID of the response node to analyze - - Returns: - List of Path objects, where each path contains branch edge IDs - """ - # Get root node ID - root_node_id = self._graph.root_node.id - - # If root is the response node, return empty path - if root_node_id == response_node_id: - return [Path()] - - # Extract variable selectors from the response node's template - response_node = self._graph.nodes[response_node_id] - response_session = ResponseSession.from_node(response_node) - template = response_session.template - - # Collect all variable selectors from the template - variable_selectors: set[tuple[str, ...]] = set() - for segment in template.segments: - if isinstance(segment, VariableSegment): - variable_selectors.add(tuple(segment.selector[:2])) - - # Step 1: Find all complete paths from root to response node - all_complete_paths: list[list[EdgeID]] = [] - - def find_paths( - current_node_id: NodeID, target_node_id: NodeID, current_path: list[EdgeID], visited: set[NodeID] - ) -> None: - """Recursively find all paths from current node to target node.""" - if current_node_id == target_node_id: - # Found a complete path, store it - all_complete_paths.append(current_path.copy()) - return - - # Mark as visited to avoid cycles - visited.add(current_node_id) - - # Explore outgoing edges - outgoing_edges = self._graph.get_outgoing_edges(current_node_id) - for edge in outgoing_edges: - edge_id = edge.id - next_node_id = edge.head - - # Skip if already visited in this path - if next_node_id not in visited: - # Add edge to path and recurse - new_path = current_path + [edge_id] - find_paths(next_node_id, target_node_id, new_path, visited.copy()) - - # Start searching from root node - find_paths(root_node_id, response_node_id, [], set()) - - # Step 2: For each complete path, filter edges based on node blocking behavior - filtered_paths: list[Path] = [] - for path in all_complete_paths: - blocking_edges: list[str] = [] - for edge_id in path: - edge = self._graph.edges[edge_id] - source_node = self._graph.nodes[edge.tail] - - # Check if node is a branch, container, or response node - if source_node.execution_type in { - NodeExecutionType.BRANCH, - NodeExecutionType.CONTAINER, - NodeExecutionType.RESPONSE, - } or source_node.blocks_variable_output(variable_selectors): - blocking_edges.append(edge_id) - - # Keep the path even if it's empty - filtered_paths.append(Path(edges=blocking_edges)) - - return filtered_paths - - def on_edge_taken(self, edge_id: str) -> Sequence[NodeRunStreamChunkEvent]: - """ - Handle when an edge is taken (selected by a branch node). - - This method updates the paths for all response nodes by removing - the taken edge. If any response node has an empty path after removal, - it means the node is now deterministically reachable and should start. - - Args: - edge_id: The ID of the edge that was taken - - Returns: - List of events to emit from starting new sessions - """ - events: list[NodeRunStreamChunkEvent] = [] - - with self._lock: - # Check each response node in order - for response_node_id in self._response_nodes: - if response_node_id not in self._paths_maps: - continue - - paths = self._paths_maps[response_node_id] - has_reachable_path = False - - # Update each path by removing the taken edge - for path in paths: - # Remove the taken edge from this path - path.remove_edge(edge_id) - - # Check if this path is now empty (node is reachable) - if path.is_empty(): - has_reachable_path = True - - # If node is now reachable (has empty path), start/queue session - if has_reachable_path: - # Pass the node_id to the activation method - # The method will handle checking and removing from map - events.extend(self._active_or_queue_session(response_node_id)) - return events - - def _active_or_queue_session(self, node_id: str) -> Sequence[NodeRunStreamChunkEvent]: - """ - Start a session immediately if no active session, otherwise queue it. - Only activates sessions that exist in the _response_sessions map. - - Args: - node_id: The ID of the response node to activate - - Returns: - List of events from flush attempt if session started immediately - """ - events: list[NodeRunStreamChunkEvent] = [] - - # Get the session from our map (only activate if it exists) - session = self._response_sessions.get(node_id) - if not session: - return events - - # Remove from map to ensure it won't be activated again - del self._response_sessions[node_id] - - if self._active_session is None: - self._active_session = session - - # Try to flush immediately - events.extend(self.try_flush()) - else: - # Queue the session if another is active - self._waiting_sessions.append(session) - - return events - - def intercept_event( - self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent - ) -> Sequence[NodeRunStreamChunkEvent]: - with self._lock: - if isinstance(event, NodeRunStreamChunkEvent): - self._append_stream_chunk(event.selector, event) - if event.is_final: - self._close_stream(event.selector) - return self.try_flush() - else: - # Skip cause we share the same variable pool. - # - # for variable_name, variable_value in event.node_run_result.outputs.items(): - # self._variable_pool.add((event.node_id, variable_name), variable_value) - return self.try_flush() - - def _create_stream_chunk_event( - self, - node_id: str, - execution_id: str, - selector: Sequence[str], - chunk: str, - is_final: bool = False, - ) -> NodeRunStreamChunkEvent: - """Create a stream chunk event with consistent structure. - - For selectors with special prefixes (sys, env, conversation), we use the - active response node's information since these are not actual node IDs. - """ - # Check if this is a special selector that doesn't correspond to a node - if selector and selector[0] not in self._graph.nodes and self._active_session: - # Use the active response node for special selectors - response_node = self._graph.nodes[self._active_session.node_id] - return NodeRunStreamChunkEvent( - id=execution_id, - node_id=response_node.id, - node_type=response_node.node_type, - selector=selector, - chunk=chunk, - is_final=is_final, - ) - - # Standard case: selector refers to an actual node - node = self._graph.nodes[node_id] - return NodeRunStreamChunkEvent( - id=execution_id, - node_id=node.id, - node_type=node.node_type, - selector=selector, - chunk=chunk, - is_final=is_final, - ) - - def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]: - """Process a variable segment. Returns (events, is_complete). - - Handles both regular node selectors and special system selectors (sys, env, conversation). - For special selectors, we attribute the output to the active response node. - """ - events: list[NodeRunStreamChunkEvent] = [] - source_selector_prefix = segment.selector[0] if segment.selector else "" - is_complete = False - - # Determine which node to attribute the output to - # For special selectors (sys, env, conversation), use the active response node - # For regular selectors, use the source node - if self._active_session and source_selector_prefix not in self._graph.nodes: - # Special selector - use active response node - output_node_id = self._active_session.node_id - else: - # Regular node selector - output_node_id = source_selector_prefix - execution_id = self._get_or_create_execution_id(output_node_id) - - # Stream all available chunks - while self._has_unread_stream(segment.selector): - if event := self._pop_stream_chunk(segment.selector): - # For special selectors, we need to update the event to use - # the active response node's information - if self._active_session and source_selector_prefix not in self._graph.nodes: - response_node = self._graph.nodes[self._active_session.node_id] - # Create a new event with the response node's information - # but keep the original selector - updated_event = NodeRunStreamChunkEvent( - id=execution_id, - node_id=response_node.id, - node_type=response_node.node_type, - selector=event.selector, # Keep original selector - chunk=event.chunk, - is_final=event.is_final, - ) - events.append(updated_event) - else: - # Regular node selector - use event as is - events.append(event) - - # Check if this is the last chunk by looking ahead - stream_closed = self._is_stream_closed(segment.selector) - # Check if stream is closed to determine if segment is complete - if stream_closed: - is_complete = True - - elif value := self._variable_pool.get(segment.selector): - # Process scalar value - is_last_segment = bool( - self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1 - ) - events.append( - self._create_stream_chunk_event( - node_id=output_node_id, - execution_id=execution_id, - selector=segment.selector, - chunk=value.markdown, - is_final=is_last_segment, - ) - ) - is_complete = True - - return events, is_complete - - def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]: - """Process a text segment. Returns (events, is_complete).""" - assert self._active_session is not None - current_response_node = self._graph.nodes[self._active_session.node_id] - - # Use get_or_create_execution_id to ensure we have a consistent ID - execution_id = self._get_or_create_execution_id(current_response_node.id) - - is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1 - event = self._create_stream_chunk_event( - node_id=current_response_node.id, - execution_id=execution_id, - selector=[current_response_node.id, "answer"], # FIXME(-LAN-) - chunk=segment.text, - is_final=is_last_segment, - ) - return [event] - - def try_flush(self) -> list[NodeRunStreamChunkEvent]: - with self._lock: - if not self._active_session: - return [] - - template = self._active_session.template - response_node_id = self._active_session.node_id - - events: list[NodeRunStreamChunkEvent] = [] - - # Process segments sequentially from current index - while self._active_session.index < len(template.segments): - segment = template.segments[self._active_session.index] - - if isinstance(segment, VariableSegment): - # Check if the source node for this variable is skipped - # Only check for actual nodes, not special selectors (sys, env, conversation) - source_selector_prefix = segment.selector[0] if segment.selector else "" - if source_selector_prefix in self._graph.nodes: - source_node = self._graph.nodes[source_selector_prefix] - - if source_node.state == NodeState.SKIPPED: - # Skip this variable segment if the source node is skipped - self._active_session.index += 1 - continue - - segment_events, is_complete = self._process_variable_segment(segment) - events.extend(segment_events) - - # Only advance index if this variable segment is complete - if is_complete: - self._active_session.index += 1 - else: - # Wait for more data - break - - else: - segment_events = self._process_text_segment(segment) - events.extend(segment_events) - self._active_session.index += 1 - - if self._active_session.is_complete(): - # End current session and get events from starting next session - next_session_events = self.end_session(response_node_id) - events.extend(next_session_events) - - return events - - def end_session(self, node_id: str) -> list[NodeRunStreamChunkEvent]: - """ - End the active session for a response node. - Automatically starts the next waiting session if available. - - Args: - node_id: ID of the response node ending its session - - Returns: - List of events from starting the next session - """ - with self._lock: - events: list[NodeRunStreamChunkEvent] = [] - - if self._active_session and self._active_session.node_id == node_id: - self._active_session = None - - # Try to start next waiting session - if self._waiting_sessions: - next_session = self._waiting_sessions.popleft() - self._active_session = next_session - - # Immediately try to flush any available segments - events = self.try_flush() - - return events - - # ============= Internal Stream Management Methods ============= - - def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None: - """ - Append a stream chunk to the internal buffer. - - Args: - selector: List of strings identifying the stream location - event: The NodeRunStreamChunkEvent to append - - Raises: - ValueError: If the stream is already closed - """ - key = tuple(selector) - - if key in self._closed_streams: - raise ValueError(f"Stream {'.'.join(selector)} is already closed") - - if key not in self._stream_buffers: - self._stream_buffers[key] = [] - self._stream_positions[key] = 0 - - self._stream_buffers[key].append(event) - - def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None: - """ - Pop the next unread stream chunk from the buffer. - - Args: - selector: List of strings identifying the stream location - - Returns: - The next event, or None if no unread events available - """ - key = tuple(selector) - - if key not in self._stream_buffers: - return None - - position = self._stream_positions.get(key, 0) - buffer = self._stream_buffers[key] - - if position >= len(buffer): - return None - - event = buffer[position] - self._stream_positions[key] = position + 1 - return event - - def _has_unread_stream(self, selector: Sequence[str]) -> bool: - """ - Check if the stream has unread events. - - Args: - selector: List of strings identifying the stream location - - Returns: - True if there are unread events, False otherwise - """ - key = tuple(selector) - - if key not in self._stream_buffers: - return False - - position = self._stream_positions.get(key, 0) - return position < len(self._stream_buffers[key]) - - def _close_stream(self, selector: Sequence[str]) -> None: - """ - Mark a stream as closed (no more chunks can be appended). - - Args: - selector: List of strings identifying the stream location - """ - key = tuple(selector) - self._closed_streams.add(key) - - def _is_stream_closed(self, selector: Sequence[str]) -> bool: - """ - Check if a stream is closed. - - Args: - selector: List of strings identifying the stream location - - Returns: - True if the stream is closed, False otherwise - """ - key = tuple(selector) - return key in self._closed_streams - - def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None: - """Convert an in-memory session into its serializable form.""" - - if session is None: - return None - return ResponseSessionState(node_id=session.node_id, index=session.index) - - def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession: - """Rebuild a response session from serialized data.""" - - node = self._graph.nodes.get(session_state.node_id) - if node is None: - raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state") - - session = ResponseSession.from_node(node) - session.index = session_state.index - return session - - def dumps(self) -> str: - """Serialize coordinator state to JSON.""" - - with self._lock: - state = ResponseStreamCoordinatorState( - response_nodes=sorted(self._response_nodes), - active_session=self._serialize_session(self._active_session), - waiting_sessions=[ - session_state - for session in list(self._waiting_sessions) - if (session_state := self._serialize_session(session)) is not None - ], - pending_sessions=[ - session_state - for _, session in sorted(self._response_sessions.items()) - if (session_state := self._serialize_session(session)) is not None - ], - node_execution_ids=dict(sorted(self._node_execution_ids.items())), - paths_map={ - node_id: [path.edges.copy() for path in paths] - for node_id, paths in sorted(self._paths_maps.items()) - }, - stream_buffers=[ - StreamBufferState( - selector=selector, - events=[event.model_copy(deep=True) for event in events], - ) - for selector, events in sorted(self._stream_buffers.items()) - ], - stream_positions=[ - StreamPositionState(selector=selector, position=position) - for selector, position in sorted(self._stream_positions.items()) - ], - closed_streams=sorted(self._closed_streams), - ) - return state.model_dump_json() - - def loads(self, data: str) -> None: - """Restore coordinator state from JSON.""" - - state = ResponseStreamCoordinatorState.model_validate_json(data) - - if state.type != "ResponseStreamCoordinator": - raise ValueError(f"Invalid serialized data type: {state.type}") - - if state.version != "1.0": - raise ValueError(f"Unsupported serialized version: {state.version}") - - with self._lock: - self._response_nodes = set(state.response_nodes) - self._paths_maps = { - node_id: [Path(edges=list(path_edges)) for path_edges in paths] - for node_id, paths in state.paths_map.items() - } - self._node_execution_ids = dict(state.node_execution_ids) - - self._stream_buffers = { - tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events] - for buffer in state.stream_buffers - } - self._stream_positions = { - tuple(position.selector): position.position for position in state.stream_positions - } - for selector in self._stream_buffers: - self._stream_positions.setdefault(selector, 0) - - self._closed_streams = {tuple(selector) for selector in state.closed_streams} - - self._waiting_sessions = deque( - self._session_from_state(session_state) for session_state in state.waiting_sessions - ) - self._response_sessions = { - session_state.node_id: self._session_from_state(session_state) - for session_state in state.pending_sessions - } - self._active_session = self._session_from_state(state.active_session) if state.active_session else None diff --git a/api/graphon/graph_engine/response_coordinator/path.py b/api/graphon/graph_engine/response_coordinator/path.py deleted file mode 100644 index 50f2f4eb217..00000000000 --- a/api/graphon/graph_engine/response_coordinator/path.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Internal path representation for response coordinator. - -This module contains the private Path class used internally by ResponseStreamCoordinator -to track execution paths to response nodes. -""" - -from dataclasses import dataclass, field -from typing import TypeAlias - -EdgeID: TypeAlias = str - - -@dataclass -class Path: - """ - Represents a path of branch edges that must be taken to reach a response node. - - Note: This is an internal class not exposed in the public API. - """ - - edges: list[EdgeID] = field(default_factory=list[EdgeID]) - - def contains_edge(self, edge_id: EdgeID) -> bool: - """Check if this path contains the given edge.""" - return edge_id in self.edges - - def remove_edge(self, edge_id: EdgeID) -> None: - """Remove the given edge from this path in place.""" - if self.contains_edge(edge_id): - self.edges.remove(edge_id) - - def is_empty(self) -> bool: - """Check if the path has no edges (node is reachable).""" - return len(self.edges) == 0 diff --git a/api/graphon/graph_engine/response_coordinator/session.py b/api/graphon/graph_engine/response_coordinator/session.py deleted file mode 100644 index cb877f15046..00000000000 --- a/api/graphon/graph_engine/response_coordinator/session.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Internal response session management for response coordinator. - -This module contains the private ResponseSession class used internally -by ResponseStreamCoordinator to manage streaming sessions. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Protocol, cast - -from graphon.nodes.base.template import Template -from graphon.runtime.graph_runtime_state import NodeProtocol - - -class _ResponseSessionNodeProtocol(NodeProtocol, Protocol): - """Structural contract required from nodes that can open a response session.""" - - def get_streaming_template(self) -> Template: ... - - -@dataclass -class ResponseSession: - """ - Represents an active response streaming session. - - Note: This is an internal class not exposed in the public API. - """ - - node_id: str - template: Template # Template object from the response node - index: int = 0 # Current position in the template segments - - @classmethod - def from_node(cls, node: NodeProtocol) -> ResponseSession: - """ - Create a ResponseSession from a response-capable node. - - The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer. - At runtime this must be a node that implements `get_streaming_template()`. The coordinator decides which - graph nodes should be treated as response-capable before they reach this factory. - - Args: - node: Node from the materialized workflow graph. - - Returns: - ResponseSession configured with the node's streaming template - - Raises: - TypeError: If node does not implement the response-session streaming contract. - """ - response_node = cast(_ResponseSessionNodeProtocol, node) - try: - template = response_node.get_streaming_template() - except AttributeError as exc: - raise TypeError("ResponseSession.from_node requires get_streaming_template() on response nodes") from exc - - return cls( - node_id=node.id, - template=template, - ) - - def is_complete(self) -> bool: - """Check if all segments in the template have been processed.""" - return self.index >= len(self.template.segments) diff --git a/api/graphon/graph_engine/worker.py b/api/graphon/graph_engine/worker.py deleted file mode 100644 index a0844ee48ea..00000000000 --- a/api/graphon/graph_engine/worker.py +++ /dev/null @@ -1,204 +0,0 @@ -""" -Worker - Thread implementation for queue-based node execution - -Workers pull node IDs from the ready_queue, execute nodes, and push events -to the event_queue for the dispatcher to process. -""" - -import queue -import threading -import time -from collections.abc import Sequence -from contextlib import AbstractContextManager -from datetime import UTC, datetime -from typing import TYPE_CHECKING, final - -from typing_extensions import override - -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node - -from .ready_queue import ReadyQueue - -if TYPE_CHECKING: - pass - - -@final -class Worker(threading.Thread): - """ - Worker thread that executes nodes from the ready queue. - - Workers continuously pull node IDs from the ready_queue, execute the - corresponding nodes, and push the resulting events to the event_queue - for the dispatcher to process. - """ - - def __init__( - self, - ready_queue: ReadyQueue, - event_queue: queue.Queue[GraphNodeEventBase], - graph: Graph, - layers: Sequence[GraphEngineLayer], - worker_id: int = 0, - execution_context: AbstractContextManager[object] | None = None, - ) -> None: - """ - Initialize worker thread. - - Args: - ready_queue: Ready queue containing node IDs ready for execution - event_queue: Queue for pushing execution events - graph: Graph containing nodes to execute - layers: Graph engine layers for node execution hooks - worker_id: Unique identifier for this worker - execution_context: Optional execution context for context preservation - """ - super().__init__(name=f"GraphWorker-{worker_id}", daemon=True) - self._ready_queue = ready_queue - self._event_queue = event_queue - self._graph = graph - self._worker_id = worker_id - self._execution_context = execution_context - self._stop_event = threading.Event() - self._layers = layers if layers is not None else [] - self._last_task_time = time.time() - self._current_node_started_at: datetime | None = None - - def stop(self) -> None: - """Signal the worker to stop processing.""" - self._stop_event.set() - - @property - def is_idle(self) -> bool: - """Check if the worker is currently idle.""" - # Worker is idle if it hasn't processed a task recently (within 0.2 seconds) - return (time.time() - self._last_task_time) > 0.2 - - @property - def idle_duration(self) -> float: - """Get the duration in seconds since the worker last processed a task.""" - return time.time() - self._last_task_time - - @property - def worker_id(self) -> int: - """Get the worker's ID.""" - return self._worker_id - - @override - def run(self) -> None: - """ - Main worker loop. - - Continuously pulls node IDs from ready_queue, executes them, - and pushes events to event_queue until stopped. - """ - while not self._stop_event.is_set(): - # Try to get a node ID from the ready queue (with timeout) - try: - node_id = self._ready_queue.get(timeout=0.1) - except queue.Empty: - continue - - self._last_task_time = time.time() - node = self._graph.nodes[node_id] - try: - self._current_node_started_at = None - self._execute_node(node) - self._ready_queue.task_done() - except Exception as e: - self._event_queue.put( - self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at) - ) - finally: - self._current_node_started_at = None - - def _execute_node(self, node: Node) -> None: - """ - Execute a single node and handle its events. - - Args: - node: The node instance to execute - """ - node.ensure_execution_id() - - error: Exception | None = None - result_event: GraphNodeEventBase | None = None - - # Execute the node with preserved context if execution context is provided - if self._execution_context is not None: - with self._execution_context: - self._invoke_node_run_start_hooks(node) - try: - node_events = node.run() - for event in node_events: - if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id: - self._current_node_started_at = event.start_at - self._event_queue.put(event) - if is_node_result_event(event): - result_event = event - except Exception as exc: - error = exc - raise - finally: - self._invoke_node_run_end_hooks(node, error, result_event) - else: - self._invoke_node_run_start_hooks(node) - try: - node_events = node.run() - for event in node_events: - if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id: - self._current_node_started_at = event.start_at - self._event_queue.put(event) - if is_node_result_event(event): - result_event = event - except Exception as exc: - error = exc - raise - finally: - self._invoke_node_run_end_hooks(node, error, result_event) - - def _invoke_node_run_start_hooks(self, node: Node) -> None: - """Invoke on_node_run_start hooks for all layers.""" - for layer in self._layers: - try: - layer.on_node_run_start(node) - except Exception: - # Silently ignore layer errors to prevent disrupting node execution - continue - - def _invoke_node_run_end_hooks( - self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None - ) -> None: - """Invoke on_node_run_end hooks for all layers.""" - for layer in self._layers: - try: - layer.on_node_run_end(node, error, result_event) - except Exception: - # Silently ignore layer errors to prevent disrupting node execution - continue - - def _build_fallback_failure_event( - self, node: Node, error: Exception, *, started_at: datetime | None = None - ) -> NodeRunFailedEvent: - """Build a failed event when worker-level execution aborts before a node emits its own result event.""" - failure_time = datetime.now(UTC).replace(tzinfo=None) - error_message = str(error) - return NodeRunFailedEvent( - id=node.execution_id, - node_id=node.id, - node_type=node.node_type, - in_iteration_id=None, - error=error_message, - start_at=started_at or failure_time, - finished_at=failure_time, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=error_message, - error_type=type(error).__name__, - ), - ) diff --git a/api/graphon/graph_engine/worker_management/__init__.py b/api/graphon/graph_engine/worker_management/__init__.py deleted file mode 100644 index 03de1f6daa7..00000000000 --- a/api/graphon/graph_engine/worker_management/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Worker management subsystem for graph engine. - -This package manages the worker pool, including creation, -scaling, and activity tracking. -""" - -from .worker_pool import WorkerPool - -__all__ = [ - "WorkerPool", -] diff --git a/api/graphon/graph_engine/worker_management/worker_pool.py b/api/graphon/graph_engine/worker_management/worker_pool.py deleted file mode 100644 index 85cdf1ca211..00000000000 --- a/api/graphon/graph_engine/worker_management/worker_pool.py +++ /dev/null @@ -1,277 +0,0 @@ -""" -Simple worker pool that consolidates functionality. - -This is a simpler implementation that merges WorkerPool, ActivityTracker, -DynamicScaler, and WorkerFactory into a single class. -""" - -import logging -import queue -import threading -from contextlib import AbstractContextManager -from typing import final - -from graphon.graph import Graph -from graphon.graph_events import GraphNodeEventBase - -from ..config import GraphEngineConfig -from ..layers.base import GraphEngineLayer -from ..ready_queue import ReadyQueue -from ..worker import Worker - -logger = logging.getLogger(__name__) - - -@final -class WorkerPool: - """ - Simple worker pool with integrated management. - - This class consolidates all worker management functionality into - a single, simpler implementation without excessive abstraction. - """ - - def __init__( - self, - ready_queue: ReadyQueue, - event_queue: queue.Queue[GraphNodeEventBase], - graph: Graph, - layers: list[GraphEngineLayer], - config: GraphEngineConfig, - execution_context: AbstractContextManager[object] | None = None, - ) -> None: - """ - Initialize the simple worker pool. - - Args: - ready_queue: Ready queue for nodes ready for execution - event_queue: Queue for worker events - graph: The workflow graph - layers: Graph engine layers for node execution hooks - config: GraphEngine worker pool configuration - execution_context: Optional execution context for context preservation - """ - self._ready_queue = ready_queue - self._event_queue = event_queue - self._graph = graph - self._execution_context = execution_context - self._layers = layers - self._config = config - - # Worker management - self._workers: list[Worker] = [] - self._worker_counter = 0 - self._lock = threading.RLock() - self._running = False - - # No longer tracking worker states with callbacks to avoid lock contention - - def start(self, initial_count: int | None = None) -> None: - """ - Start the worker pool. - - Args: - initial_count: Number of workers to start with (auto-calculated if None) - """ - with self._lock: - if self._running: - return - - self._running = True - - # Calculate initial worker count - if initial_count is None: - node_count = len(self._graph.nodes) - if node_count < 10: - initial_count = self._config.min_workers - elif node_count < 50: - initial_count = min(self._config.min_workers + 1, self._config.max_workers) - else: - initial_count = min(self._config.min_workers + 2, self._config.max_workers) - - logger.debug( - "Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)", - initial_count, - node_count, - self._config.min_workers, - self._config.max_workers, - ) - - # Create initial workers - for _ in range(initial_count): - self._create_worker() - - def stop(self) -> None: - """Stop all workers in the pool.""" - with self._lock: - self._running = False - worker_count = len(self._workers) - - if worker_count > 0: - logger.debug("Stopping worker pool: %d workers", worker_count) - - # Stop all workers - for worker in self._workers: - worker.stop() - - # Wait for workers to finish - for worker in self._workers: - if worker.is_alive(): - worker.join(timeout=2.0) - - self._workers.clear() - - def _create_worker(self) -> None: - """Create and start a new worker.""" - worker_id = self._worker_counter - self._worker_counter += 1 - - worker = Worker( - ready_queue=self._ready_queue, - event_queue=self._event_queue, - graph=self._graph, - layers=self._layers, - worker_id=worker_id, - execution_context=self._execution_context, - ) - - worker.start() - self._workers.append(worker) - - def _remove_worker(self, worker: Worker, worker_id: int) -> None: - """Remove a specific worker from the pool.""" - # Stop the worker - worker.stop() - - # Wait for it to finish - if worker.is_alive(): - worker.join(timeout=2.0) - - # Remove from list - if worker in self._workers: - self._workers.remove(worker) - - def _try_scale_up(self, queue_depth: int, current_count: int) -> bool: - """ - Try to scale up workers if needed. - - Args: - queue_depth: Current queue depth - current_count: Current number of workers - - Returns: - True if scaled up, False otherwise - """ - if queue_depth > self._config.scale_up_threshold and current_count < self._config.max_workers: - old_count = current_count - self._create_worker() - - logger.debug( - "Scaled up workers: %d -> %d (queue_depth=%d exceeded threshold=%d)", - old_count, - len(self._workers), - queue_depth, - self._config.scale_up_threshold, - ) - return True - return False - - def _try_scale_down(self, queue_depth: int, current_count: int, active_count: int, idle_count: int) -> bool: - """ - Try to scale down workers if we have excess capacity. - - Args: - queue_depth: Current queue depth - current_count: Current number of workers - active_count: Number of active workers - idle_count: Number of idle workers - - Returns: - True if scaled down, False otherwise - """ - # Skip if we're at minimum or have no idle workers - if current_count <= self._config.min_workers or idle_count == 0: - return False - - # Check if we have excess capacity - has_excess_capacity = ( - queue_depth <= active_count # Active workers can handle current queue - or idle_count > active_count # More idle than active workers - or (queue_depth == 0 and idle_count > 0) # No work and have idle workers - ) - - if not has_excess_capacity: - return False - - # Find and remove idle workers that have been idle long enough - workers_to_remove: list[tuple[Worker, int]] = [] - - for worker in self._workers: - # Check if worker is idle and has exceeded idle time threshold - if worker.is_idle and worker.idle_duration >= self._config.scale_down_idle_time: - # Don't remove if it would leave us unable to handle the queue - remaining_workers = current_count - len(workers_to_remove) - 1 - if remaining_workers >= self._config.min_workers and remaining_workers >= max(1, queue_depth // 2): - workers_to_remove.append((worker, worker.worker_id)) - # Only remove one worker per check to avoid aggressive scaling - break - - # Remove idle workers if any found - if workers_to_remove: - old_count = current_count - for worker, worker_id in workers_to_remove: - self._remove_worker(worker, worker_id) - - logger.debug( - "Scaled down workers: %d -> %d (removed %d idle workers after %.1fs, " - "queue_depth=%d, active=%d, idle=%d)", - old_count, - len(self._workers), - len(workers_to_remove), - self._config.scale_down_idle_time, - queue_depth, - active_count, - idle_count - len(workers_to_remove), - ) - return True - - return False - - def check_and_scale(self) -> None: - """Check and perform scaling if needed.""" - with self._lock: - if not self._running: - return - - current_count = len(self._workers) - queue_depth = self._ready_queue.qsize() - - # Count active vs idle workers by querying their state directly - idle_count = sum(1 for worker in self._workers if worker.is_idle) - active_count = current_count - idle_count - - # Try to scale up if queue is backing up - self._try_scale_up(queue_depth, current_count) - - # Try to scale down if we have excess capacity - self._try_scale_down(queue_depth, current_count, active_count, idle_count) - - def get_worker_count(self) -> int: - """Get current number of workers.""" - with self._lock: - return len(self._workers) - - def get_status(self) -> dict[str, int]: - """ - Get pool status information. - - Returns: - Dictionary with status information - """ - with self._lock: - return { - "total_workers": len(self._workers), - "queue_depth": self._ready_queue.qsize(), - "min_workers": self._config.min_workers, - "max_workers": self._config.max_workers, - } diff --git a/api/graphon/graph_events/__init__.py b/api/graphon/graph_events/__init__.py deleted file mode 100644 index 7cec587a053..00000000000 --- a/api/graphon/graph_events/__init__.py +++ /dev/null @@ -1,84 +0,0 @@ -# Agent events -from .agent import NodeRunAgentLogEvent - -# Base events -from .base import ( - BaseGraphEvent, - GraphEngineEvent, - GraphNodeEventBase, -) - -# Graph events -from .graph import ( - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) - -# Iteration events -from .iteration import ( - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, -) - -# Loop events -from .loop import ( - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, -) - -# Node events -from .node import ( - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunPauseRequestedEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, - is_node_result_event, -) - -__all__ = [ - "BaseGraphEvent", - "GraphEngineEvent", - "GraphNodeEventBase", - "GraphRunAbortedEvent", - "GraphRunFailedEvent", - "GraphRunPartialSucceededEvent", - "GraphRunPausedEvent", - "GraphRunStartedEvent", - "GraphRunSucceededEvent", - "NodeRunAgentLogEvent", - "NodeRunExceptionEvent", - "NodeRunFailedEvent", - "NodeRunHumanInputFormFilledEvent", - "NodeRunHumanInputFormTimeoutEvent", - "NodeRunIterationFailedEvent", - "NodeRunIterationNextEvent", - "NodeRunIterationStartedEvent", - "NodeRunIterationSucceededEvent", - "NodeRunLoopFailedEvent", - "NodeRunLoopNextEvent", - "NodeRunLoopStartedEvent", - "NodeRunLoopSucceededEvent", - "NodeRunPauseRequestedEvent", - "NodeRunRetrieverResourceEvent", - "NodeRunRetryEvent", - "NodeRunStartedEvent", - "NodeRunStreamChunkEvent", - "NodeRunSucceededEvent", - "NodeRunVariableUpdatedEvent", - "is_node_result_event", -] diff --git a/api/graphon/graph_events/agent.py b/api/graphon/graph_events/agent.py deleted file mode 100644 index 759fe3a71c7..00000000000 --- a/api/graphon/graph_events/agent.py +++ /dev/null @@ -1,17 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import Field - -from .base import GraphAgentNodeEventBase - - -class NodeRunAgentLogEvent(GraphAgentNodeEventBase): - message_id: str = Field(..., description="message id") - label: str = Field(..., description="label") - node_execution_id: str = Field(..., description="node execution id") - parent_id: str | None = Field(..., description="parent id") - error: str | None = Field(..., description="error") - status: str = Field(..., description="status") - data: Mapping[str, Any] = Field(..., description="data") - metadata: Mapping[str, object] = Field(default_factory=dict) diff --git a/api/graphon/graph_events/base.py b/api/graphon/graph_events/base.py deleted file mode 100644 index 4ea9787b9ac..00000000000 --- a/api/graphon/graph_events/base.py +++ /dev/null @@ -1,31 +0,0 @@ -from pydantic import BaseModel, Field - -from graphon.enums import NodeType -from graphon.node_events import NodeRunResult - - -class GraphEngineEvent(BaseModel): - pass - - -class BaseGraphEvent(GraphEngineEvent): - pass - - -class GraphNodeEventBase(GraphEngineEvent): - id: str = Field(..., description="node execution id") - node_id: str - node_type: NodeType - - in_iteration_id: str | None = None - """iteration id if node is in iteration""" - in_loop_id: str | None = None - """loop id if node is in loop""" - - # The version of the node, or "1" if not specified. - node_version: str = "1" - node_run_result: NodeRunResult = Field(default_factory=NodeRunResult) - - -class GraphAgentNodeEventBase(GraphNodeEventBase): - pass diff --git a/api/graphon/graph_events/graph.py b/api/graphon/graph_events/graph.py deleted file mode 100644 index 3782cb49bce..00000000000 --- a/api/graphon/graph_events/graph.py +++ /dev/null @@ -1,57 +0,0 @@ -from pydantic import Field - -from graphon.entities.pause_reason import PauseReason -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.graph_events import BaseGraphEvent - - -class GraphRunStartedEvent(BaseGraphEvent): - # Reason is emitted for workflow start events and is always set. - reason: WorkflowStartReason = Field( - default=WorkflowStartReason.INITIAL, - description="reason for workflow start", - ) - - -class GraphRunSucceededEvent(BaseGraphEvent): - """Event emitted when a run completes successfully with final outputs.""" - - outputs: dict[str, object] = Field( - default_factory=dict, - description="Final workflow outputs keyed by output selector.", - ) - - -class GraphRunFailedEvent(BaseGraphEvent): - error: str = Field(..., description="failed reason") - exceptions_count: int = Field(description="exception count", default=0) - - -class GraphRunPartialSucceededEvent(BaseGraphEvent): - """Event emitted when a run finishes with partial success and failures.""" - - exceptions_count: int = Field(..., description="exception count") - outputs: dict[str, object] = Field( - default_factory=dict, - description="Outputs that were materialised before failures occurred.", - ) - - -class GraphRunAbortedEvent(BaseGraphEvent): - """Event emitted when a graph run is aborted by user command.""" - - reason: str | None = Field(default=None, description="reason for abort") - outputs: dict[str, object] = Field( - default_factory=dict, - description="Outputs produced before the abort was requested.", - ) - - -class GraphRunPausedEvent(BaseGraphEvent): - """Event emitted when a graph run is paused by user command.""" - - reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list) - outputs: dict[str, object] = Field( - default_factory=dict, - description="Outputs available to the client while the run is paused.", - ) diff --git a/api/graphon/graph_events/human_input.py b/api/graphon/graph_events/human_input.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/graphon/graph_events/iteration.py b/api/graphon/graph_events/iteration.py deleted file mode 100644 index 28627395fd8..00000000000 --- a/api/graphon/graph_events/iteration.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import GraphNodeEventBase - - -class NodeRunIterationStartedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class NodeRunIterationNextEvent(GraphNodeEventBase): - node_title: str - index: int = Field(..., description="index") - pre_iteration_output: Any = None - - -class NodeRunIterationSucceededEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class NodeRunIterationFailedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/graphon/graph_events/loop.py b/api/graphon/graph_events/loop.py deleted file mode 100644 index 7cdc5427e2b..00000000000 --- a/api/graphon/graph_events/loop.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import GraphNodeEventBase - - -class NodeRunLoopStartedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class NodeRunLoopNextEvent(GraphNodeEventBase): - node_title: str - index: int = Field(..., description="index") - pre_loop_output: Any = None - - -class NodeRunLoopSucceededEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class NodeRunLoopFailedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/graphon/graph_events/node.py b/api/graphon/graph_events/node.py deleted file mode 100644 index 471ae08ee7f..00000000000 --- a/api/graphon/graph_events/node.py +++ /dev/null @@ -1,106 +0,0 @@ -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any - -from pydantic import Field - -from graphon.entities.pause_reason import PauseReason -from graphon.variables.variables import Variable - -from .base import GraphNodeEventBase - - -class NodeRunStartedEvent(GraphNodeEventBase): - node_title: str - predecessor_node_id: str | None = None - start_at: datetime = Field(..., description="node start time") - extras: dict[str, object] = Field(default_factory=dict) - - # FIXME(-LAN-): only for ToolNode - provider_type: str = "" - provider_id: str = "" - - -class NodeRunStreamChunkEvent(GraphNodeEventBase): - # Spec-compliant fields - selector: Sequence[str] = Field( - ..., description="selector identifying the output location (e.g., ['nodeA', 'text'])" - ) - chunk: str = Field(..., description="the actual chunk content") - is_final: bool = Field(default=False, description="indicates if this is the last chunk") - - -class NodeRunRetrieverResourceEvent(GraphNodeEventBase): - retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources") - context: str = Field(..., description="context") - - -class NodeRunSucceededEvent(GraphNodeEventBase): - start_at: datetime = Field(..., description="node start time") - finished_at: datetime | None = Field(default=None, description="node finish time") - - -class NodeRunVariableUpdatedEvent(GraphNodeEventBase): - """Request that the engine apply a variable update before downstream observers continue.""" - - variable: Variable = Field(..., description="Updated variable payload to apply.") - - -class NodeRunFailedEvent(GraphNodeEventBase): - error: str = Field(..., description="error") - start_at: datetime = Field(..., description="node start time") - finished_at: datetime | None = Field(default=None, description="node finish time") - - -class NodeRunExceptionEvent(GraphNodeEventBase): - error: str = Field(..., description="error") - start_at: datetime = Field(..., description="node start time") - finished_at: datetime | None = Field(default=None, description="node finish time") - - -class NodeRunRetryEvent(NodeRunStartedEvent): - error: str = Field(..., description="error") - retry_index: int = Field(..., description="which retry attempt is about to be performed") - - -class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase): - """Emitted when a HumanInput form is submitted and before the node finishes.""" - - node_title: str = Field(..., description="HumanInput node title") - rendered_content: str = Field(..., description="Markdown content rendered with user inputs.") - action_id: str = Field(..., description="User action identifier chosen in the form.") - action_text: str = Field(..., description="Display text of the chosen action button.") - - -class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase): - """Emitted when a HumanInput form times out.""" - - node_title: str = Field(..., description="HumanInput node title") - expiration_time: datetime = Field(..., description="Form expiration time") - - -class NodeRunPauseRequestedEvent(GraphNodeEventBase): - reason: PauseReason = Field(..., description="pause reason") - - -def is_node_result_event(event: GraphNodeEventBase) -> bool: - """ - Check if an event is a final result event from node execution. - - A result event indicates the completion of a node execution and contains - runtime information such as inputs, outputs, or error details. - - Args: - event: The event to check - - Returns: - True if the event is a node result event (succeeded/failed/paused), False otherwise - """ - return isinstance( - event, - ( - NodeRunSucceededEvent, - NodeRunFailedEvent, - NodeRunPauseRequestedEvent, - ), - ) diff --git a/api/graphon/model_runtime/README.md b/api/graphon/model_runtime/README.md deleted file mode 100644 index b9d2c552105..00000000000 --- a/api/graphon/model_runtime/README.md +++ /dev/null @@ -1,51 +0,0 @@ -# Model Runtime - -This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers. - -- On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers, -- On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic. - -## Features - -- Supports capability invocation for 6 types of models - - - `LLM` - LLM text completion, dialogue, pre-computed tokens capability - - `Text Embedding Model` - Text Embedding, pre-computed tokens capability - - `Rerank Model` - Segment Rerank capability - - `Speech-to-text Model` - Speech to text capability - - `Text-to-speech Model` - Text to speech capability - - `Moderation` - Moderation capability - -- Model provider display - - Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. - -- Selectable model list display - - After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models. - - In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models. - -- Provider/model credential authentication - - The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. - -## Structure - -Model Runtime is divided into three layers: - -- The outermost layer is the factory method - - It provides methods for obtaining all providers, all model lists, getting provider instances, and authenticating provider/model credentials. - -- The second layer is the provider layer - - It provides the current provider's model list, model instance obtaining, provider credential authentication, and provider configuration rule information, **allowing horizontal expansion** to support different providers. - -- The bottom layer is the model layer - - It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types). - -## Documentation - -For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/). diff --git a/api/graphon/model_runtime/README_CN.md b/api/graphon/model_runtime/README_CN.md deleted file mode 100644 index 0a8b56b3fed..00000000000 --- a/api/graphon/model_runtime/README_CN.md +++ /dev/null @@ -1,64 +0,0 @@ -# Model Runtime - -่ฏฅๆจกๅ—ๆไพ›ไบ†ๅ„ๆจกๅž‹็š„่ฐƒ็”จใ€้‰ดๆƒๆŽฅๅฃ๏ผŒๅนถไธบ Dify ๆไพ›ไบ†็ปŸไธ€็š„ๆจกๅž‹ไพ›ๅบ”ๅ•†็š„ไฟกๆฏๅ’Œๅ‡ญๆฎ่กจๅ•่ง„ๅˆ™ใ€‚ - -- ไธ€ๆ–น้ขๅฐ†ๆจกๅž‹ๅ’ŒไธŠไธ‹ๆธธ่งฃ่€ฆ๏ผŒๆ–นไพฟๅผ€ๅ‘่€…ๅฏนๆจกๅž‹ๆจชๅ‘ๆ‰ฉๅฑ•๏ผŒ -- ๅฆไธ€ๆ–น้ขๆไพ›ไบ†ๅช้œ€ๅœจๅŽ็ซฏๅฎšไน‰ไพ›ๅบ”ๅ•†ๅ’Œๆจกๅž‹๏ผŒๅณๅฏๅœจๅ‰็ซฏ้กต้ข็›ดๆŽฅๅฑ•็คบ๏ผŒๆ— ้œ€ไฟฎๆ”นๅ‰็ซฏ้€ป่พ‘ใ€‚ - -## ๅŠŸ่ƒฝไป‹็ป - -- ๆ”ฏๆŒ 6 ็งๆจกๅž‹็ฑปๅž‹็š„่ƒฝๅŠ›่ฐƒ็”จ - - - `LLM` - LLM ๆ–‡ๆœฌ่กฅๅ…จใ€ๅฏน่ฏ๏ผŒ้ข„่ฎก็ฎ— tokens ่ƒฝๅŠ› - - `Text Embedding Model` - ๆ–‡ๆœฌ Embedding๏ผŒ้ข„่ฎก็ฎ— tokens ่ƒฝๅŠ› - - `Rerank Model` - ๅˆ†ๆฎต Rerank ่ƒฝๅŠ› - - `Speech-to-text Model` - ่ฏญ้Ÿณ่ฝฌๆ–‡ๆœฌ่ƒฝๅŠ› - - `Text-to-speech Model` - ๆ–‡ๆœฌ่ฝฌ่ฏญ้Ÿณ่ƒฝๅŠ› - - `Moderation` - Moderation ่ƒฝๅŠ› - -- ๆจกๅž‹ไพ›ๅบ”ๅ•†ๅฑ•็คบ - - ๅฑ•็คบๆ‰€ๆœ‰ๅทฒๆ”ฏๆŒ็š„ไพ›ๅบ”ๅ•†ๅˆ—่กจ๏ผŒ้™คไบ†่ฟ”ๅ›žไพ›ๅบ”ๅ•†ๅ็งฐใ€ๅ›พๆ ‡ไน‹ๅค–๏ผŒ่ฟ˜ๆไพ›ไบ†ๆ”ฏๆŒ็š„ๆจกๅž‹็ฑปๅž‹ๅˆ—่กจ๏ผŒ้ข„ๅฎšไน‰ๆจกๅž‹ๅˆ—่กจใ€้…็ฝฎๆ–นๅผไปฅๅŠ้…็ฝฎๅ‡ญๆฎ็š„่กจๅ•่ง„ๅˆ™็ญ‰็ญ‰ใ€‚ - -- ๅฏ้€‰ๆ‹ฉ็š„ๆจกๅž‹ๅˆ—่กจๅฑ•็คบ - - ้…็ฝฎไพ›ๅบ”ๅ•†/ๆจกๅž‹ๅ‡ญๆฎๅŽ๏ผŒๅฏๅœจๆญคไธ‹ๆ‹‰๏ผˆๅบ”็”จ็ผ–ๆŽ’็•Œ้ข/้ป˜่ฎคๆจกๅž‹๏ผ‰ๆŸฅ็œ‹ๅฏ็”จ็š„ LLM ๅˆ—่กจ๏ผŒๅ…ถไธญ็ฐ่‰ฒ็š„ไธบๆœช้…็ฝฎๅ‡ญๆฎไพ›ๅบ”ๅ•†็š„้ข„ๅฎšไน‰ๆจกๅž‹ๅˆ—่กจ๏ผŒๆ–นไพฟ็”จๆˆทๆŸฅ็œ‹ๅทฒๆ”ฏๆŒ็š„ๆจกๅž‹ใ€‚ - - ้™คๆญคไน‹ๅค–๏ผŒ่ฏฅๅˆ—่กจ่ฟ˜่ฟ”ๅ›žไบ† LLM ๅฏ้…็ฝฎ็š„ๅ‚ๆ•ฐไฟกๆฏๅ’Œ่ง„ๅˆ™ใ€‚่ฟ™้‡Œ็š„ๅ‚ๆ•ฐๅ‡ไธบๅŽ็ซฏๅฎšไน‰๏ผŒ็›ธๆฏ”ไน‹ๅ‰ๅชๆœ‰ 5 ็งๅ›บๅฎšๅ‚ๆ•ฐ๏ผŒ่ฟ™้‡ŒๅฏไธบไธๅŒๆจกๅž‹่ฎพ็ฝฎๆ‰€ๆ”ฏๆŒ็š„ๅ„็งๅ‚ๆ•ฐใ€‚ - -- ไพ›ๅบ”ๅ•†/ๆจกๅž‹ๅ‡ญๆฎ้‰ดๆƒ - - ไพ›ๅบ”ๅ•†ๅˆ—่กจ่ฟ”ๅ›žไบ†ๅ‡ญๆฎ่กจๅ•็š„้…็ฝฎไฟกๆฏ๏ผŒๅฏ้€š่ฟ‡ Runtime ๆไพ›็š„ๆŽฅๅฃๅฏนๅ‡ญๆฎ่ฟ›่กŒ้‰ดๆƒใ€‚ - -## ็ป“ๆž„ - -Model Runtime ๅˆ†ไธ‰ๅฑ‚๏ผš - -- ๆœ€ๅค–ๅฑ‚ไธบๅทฅๅŽ‚ๆ–นๆณ• - - ๆไพ›่Žทๅ–ๆ‰€ๆœ‰ไพ›ๅบ”ๅ•†ใ€ๆ‰€ๆœ‰ๆจกๅž‹ๅˆ—่กจใ€่Žทๅ–ไพ›ๅบ”ๅ•†ๅฎžไพ‹ใ€ไพ›ๅบ”ๅ•†/ๆจกๅž‹ๅ‡ญๆฎ้‰ดๆƒๆ–นๆณ•ใ€‚ - -- ็ฌฌไบŒๅฑ‚ไธบไพ›ๅบ”ๅ•†ๅฑ‚ - - ๆไพ›่Žทๅ–ๅฝ“ๅ‰ไพ›ๅบ”ๅ•†ๆจกๅž‹ๅˆ—่กจใ€่Žทๅ–ๆจกๅž‹ๅฎžไพ‹ใ€ไพ›ๅบ”ๅ•†ๅ‡ญๆฎ้‰ดๆƒใ€ไพ›ๅบ”ๅ•†้…็ฝฎ่ง„ๅˆ™ไฟกๆฏ๏ผŒ**ๅฏๆจชๅ‘ๆ‰ฉๅฑ•**ไปฅๆ”ฏๆŒไธๅŒ็š„ไพ›ๅบ”ๅ•†ใ€‚ - - ๅฏนไบŽไพ›ๅบ”ๅ•†/ๆจกๅž‹ๅ‡ญๆฎ๏ผŒๆœ‰ไธค็งๆƒ…ๅ†ต - - - ๅฆ‚ OpenAI ่ฟ™็ฑปไธญๅฟƒๅŒ–ไพ›ๅบ”ๅ•†๏ผŒ้œ€่ฆๅฎšไน‰ๅฆ‚**api_key**่ฟ™็ฑป็š„้‰ดๆƒๅ‡ญๆฎ - - ๅฆ‚[**Xinference**](https://github.com/xorbitsai/inference)่ฟ™็ฑปๆœฌๅœฐ้ƒจ็ฝฒ็š„ไพ›ๅบ”ๅ•†๏ผŒ้œ€่ฆๅฎšไน‰ๅฆ‚**server_url**่ฟ™็ฑป็š„ๅœฐๅ€ๅ‡ญๆฎ๏ผŒๆœ‰ๆ—ถๅ€™่ฟ˜้œ€่ฆๅฎšไน‰**model_uid**ไน‹็ฑป็š„ๆจกๅž‹็ฑปๅž‹ๅ‡ญๆฎใ€‚ๅฝ“ๅœจไพ›ๅบ”ๅ•†ๅฑ‚ๅฎšไน‰ไบ†่ฟ™ไบ›ๅ‡ญๆฎๅŽ๏ผŒๅฐฑๅฏไปฅๅœจๅ‰็ซฏ้กต้ขไธŠ็›ดๆŽฅๅฑ•็คบ๏ผŒๆ— ้œ€ไฟฎๆ”นๅ‰็ซฏ้€ป่พ‘ใ€‚ - - ๅฝ“้…็ฝฎๅฅฝๅ‡ญๆฎๅŽ๏ผŒๅฐฑๅฏไปฅ้€š่ฟ‡ DifyRuntime ็š„ๅค–้ƒจๆŽฅๅฃ็›ดๆŽฅ่Žทๅ–ๅˆฐๅฏนๅบ”ไพ›ๅบ”ๅ•†ๆ‰€้œ€่ฆ็š„**Schema**๏ผˆๅ‡ญๆฎ่กจๅ•่ง„ๅˆ™๏ผ‰๏ผŒไปŽ่€Œๅœจๅฏไปฅๅœจไธไฟฎๆ”นๅ‰็ซฏ้€ป่พ‘็š„ๆƒ…ๅ†ตไธ‹๏ผŒๆไพ›ๆ–ฐ็š„ไพ›ๅบ”ๅ•†/ๆจกๅž‹็š„ๆ”ฏๆŒใ€‚ - -- ๆœ€ๅบ•ๅฑ‚ไธบๆจกๅž‹ๅฑ‚ - - ๆไพ›ๅ„็งๆจกๅž‹็ฑปๅž‹็š„็›ดๆŽฅ่ฐƒ็”จใ€้ข„ๅฎšไน‰ๆจกๅž‹้…็ฝฎไฟกๆฏใ€่Žทๅ–้ข„ๅฎšไน‰/่ฟœ็จ‹ๆจกๅž‹ๅˆ—่กจใ€ๆจกๅž‹ๅ‡ญๆฎ้‰ดๆƒๆ–นๆณ•๏ผŒไธๅŒๆจกๅž‹้ขๅค–ๆไพ›ไบ†็‰นๆฎŠๆ–นๆณ•๏ผŒๅฆ‚ LLM ๆไพ›้ข„่ฎก็ฎ— tokens ๆ–นๆณ•ใ€่Žทๅ–่ดน็”จไฟกๆฏๆ–นๆณ•็ญ‰๏ผŒ**ๅฏๆจชๅ‘ๆ‰ฉๅฑ•**ๅŒไพ›ๅบ”ๅ•†ไธ‹ไธๅŒ็š„ๆจกๅž‹๏ผˆๆ”ฏๆŒ็š„ๆจกๅž‹็ฑปๅž‹ไธ‹๏ผ‰ใ€‚ - - ๅœจ่ฟ™้‡Œๆˆ‘ไปฌ้œ€่ฆๅ…ˆๅŒบๅˆ†ๆจกๅž‹ๅ‚ๆ•ฐไธŽๆจกๅž‹ๅ‡ญๆฎใ€‚ - - - ๆจกๅž‹ๅ‚ๆ•ฐ (**ๅœจๆœฌๅฑ‚ๅฎšไน‰**)๏ผš่ฟ™ๆ˜ฏไธ€็ฑป็ปๅธธ้œ€่ฆๅ˜ๅŠจ๏ผŒ้šๆ—ถ่ฐƒๆ•ด็š„ๅ‚ๆ•ฐ๏ผŒๅฆ‚ LLM ็š„ **max_tokens**ใ€**temperature** ็ญ‰๏ผŒ่ฟ™ไบ›ๅ‚ๆ•ฐๆ˜ฏ็”ฑ็”จๆˆทๅœจๅ‰็ซฏ้กต้ขไธŠ่ฟ›่กŒ่ฐƒๆ•ด็š„๏ผŒๅ› ๆญค้œ€่ฆๅœจๅŽ็ซฏๅฎšไน‰ๅ‚ๆ•ฐ็š„่ง„ๅˆ™๏ผŒไปฅไพฟๅ‰็ซฏ้กต้ข่ฟ›่กŒๅฑ•็คบๅ’Œ่ฐƒๆ•ดใ€‚ๅœจ DifyRuntime ไธญ๏ผŒไป–ไปฌ็š„ๅ‚ๆ•ฐๅไธ€่ˆฌไธบ**model_parameters: dict[str, any]**ใ€‚ - - - ๆจกๅž‹ๅ‡ญๆฎ (**ๅœจไพ›ๅบ”ๅ•†ๅฑ‚ๅฎšไน‰**)๏ผš่ฟ™ๆ˜ฏไธ€็ฑปไธ็ปๅธธๅ˜ๅŠจ๏ผŒไธ€่ˆฌๅœจ้…็ฝฎๅฅฝๅŽๅฐฑไธไผšๅ†ๅ˜ๅŠจ็š„ๅ‚ๆ•ฐ๏ผŒๅฆ‚ **api_key**ใ€**server_url** ็ญ‰ใ€‚ๅœจ DifyRuntime ไธญ๏ผŒไป–ไปฌ็š„ๅ‚ๆ•ฐๅไธ€่ˆฌไธบ**credentials: dict[str, any]**๏ผŒProvider ๅฑ‚็š„ credentials ไผš็›ดๆŽฅ่ขซไผ ้€’ๅˆฐ่ฟ™ไธ€ๅฑ‚๏ผŒไธ้œ€่ฆๅ†ๅ•็‹ฌๅฎšไน‰ใ€‚ - -## ๆ–‡ๆกฃ - -ๆœ‰ๅ…ณๅฆ‚ไฝ•ๆทปๅŠ ๆ–ฐไพ›ๅบ”ๅ•†ๆˆ–ๆจกๅž‹็š„่ฏฆ็ป†ๆ–‡ๆกฃ๏ผŒ่ฏทๅ‚้˜… [Dify ๆ–‡ๆกฃ](https://docs.dify.ai/)ใ€‚ diff --git a/api/graphon/model_runtime/__init__.py b/api/graphon/model_runtime/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/graphon/model_runtime/callbacks/__init__.py b/api/graphon/model_runtime/callbacks/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/graphon/model_runtime/callbacks/base_callback.py b/api/graphon/model_runtime/callbacks/base_callback.py deleted file mode 100644 index cd85cf63016..00000000000 --- a/api/graphon/model_runtime/callbacks/base_callback.py +++ /dev/null @@ -1,159 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Mapping, Sequence - -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - -_TEXT_COLOR_MAPPING = { - "blue": "36;1", - "yellow": "33;1", - "pink": "38;5;200", - "green": "32;1", - "red": "31;1", -} - - -class Callback(ABC): - """ - Base class for callbacks. - Only for LLM. - """ - - raise_error: bool = False - - @abstractmethod - def on_before_invoke( - self, - llm_instance: AIModel, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - Before invoke callback - - :param llm_instance: LLM instance - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: optional end-user identifier for the invocation - :param invocation_context: opaque request metadata for the current invocation - """ - raise NotImplementedError() - - @abstractmethod - def on_new_chunk( - self, - llm_instance: AIModel, - chunk: LLMResultChunk, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - On new chunk callback - - :param llm_instance: LLM instance - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: optional end-user identifier for the invocation - :param invocation_context: opaque request metadata for the current invocation - """ - raise NotImplementedError() - - @abstractmethod - def on_after_invoke( - self, - llm_instance: AIModel, - result: LLMResult, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - After invoke callback - - :param llm_instance: LLM instance - :param result: result - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: optional end-user identifier for the invocation - :param invocation_context: opaque request metadata for the current invocation - """ - raise NotImplementedError() - - @abstractmethod - def on_invoke_error( - self, - llm_instance: AIModel, - ex: Exception, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - Invoke error callback - - :param llm_instance: LLM instance - :param ex: exception - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: optional end-user identifier for the invocation - :param invocation_context: opaque request metadata for the current invocation - """ - raise NotImplementedError() - - def print_text(self, text: str, color: str | None = None, end: str = ""): - """Print text with highlighting and no end characters.""" - text_to_print = self._get_colored_text(text, color) if color else text - print(text_to_print, end=end) - - def _get_colored_text(self, text: str, color: str) -> str: - """Get colored text.""" - color_str = _TEXT_COLOR_MAPPING[color] - return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" diff --git a/api/graphon/model_runtime/callbacks/logging_callback.py b/api/graphon/model_runtime/callbacks/logging_callback.py deleted file mode 100644 index f96eb446fc8..00000000000 --- a/api/graphon/model_runtime/callbacks/logging_callback.py +++ /dev/null @@ -1,180 +0,0 @@ -import json -import logging -import sys -from collections.abc import Mapping, Sequence -from typing import cast - -from graphon.model_runtime.callbacks.base_callback import Callback -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -class LoggingCallback(Callback): - def on_before_invoke( - self, - llm_instance: AIModel, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - Before invoke callback - - :param llm_instance: LLM instance - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: optional end-user identifier for the invocation - :param invocation_context: opaque request metadata for the current invocation - """ - self.print_text("\n[on_llm_before_invoke]\n", color="blue") - self.print_text(f"Model: {model}\n", color="blue") - self.print_text("Parameters:\n", color="blue") - for key, value in model_parameters.items(): - self.print_text(f"\t{key}: {value}\n", color="blue") - - if stop: - self.print_text(f"\tstop: {stop}\n", color="blue") - - if tools: - self.print_text("\tTools:\n", color="blue") - for tool in tools: - self.print_text(f"\t\t{tool.name}\n", color="blue") - - self.print_text(f"Stream: {stream}\n", color="blue") - if user: - self.print_text(f"User: {user}\n", color="blue") - - if invocation_context: - self.print_text(f"Invocation context: {dict(invocation_context)}\n", color="blue") - - self.print_text("Prompt messages:\n", color="blue") - for prompt_message in prompt_messages: - if prompt_message.name: - self.print_text(f"\tname: {prompt_message.name}\n", color="blue") - - self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue") - self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue") - - if stream: - self.print_text("\n[on_llm_new_chunk]") - - def on_new_chunk( - self, - llm_instance: AIModel, - chunk: LLMResultChunk, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - On new chunk callback - - :param llm_instance: LLM instance - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - """ - _ = user, invocation_context - sys.stdout.write(cast(str, chunk.delta.message.content)) - sys.stdout.flush() - - def on_after_invoke( - self, - llm_instance: AIModel, - result: LLMResult, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - After invoke callback - - :param llm_instance: LLM instance - :param result: result - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - """ - _ = user, invocation_context - self.print_text("\n[on_llm_after_invoke]\n", color="yellow") - self.print_text(f"Content: {result.message.content}\n", color="yellow") - - if result.message.tool_calls: - self.print_text("Tool calls:\n", color="yellow") - for tool_call in result.message.tool_calls: - self.print_text(f"\t{tool_call.id}\n", color="yellow") - self.print_text(f"\t{tool_call.function.name}\n", color="yellow") - self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow") - - self.print_text(f"Model: {result.model}\n", color="yellow") - self.print_text(f"Usage: {result.usage}\n", color="yellow") - self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow") - - def on_invoke_error( - self, - llm_instance: AIModel, - ex: Exception, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - Invoke error callback - - :param llm_instance: LLM instance - :param ex: exception - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - """ - _ = user, invocation_context - self.print_text("\n[on_llm_invoke_error]\n", color="red") - logger.exception(ex) diff --git a/api/graphon/model_runtime/entities/__init__.py b/api/graphon/model_runtime/entities/__init__.py deleted file mode 100644 index a24e437d48e..00000000000 --- a/api/graphon/model_runtime/entities/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -from .llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from .message_entities import ( - AssistantPromptMessage, - AudioPromptMessageContent, - DocumentPromptMessageContent, - ImagePromptMessageContent, - MultiModalPromptMessageContent, - PromptMessage, - PromptMessageContent, - PromptMessageContentType, - PromptMessageRole, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, - VideoPromptMessageContent, -) -from .model_entities import ModelPropertyKey - -__all__ = [ - "AssistantPromptMessage", - "AudioPromptMessageContent", - "DocumentPromptMessageContent", - "ImagePromptMessageContent", - "LLMMode", - "LLMResult", - "LLMResultChunk", - "LLMResultChunkDelta", - "LLMUsage", - "ModelPropertyKey", - "MultiModalPromptMessageContent", - "PromptMessage", - "PromptMessageContent", - "PromptMessageContentType", - "PromptMessageRole", - "PromptMessageTool", - "SystemPromptMessage", - "TextPromptMessageContent", - "ToolPromptMessage", - "UserPromptMessage", - "VideoPromptMessageContent", -] diff --git a/api/graphon/model_runtime/entities/common_entities.py b/api/graphon/model_runtime/entities/common_entities.py deleted file mode 100644 index b673efae228..00000000000 --- a/api/graphon/model_runtime/entities/common_entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from pydantic import BaseModel, model_validator - - -class I18nObject(BaseModel): - """ - Model class for i18n object. - """ - - zh_Hans: str | None = None - en_US: str - - @model_validator(mode="after") - def _(self): - if not self.zh_Hans: - self.zh_Hans = self.en_US - return self diff --git a/api/graphon/model_runtime/entities/defaults.py b/api/graphon/model_runtime/entities/defaults.py deleted file mode 100644 index bcce17c5d5b..00000000000 --- a/api/graphon/model_runtime/entities/defaults.py +++ /dev/null @@ -1,130 +0,0 @@ -from graphon.model_runtime.entities.model_entities import DefaultParameterName - -PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { - DefaultParameterName.TEMPERATURE: { - "label": { - "en_US": "Temperature", - "zh_Hans": "ๆธฉๅบฆ", - }, - "type": "float", - "help": { - "en_US": "Controls randomness. Lower temperature results in less random completions." - " As the temperature approaches zero, the model will become deterministic and repetitive." - " Higher temperature results in more random completions.", - "zh_Hans": "ๆธฉๅบฆๆŽงๅˆถ้šๆœบๆ€งใ€‚่พƒไฝŽ็š„ๆธฉๅบฆไผšๅฏผ่‡ด่พƒๅฐ‘็š„้šๆœบๅฎŒๆˆใ€‚้š็€ๆธฉๅบฆๆŽฅ่ฟ‘้›ถ๏ผŒๆจกๅž‹ๅฐ†ๅ˜ๅพ—็กฎๅฎšๆ€งๅ’Œ้‡ๅคๆ€งใ€‚" - "่พƒ้ซ˜็š„ๆธฉๅบฆไผšๅฏผ่‡ดๆ›ดๅคš็š„้šๆœบๅฎŒๆˆใ€‚", - }, - "required": False, - "default": 0.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.TOP_P: { - "label": { - "en_US": "Top P", - "zh_Hans": "Top P", - }, - "type": "float", - "help": { - "en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options" - " are considered.", - "zh_Hans": "้€š่ฟ‡ๆ ธๅฟƒ้‡‡ๆ ทๆŽงๅˆถๅคšๆ ทๆ€ง๏ผš0.5 ่กจ็คบ่€ƒ่™‘ไบ†ไธ€ๅŠ็š„ๆ‰€ๆœ‰ๅฏ่ƒฝๆ€งๅŠ ๆƒ้€‰้กนใ€‚", - }, - "required": False, - "default": 1.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.TOP_K: { - "label": { - "en_US": "Top K", - "zh_Hans": "Top K", - }, - "type": "int", - "help": { - "en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.", - "zh_Hans": "้€š่ฟ‡ๅชไฟ็•™ๆฏไธ€ๆญฅไธญๆœ€ๅฏ่ƒฝ็š„ k ไธชๆ ‡่ฎฐๆฅ้™ๅˆถ่ฆ่€ƒ่™‘็š„ๆ ‡่ฎฐๆ•ฐ้‡ใ€‚", - }, - "required": False, - "default": 50, - "min": 1, - "max": 100, - "precision": 0, - }, - DefaultParameterName.PRESENCE_PENALTY: { - "label": { - "en_US": "Presence Penalty", - "zh_Hans": "ๅญ˜ๅœจๆƒฉ็ฝš", - }, - "type": "float", - "help": { - "en_US": "Applies a penalty to the log-probability of tokens already in the text.", - "zh_Hans": "ๅฏนๆ–‡ๆœฌไธญๅทฒๆœ‰็š„ๆ ‡่ฎฐ็š„ๅฏนๆ•ฐๆฆ‚็އๆ–ฝๅŠ ๆƒฉ็ฝšใ€‚", - }, - "required": False, - "default": 0.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.FREQUENCY_PENALTY: { - "label": { - "en_US": "Frequency Penalty", - "zh_Hans": "้ข‘็އๆƒฉ็ฝš", - }, - "type": "float", - "help": { - "en_US": "Applies a penalty to the log-probability of tokens that appear in the text.", - "zh_Hans": "ๅฏนๆ–‡ๆœฌไธญๅ‡บ็Žฐ็š„ๆ ‡่ฎฐ็š„ๅฏนๆ•ฐๆฆ‚็އๆ–ฝๅŠ ๆƒฉ็ฝšใ€‚", - }, - "required": False, - "default": 0.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.MAX_TOKENS: { - "label": { - "en_US": "Max Tokens", - "zh_Hans": "ๆœ€ๅคง Token ๆ•ฐ", - }, - "type": "int", - "help": { - "en_US": "Specifies the upper limit on the length of generated results." - " If the generated results are truncated, you can increase this parameter.", - "zh_Hans": "ๆŒ‡ๅฎš็”Ÿๆˆ็ป“ๆžœ้•ฟๅบฆ็š„ไธŠ้™ใ€‚ๅฆ‚ๆžœ็”Ÿๆˆ็ป“ๆžœๆˆชๆ–ญ๏ผŒๅฏไปฅ่ฐƒๅคง่ฏฅๅ‚ๆ•ฐใ€‚", - }, - "required": False, - "default": 64, - "min": 1, - "max": 2048, - "precision": 0, - }, - DefaultParameterName.RESPONSE_FORMAT: { - "label": { - "en_US": "Response Format", - "zh_Hans": "ๅ›žๅคๆ ผๅผ", - }, - "type": "string", - "help": { - "en_US": "Set a response format, ensure the output from llm is a valid code block as possible," - " such as JSON, XML, etc.", - "zh_Hans": "่ฎพ็ฝฎไธ€ไธช่ฟ”ๅ›žๆ ผๅผ๏ผŒ็กฎไฟ llm ็š„่พ“ๅ‡บๅฐฝๅฏ่ƒฝๆ˜ฏๆœ‰ๆ•ˆ็š„ไปฃ็ ๅ—๏ผŒๅฆ‚ JSONใ€XML ็ญ‰", - }, - "required": False, - "options": ["JSON", "XML"], - }, - DefaultParameterName.JSON_SCHEMA: { - "label": { - "en_US": "JSON Schema", - }, - "type": "text", - "help": { - "en_US": "Set a response json schema will ensure LLM to adhere it.", - "zh_Hans": "่ฎพ็ฝฎ่ฟ”ๅ›ž็š„ json schema๏ผŒllm ๅฐ†ๆŒ‰็…งๅฎƒ่ฟ”ๅ›ž", - }, - "required": False, - }, -} diff --git a/api/graphon/model_runtime/entities/llm_entities.py b/api/graphon/model_runtime/entities/llm_entities.py deleted file mode 100644 index bfc80f21c5a..00000000000 --- a/api/graphon/model_runtime/entities/llm_entities.py +++ /dev/null @@ -1,219 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from decimal import Decimal -from enum import StrEnum -from typing import Any, TypedDict, Union - -from pydantic import BaseModel, Field - -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage -from graphon.model_runtime.entities.model_entities import ModelUsage, PriceInfo - - -class LLMMode(StrEnum): - """ - Enum class for large language model mode. - """ - - COMPLETION = "completion" - CHAT = "chat" - - -class LLMUsageMetadata(TypedDict, total=False): - """ - TypedDict for LLM usage metadata. - All fields are optional. - """ - - prompt_tokens: int - completion_tokens: int - total_tokens: int - prompt_unit_price: Union[float, str] - completion_unit_price: Union[float, str] - total_price: Union[float, str] - currency: str - prompt_price_unit: Union[float, str] - completion_price_unit: Union[float, str] - prompt_price: Union[float, str] - completion_price: Union[float, str] - latency: float - time_to_first_token: float - time_to_generate: float - - -class LLMUsage(ModelUsage): - """ - Model class for llm usage. - """ - - prompt_tokens: int - prompt_unit_price: Decimal - prompt_price_unit: Decimal - prompt_price: Decimal - completion_tokens: int - completion_unit_price: Decimal - completion_price_unit: Decimal - completion_price: Decimal - total_tokens: int - total_price: Decimal - currency: str - latency: float - time_to_first_token: float | None = None - time_to_generate: float | None = None - - @classmethod - def empty_usage(cls): - return cls( - prompt_tokens=0, - prompt_unit_price=Decimal("0.0"), - prompt_price_unit=Decimal("0.0"), - prompt_price=Decimal("0.0"), - completion_tokens=0, - completion_unit_price=Decimal("0.0"), - completion_price_unit=Decimal("0.0"), - completion_price=Decimal("0.0"), - total_tokens=0, - total_price=Decimal("0.0"), - currency="USD", - latency=0.0, - time_to_first_token=None, - time_to_generate=None, - ) - - @classmethod - def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage: - """ - Create LLMUsage instance from metadata dictionary with default values. - - Args: - metadata: TypedDict containing usage metadata - - Returns: - LLMUsage instance with values from metadata or defaults - """ - prompt_tokens = metadata.get("prompt_tokens", 0) - completion_tokens = metadata.get("completion_tokens", 0) - total_tokens = metadata.get("total_tokens", 0) - - # If total_tokens is not provided but prompt and completion tokens are, - # calculate total_tokens - if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0): - total_tokens = prompt_tokens + completion_tokens - - return cls( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))), - completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))), - total_price=Decimal(str(metadata.get("total_price", 0))), - currency=metadata.get("currency", "USD"), - prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))), - completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))), - prompt_price=Decimal(str(metadata.get("prompt_price", 0))), - completion_price=Decimal(str(metadata.get("completion_price", 0))), - latency=metadata.get("latency", 0.0), - time_to_first_token=metadata.get("time_to_first_token"), - time_to_generate=metadata.get("time_to_generate"), - ) - - def plus(self, other: LLMUsage) -> LLMUsage: - """ - Add two LLMUsage instances together. - - :param other: Another LLMUsage instance to add - :return: A new LLMUsage instance with summed values - """ - if self.total_tokens == 0: - return other - else: - return LLMUsage( - prompt_tokens=self.prompt_tokens + other.prompt_tokens, - prompt_unit_price=other.prompt_unit_price, - prompt_price_unit=other.prompt_price_unit, - prompt_price=self.prompt_price + other.prompt_price, - completion_tokens=self.completion_tokens + other.completion_tokens, - completion_unit_price=other.completion_unit_price, - completion_price_unit=other.completion_price_unit, - completion_price=self.completion_price + other.completion_price, - total_tokens=self.total_tokens + other.total_tokens, - total_price=self.total_price + other.total_price, - currency=other.currency, - latency=self.latency + other.latency, - time_to_first_token=other.time_to_first_token, - time_to_generate=other.time_to_generate, - ) - - def __add__(self, other: LLMUsage) -> LLMUsage: - """ - Overload the + operator to add two LLMUsage instances. - - :param other: Another LLMUsage instance to add - :return: A new LLMUsage instance with summed values - """ - return self.plus(other) - - -class LLMResult(BaseModel): - """ - Model class for llm result. - """ - - id: str | None = None - model: str - prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) - message: AssistantPromptMessage - usage: LLMUsage - system_fingerprint: str | None = None - reasoning_content: str | None = None - - -class LLMStructuredOutput(BaseModel): - """ - Model class for llm structured output. - """ - - structured_output: Mapping[str, Any] | None = None - - -class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput): - """ - Model class for llm result with structured output. - """ - - -class LLMResultChunkDelta(BaseModel): - """ - Model class for llm result chunk delta. - """ - - index: int - message: AssistantPromptMessage - usage: LLMUsage | None = None - finish_reason: str | None = None - - -class LLMResultChunk(BaseModel): - """ - Model class for llm result chunk. - """ - - model: str - prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) - system_fingerprint: str | None = None - delta: LLMResultChunkDelta - - -class LLMResultChunkWithStructuredOutput(LLMResultChunk, LLMStructuredOutput): - """ - Model class for llm result chunk with structured output. - """ - - -class NumTokensResult(PriceInfo): - """ - Model class for number of tokens result. - """ - - tokens: int diff --git a/api/graphon/model_runtime/entities/message_entities.py b/api/graphon/model_runtime/entities/message_entities.py deleted file mode 100644 index 402bfdc6065..00000000000 --- a/api/graphon/model_runtime/entities/message_entities.py +++ /dev/null @@ -1,279 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from collections.abc import Mapping, Sequence -from enum import StrEnum, auto -from typing import Annotated, Any, Literal, Union - -from pydantic import BaseModel, Field, field_serializer, field_validator - - -class PromptMessageRole(StrEnum): - """ - Enum class for prompt message. - """ - - SYSTEM = auto() - USER = auto() - ASSISTANT = auto() - TOOL = auto() - - @classmethod - def value_of(cls, value: str) -> PromptMessageRole: - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid prompt message type value {value}") - - -class PromptMessageTool(BaseModel): - """ - Model class for prompt message tool. - """ - - name: str - description: str - parameters: dict - - -class PromptMessageFunction(BaseModel): - """ - Model class for prompt message function. - """ - - type: str = "function" - function: PromptMessageTool - - -class PromptMessageContentType(StrEnum): - """ - Enum class for prompt message content type. - """ - - TEXT = auto() - IMAGE = auto() - AUDIO = auto() - VIDEO = auto() - DOCUMENT = auto() - - -class PromptMessageContent(ABC, BaseModel): - """ - Model class for prompt message content. - """ - - type: PromptMessageContentType - - -class TextPromptMessageContent(PromptMessageContent): - """ - Model class for text prompt message content. - """ - - type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore - data: str - - -class MultiModalPromptMessageContent(PromptMessageContent): - """ - Model class for multi-modal prompt message content. - """ - - format: str = Field(default=..., description="the format of multi-modal file") - base64_data: str = Field(default="", description="the base64 data of multi-modal file") - url: str = Field(default="", description="the url of multi-modal file") - mime_type: str = Field(default=..., description="the mime type of multi-modal file") - filename: str = Field(default="", description="the filename of multi-modal file") - - @property - def data(self): - return self.url or f"data:{self.mime_type};base64,{self.base64_data}" - - -class VideoPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore - - -class AudioPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore - - -class ImagePromptMessageContent(MultiModalPromptMessageContent): - """ - Model class for image prompt message content. - """ - - class DETAIL(StrEnum): - LOW = auto() - HIGH = auto() - - type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore - detail: DETAIL = DETAIL.LOW - - -class DocumentPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore - - -PromptMessageContentUnionTypes = Annotated[ - Union[ - TextPromptMessageContent, - ImagePromptMessageContent, - DocumentPromptMessageContent, - AudioPromptMessageContent, - VideoPromptMessageContent, - ], - Field(discriminator="type"), -] - - -CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = { - PromptMessageContentType.TEXT: TextPromptMessageContent, - PromptMessageContentType.IMAGE: ImagePromptMessageContent, - PromptMessageContentType.AUDIO: AudioPromptMessageContent, - PromptMessageContentType.VIDEO: VideoPromptMessageContent, - PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent, -} - - -class PromptMessage(ABC, BaseModel): - """ - Model class for prompt message. - """ - - role: PromptMessageRole - content: str | list[PromptMessageContentUnionTypes] | None = None - name: str | None = None - - def is_empty(self) -> bool: - """ - Check if prompt message is empty. - - :return: True if prompt message is empty, False otherwise - """ - return not self.content - - def get_text_content(self) -> str: - """ - Get text content from prompt message. - - :return: Text content as string, empty string if no text content - """ - if isinstance(self.content, str): - return self.content - elif isinstance(self.content, list): - text_parts = [] - for item in self.content: - if isinstance(item, TextPromptMessageContent): - text_parts.append(item.data) - return "".join(text_parts) - else: - return "" - - @field_validator("content", mode="before") - @classmethod - def validate_content(cls, v): - if isinstance(v, list): - prompts = [] - for prompt in v: - if isinstance(prompt, PromptMessageContent): - if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent): - prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump()) - elif isinstance(prompt, dict): - prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt) - else: - raise ValueError(f"invalid prompt message {prompt}") - prompts.append(prompt) - return prompts - return v - - @field_serializer("content") - def serialize_content( - self, content: Union[str, Sequence[PromptMessageContent]] | None - ) -> str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent] | None: - if content is None or isinstance(content, str): - return content - if isinstance(content, list): - return [item.model_dump() if hasattr(item, "model_dump") else item for item in content] - return content - - -class UserPromptMessage(PromptMessage): - """ - Model class for user prompt message. - """ - - role: PromptMessageRole = PromptMessageRole.USER - - -class AssistantPromptMessage(PromptMessage): - """ - Model class for assistant prompt message. - """ - - class ToolCall(BaseModel): - """ - Model class for assistant prompt message tool call. - """ - - class ToolCallFunction(BaseModel): - """ - Model class for assistant prompt message tool call function. - """ - - name: str - arguments: str - - id: str - type: str - function: ToolCallFunction - - @field_validator("id", mode="before") - @classmethod - def transform_id_to_str(cls, value) -> str: - if not isinstance(value, str): - return str(value) - else: - return value - - role: PromptMessageRole = PromptMessageRole.ASSISTANT - tool_calls: list[ToolCall] = [] - - def is_empty(self) -> bool: - """ - Check if prompt message is empty. - - :return: True if prompt message is empty, False otherwise - """ - return super().is_empty() and not self.tool_calls - - -class SystemPromptMessage(PromptMessage): - """ - Model class for system prompt message. - """ - - role: PromptMessageRole = PromptMessageRole.SYSTEM - - -class ToolPromptMessage(PromptMessage): - """ - Model class for tool prompt message. - """ - - role: PromptMessageRole = PromptMessageRole.TOOL - tool_call_id: str - - def is_empty(self) -> bool: - """ - Check if prompt message is empty. - - :return: True if prompt message is empty, False otherwise - """ - return super().is_empty() and not self.tool_call_id diff --git a/api/graphon/model_runtime/entities/model_entities.py b/api/graphon/model_runtime/entities/model_entities.py deleted file mode 100644 index 5ec4970faf9..00000000000 --- a/api/graphon/model_runtime/entities/model_entities.py +++ /dev/null @@ -1,242 +0,0 @@ -from __future__ import annotations - -from decimal import Decimal -from enum import StrEnum, auto -from typing import Any - -from pydantic import BaseModel, ConfigDict, model_validator - -from graphon.model_runtime.entities.common_entities import I18nObject - - -class ModelType(StrEnum): - """ - Enum class for model type. - """ - - LLM = auto() - TEXT_EMBEDDING = "text-embedding" - RERANK = auto() - SPEECH2TEXT = auto() - MODERATION = auto() - TTS = auto() - - @classmethod - def value_of(cls, origin_model_type: str) -> ModelType: - """ - Get model type from origin model type. - - :return: model type - """ - if origin_model_type in {"text-generation", cls.LLM}: - return cls.LLM - elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}: - return cls.TEXT_EMBEDDING - elif origin_model_type in {"reranking", cls.RERANK}: - return cls.RERANK - elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}: - return cls.SPEECH2TEXT - elif origin_model_type in {"tts", cls.TTS}: - return cls.TTS - elif origin_model_type == cls.MODERATION: - return cls.MODERATION - else: - raise ValueError(f"invalid origin model type {origin_model_type}") - - def to_origin_model_type(self) -> str: - """ - Get origin model type from model type. - - :return: origin model type - """ - if self == self.LLM: - return "text-generation" - elif self == self.TEXT_EMBEDDING: - return "embeddings" - elif self == self.RERANK: - return "reranking" - elif self == self.SPEECH2TEXT: - return "speech2text" - elif self == self.TTS: - return "tts" - elif self == self.MODERATION: - return "moderation" - else: - raise ValueError(f"invalid model type {self}") - - -class FetchFrom(StrEnum): - """ - Enum class for fetch from. - """ - - PREDEFINED_MODEL = "predefined-model" - CUSTOMIZABLE_MODEL = "customizable-model" - - -class ModelFeature(StrEnum): - """ - Enum class for llm feature. - """ - - TOOL_CALL = "tool-call" - MULTI_TOOL_CALL = "multi-tool-call" - AGENT_THOUGHT = "agent-thought" - VISION = auto() - STREAM_TOOL_CALL = "stream-tool-call" - DOCUMENT = auto() - VIDEO = auto() - AUDIO = auto() - STRUCTURED_OUTPUT = "structured-output" - - -class DefaultParameterName(StrEnum): - """ - Enum class for parameter template variable. - """ - - TEMPERATURE = auto() - TOP_P = auto() - TOP_K = auto() - PRESENCE_PENALTY = auto() - FREQUENCY_PENALTY = auto() - MAX_TOKENS = auto() - RESPONSE_FORMAT = auto() - JSON_SCHEMA = auto() - - @classmethod - def value_of(cls, value: Any) -> DefaultParameterName: - """ - Get parameter name from value. - - :param value: parameter value - :return: parameter name - """ - for name in cls: - if name.value == value: - return name - raise ValueError(f"invalid parameter name {value}") - - -class ParameterType(StrEnum): - """ - Enum class for parameter type. - """ - - FLOAT = auto() - INT = auto() - STRING = auto() - BOOLEAN = auto() - TEXT = auto() - - -class ModelPropertyKey(StrEnum): - """ - Enum class for model property key. - """ - - MODE = auto() - CONTEXT_SIZE = auto() - MAX_CHUNKS = auto() - FILE_UPLOAD_LIMIT = auto() - SUPPORTED_FILE_EXTENSIONS = auto() - MAX_CHARACTERS_PER_CHUNK = auto() - DEFAULT_VOICE = auto() - VOICES = auto() - WORD_LIMIT = auto() - AUDIO_TYPE = auto() - MAX_WORKERS = auto() - - -class ProviderModel(BaseModel): - """ - Model class for provider model. - """ - - model: str - label: I18nObject - model_type: ModelType - features: list[ModelFeature] | None = None - fetch_from: FetchFrom - model_properties: dict[ModelPropertyKey, Any] - deprecated: bool = False - model_config = ConfigDict(protected_namespaces=()) - - @property - def support_structure_output(self) -> bool: - return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features - - -class ParameterRule(BaseModel): - """ - Model class for parameter rule. - """ - - name: str - use_template: str | None = None - label: I18nObject - type: ParameterType - help: I18nObject | None = None - required: bool = False - default: Any | None = None - min: float | None = None - max: float | None = None - precision: int | None = None - options: list[str] = [] - - -class PriceConfig(BaseModel): - """ - Model class for pricing info. - """ - - input: Decimal - output: Decimal | None = None - unit: Decimal - currency: str - - -class AIModelEntity(ProviderModel): - """ - Model class for AI model. - """ - - parameter_rules: list[ParameterRule] = [] - pricing: PriceConfig | None = None - - @model_validator(mode="after") - def validate_model(self): - supported_schema_keys = ["json_schema"] - schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None) - if not schema_key: - return self - if self.features is None: - self.features = [ModelFeature.STRUCTURED_OUTPUT] - else: - if ModelFeature.STRUCTURED_OUTPUT not in self.features: - self.features.append(ModelFeature.STRUCTURED_OUTPUT) - return self - - -class ModelUsage(BaseModel): - pass - - -class PriceType(StrEnum): - """ - Enum class for price type. - """ - - INPUT = auto() - OUTPUT = auto() - - -class PriceInfo(BaseModel): - """ - Model class for price info. - """ - - unit_price: Decimal - unit: Decimal - total_amount: Decimal - currency: str diff --git a/api/graphon/model_runtime/entities/provider_entities.py b/api/graphon/model_runtime/entities/provider_entities.py deleted file mode 100644 index 8e6c516fb9c..00000000000 --- a/api/graphon/model_runtime/entities/provider_entities.py +++ /dev/null @@ -1,179 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum, auto - -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType - - -class ConfigurateMethod(StrEnum): - """ - Enum class for configurate method of provider model. - """ - - PREDEFINED_MODEL = "predefined-model" - CUSTOMIZABLE_MODEL = "customizable-model" - - -class FormType(StrEnum): - """ - Enum class for form type. - """ - - TEXT_INPUT = "text-input" - SECRET_INPUT = "secret-input" - SELECT = auto() - RADIO = auto() - SWITCH = auto() - - -class FormShowOnObject(BaseModel): - """ - Model class for form show on. - """ - - variable: str - value: str - - -class FormOption(BaseModel): - """ - Model class for form option. - """ - - label: I18nObject - value: str - show_on: list[FormShowOnObject] = [] - - @model_validator(mode="after") - def _(self): - if not self.label: - self.label = I18nObject(en_US=self.value) - return self - - -class CredentialFormSchema(BaseModel): - """ - Model class for credential form schema. - """ - - variable: str - label: I18nObject - type: FormType - required: bool = True - default: str | None = None - options: list[FormOption] | None = None - placeholder: I18nObject | None = None - max_length: int = 0 - show_on: list[FormShowOnObject] = [] - - -class ProviderCredentialSchema(BaseModel): - """ - Model class for provider credential schema. - """ - - credential_form_schemas: list[CredentialFormSchema] - - -class FieldModelSchema(BaseModel): - label: I18nObject - placeholder: I18nObject | None = None - - -class ModelCredentialSchema(BaseModel): - """ - Model class for model credential schema. - """ - - model: FieldModelSchema - credential_form_schemas: list[CredentialFormSchema] - - -class SimpleProviderEntity(BaseModel): - """ - Simplified provider schema exposed to callers. - - `provider` is the canonical runtime identifier. `provider_name` is an optional - compatibility alias for short-name lookups and is empty when no alias exists. - """ - - provider: str - provider_name: str = "" - label: I18nObject - icon_small: I18nObject | None = None - icon_small_dark: I18nObject | None = None - supported_model_types: Sequence[ModelType] - models: list[AIModelEntity] = [] - - -class ProviderHelpEntity(BaseModel): - """ - Model class for provider help. - """ - - title: I18nObject - url: I18nObject - - -class ProviderEntity(BaseModel): - """ - Runtime-native provider schema. - - `provider` is the canonical runtime identifier. `provider_name` is a - compatibility alias for callers that still resolve providers by short name and - is empty when no alias exists. - """ - - provider: str - provider_name: str = "" - label: I18nObject - description: I18nObject | None = None - icon_small: I18nObject | None = None - icon_small_dark: I18nObject | None = None - background: str | None = None - help: ProviderHelpEntity | None = None - supported_model_types: Sequence[ModelType] - configurate_methods: list[ConfigurateMethod] - models: list[AIModelEntity] = Field(default_factory=list) - provider_credential_schema: ProviderCredentialSchema | None = None - model_credential_schema: ModelCredentialSchema | None = None - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - # position from plugin _position.yaml - position: dict[str, list[str]] | None = {} - - @field_validator("models", mode="before") - @classmethod - def validate_models(cls, v): - # returns EmptyList if v is empty - if not v: - return [] - return v - - def to_simple_provider(self) -> SimpleProviderEntity: - """ - Convert to simple provider. - - :return: simple provider - """ - return SimpleProviderEntity( - provider=self.provider, - provider_name=self.provider_name, - label=self.label, - icon_small=self.icon_small, - supported_model_types=self.supported_model_types, - models=self.models, - ) - - -class ProviderConfig(BaseModel): - """ - Model class for provider config. - """ - - provider: str - credentials: dict diff --git a/api/graphon/model_runtime/entities/rerank_entities.py b/api/graphon/model_runtime/entities/rerank_entities.py deleted file mode 100644 index 8a0bb5fac2b..00000000000 --- a/api/graphon/model_runtime/entities/rerank_entities.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import TypedDict - -from pydantic import BaseModel - - -class MultimodalRerankInput(TypedDict): - content: str - content_type: str - - -class RerankDocument(BaseModel): - """ - Model class for rerank document. - """ - - index: int - text: str - score: float - - -class RerankResult(BaseModel): - """ - Model class for rerank result. - """ - - model: str - docs: list[RerankDocument] diff --git a/api/graphon/model_runtime/entities/text_embedding_entities.py b/api/graphon/model_runtime/entities/text_embedding_entities.py deleted file mode 100644 index 08ffd83b5be..00000000000 --- a/api/graphon/model_runtime/entities/text_embedding_entities.py +++ /dev/null @@ -1,47 +0,0 @@ -from decimal import Decimal -from enum import StrEnum, auto - -from pydantic import BaseModel - -from graphon.model_runtime.entities.model_entities import ModelUsage - - -class EmbeddingInputType(StrEnum): - """Embedding request input variants understood by the model runtime.""" - - DOCUMENT = auto() - QUERY = auto() - - -class EmbeddingUsage(ModelUsage): - """ - Model class for embedding usage. - """ - - tokens: int - total_tokens: int - unit_price: Decimal - price_unit: Decimal - total_price: Decimal - currency: str - latency: float - - -class EmbeddingResult(BaseModel): - """ - Model class for text embedding result. - """ - - model: str - embeddings: list[list[float]] - usage: EmbeddingUsage - - -class FileEmbeddingResult(BaseModel): - """ - Model class for file embedding result. - """ - - model: str - embeddings: list[list[float]] - usage: EmbeddingUsage diff --git a/api/graphon/model_runtime/errors/__init__.py b/api/graphon/model_runtime/errors/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/graphon/model_runtime/errors/invoke.py b/api/graphon/model_runtime/errors/invoke.py deleted file mode 100644 index 1a57078b988..00000000000 --- a/api/graphon/model_runtime/errors/invoke.py +++ /dev/null @@ -1,41 +0,0 @@ -class InvokeError(ValueError): - """Base class for all LLM exceptions.""" - - description: str | None = None - - def __init__(self, description: str | None = None): - if description is not None: - self.description = description - - def __str__(self): - return self.description or self.__class__.__name__ - - -class InvokeConnectionError(InvokeError): - """Raised when the Invoke returns connection error.""" - - description = "Connection Error" - - -class InvokeServerUnavailableError(InvokeError): - """Raised when the Invoke returns server unavailable error.""" - - description = "Server Unavailable Error" - - -class InvokeRateLimitError(InvokeError): - """Raised when the Invoke returns rate limit error.""" - - description = "Rate Limit Error" - - -class InvokeAuthorizationError(InvokeError): - """Raised when the Invoke returns authorization error.""" - - description = "Incorrect model credentials provided, please check and try again. " - - -class InvokeBadRequestError(InvokeError): - """Raised when the Invoke returns bad request.""" - - description = "Bad Request Error" diff --git a/api/graphon/model_runtime/errors/validate.py b/api/graphon/model_runtime/errors/validate.py deleted file mode 100644 index 16bebcc67db..00000000000 --- a/api/graphon/model_runtime/errors/validate.py +++ /dev/null @@ -1,6 +0,0 @@ -class CredentialsValidateFailedError(ValueError): - """ - Credentials validate failed error - """ - - pass diff --git a/api/graphon/model_runtime/memory/__init__.py b/api/graphon/model_runtime/memory/__init__.py deleted file mode 100644 index 2d954486c30..00000000000 --- a/api/graphon/model_runtime/memory/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory - -__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"] diff --git a/api/graphon/model_runtime/memory/prompt_message_memory.py b/api/graphon/model_runtime/memory/prompt_message_memory.py deleted file mode 100644 index 03e26e9ff59..00000000000 --- a/api/graphon/model_runtime/memory/prompt_message_memory.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from typing import Protocol - -from graphon.model_runtime.entities import PromptMessage - -DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000 - - -class PromptMessageMemory(Protocol): - """Port for loading memory as prompt messages.""" - - def get_history_prompt_messages( - self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None - ) -> Sequence[PromptMessage]: - """Return historical prompt messages constrained by token/message limits.""" - ... diff --git a/api/graphon/model_runtime/model_providers/__base/__init__.py b/api/graphon/model_runtime/model_providers/__base/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/graphon/model_runtime/model_providers/__base/ai_model.py b/api/graphon/model_runtime/model_providers/__base/ai_model.py deleted file mode 100644 index 1700ec97402..00000000000 --- a/api/graphon/model_runtime/model_providers/__base/ai_model.py +++ /dev/null @@ -1,247 +0,0 @@ -import decimal - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE -from graphon.model_runtime.entities.model_entities import ( - AIModelEntity, - DefaultParameterName, - ModelType, - PriceConfig, - PriceInfo, - PriceType, -) -from graphon.model_runtime.entities.provider_entities import ProviderEntity -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.runtime import ModelRuntime - - -class AIModel: - """ - Runtime-facing base class for all model providers. - - This stays a regular Python class because instances hold live collaborators - such as the provider schema and runtime adapter rather than user input that - benefits from Pydantic validation. Subclasses must pin ``model_type`` via a - class attribute; the base class is not meant to be instantiated directly. - """ - - model_type: ModelType - provider_schema: ProviderEntity - model_runtime: ModelRuntime - started_at: float - - def __init__( - self, - provider_schema: ProviderEntity, - model_runtime: ModelRuntime, - *, - started_at: float = 0, - ) -> None: - if getattr(type(self), "model_type", None) is None: - raise TypeError("AIModel subclasses must define model_type as a class attribute") - - self.model_type = type(self).model_type - self.provider_schema = provider_schema - self.model_runtime = model_runtime - self.started_at = started_at - - @property - def provider(self) -> str: - return self.provider_schema.provider - - @property - def provider_display_name(self) -> str: - return self.provider_schema.label.en_US - - @property - def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]: - """ - Map model invoke error to unified error. - - The key is the error type thrown to the caller, and the value contains - runtime-facing exception types that should be normalized to it. - """ - return { - InvokeConnectionError: [InvokeConnectionError], - InvokeServerUnavailableError: [InvokeServerUnavailableError], - InvokeRateLimitError: [InvokeRateLimitError], - InvokeAuthorizationError: [InvokeAuthorizationError], - InvokeBadRequestError: [InvokeBadRequestError], - ValueError: [ValueError], - } - - def _transform_invoke_error(self, error: Exception) -> Exception: - """ - Transform invoke error to unified error - - :param error: model invoke error - :return: unified error - """ - for invoke_error, model_errors in self._invoke_error_mapping.items(): - if isinstance(error, tuple(model_errors)): - if invoke_error == InvokeAuthorizationError: - return InvokeAuthorizationError( - description=( - f"[{self.provider_display_name}] Incorrect model credentials provided, " - "please check and try again." - ) - ) - elif isinstance(invoke_error, InvokeError): - return InvokeError( - description=f"[{self.provider_display_name}] {invoke_error.description}, {str(error)}" - ) - else: - return error - - return InvokeError(description=f"[{self.provider_display_name}] Error: {str(error)}") - - def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo: - """ - Get price for given model and tokens - - :param model: model name - :param credentials: model credentials - :param price_type: price type - :param tokens: number of tokens - :return: price info - """ - # get model schema - model_schema = self.get_model_schema(model, credentials) - - # get price info from predefined model schema - price_config: PriceConfig | None = None - if model_schema and model_schema.pricing: - price_config = model_schema.pricing - - # get unit price - unit_price = None - if price_config: - if price_type == PriceType.INPUT: - unit_price = price_config.input - elif price_type == PriceType.OUTPUT and price_config.output is not None: - unit_price = price_config.output - - if unit_price is None: - return PriceInfo( - unit_price=decimal.Decimal("0.0"), - unit=decimal.Decimal("0.0"), - total_amount=decimal.Decimal("0.0"), - currency="USD", - ) - - # calculate total amount - if not price_config: - raise ValueError(f"Price config not found for model {model}") - total_amount = tokens * unit_price * price_config.unit - total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP) - - return PriceInfo( - unit_price=unit_price, - unit=price_config.unit, - total_amount=total_amount, - currency=price_config.currency, - ) - - def get_model_schema(self, model: str, credentials: dict | None = None) -> AIModelEntity | None: - """ - Get model schema by model name and credentials - - :param model: model name - :param credentials: model credentials - :return: model schema - """ - return self.model_runtime.get_model_schema( - provider=self.provider, - model_type=self.model_type, - model=model, - credentials=credentials or {}, - ) - - def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: - """ - Get customizable model schema from credentials - - :param model: model name - :param credentials: model credentials - :return: model schema - """ - - # get customizable model schema - schema = self.get_customizable_model_schema(model, credentials) - if not schema: - return None - - # fill in the template - new_parameter_rules = [] - for parameter_rule in schema.parameter_rules: - if parameter_rule.use_template: - try: - default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template) - default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) - if not parameter_rule.max and "max" in default_parameter_rule: - parameter_rule.max = default_parameter_rule["max"] - if not parameter_rule.min and "min" in default_parameter_rule: - parameter_rule.min = default_parameter_rule["min"] - if not parameter_rule.default and "default" in default_parameter_rule: - parameter_rule.default = default_parameter_rule["default"] - if not parameter_rule.precision and "precision" in default_parameter_rule: - parameter_rule.precision = default_parameter_rule["precision"] - if not parameter_rule.required and "required" in default_parameter_rule: - parameter_rule.required = default_parameter_rule["required"] - if not parameter_rule.help and "help" in default_parameter_rule: - parameter_rule.help = I18nObject( - en_US=default_parameter_rule["help"]["en_US"], - ) - if ( - parameter_rule.help - and not parameter_rule.help.en_US - and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"]) - ): - parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"] - if ( - parameter_rule.help - and not parameter_rule.help.zh_Hans - and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"]) - ): - parameter_rule.help.zh_Hans = default_parameter_rule["help"].get( - "zh_Hans", default_parameter_rule["help"]["en_US"] - ) - except ValueError: - pass - - new_parameter_rules.append(parameter_rule) - - schema.parameter_rules = new_parameter_rules - - return schema - - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: - """ - Get customizable model schema - - :param model: model name - :param credentials: model credentials - :return: model schema - """ - return None - - def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName): - """ - Get default parameter rule for given name - - :param name: parameter name - :return: parameter rule - """ - default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name) - - if not default_parameter_rule: - raise Exception(f"Invalid model parameter rule name {name}") - - return default_parameter_rule diff --git a/api/graphon/model_runtime/model_providers/__base/large_language_model.py b/api/graphon/model_runtime/model_providers/__base/large_language_model.py deleted file mode 100644 index 0f909646a12..00000000000 --- a/api/graphon/model_runtime/model_providers/__base/large_language_model.py +++ /dev/null @@ -1,638 +0,0 @@ -import logging -import time -import uuid -from collections.abc import Callable, Generator, Iterator, Mapping, Sequence -from typing import Union - -from graphon.model_runtime.callbacks.base_callback import Callback -from graphon.model_runtime.callbacks.logging_callback import LoggingCallback -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageContentUnionTypes, - PromptMessageTool, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ( - ModelType, - PriceType, -) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -def _gen_tool_call_id() -> str: - return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" - - -def _run_callbacks(callbacks: Sequence[Callback] | None, *, event: str, invoke: Callable[[Callback], None]) -> None: - if not callbacks: - return - - for callback in callbacks: - try: - invoke(callback) - except Exception as e: - if callback.raise_error: - raise - logger.warning("Callback %s %s failed with error %s", callback.__class__.__name__, event, e) - - -def _get_or_create_tool_call( - existing_tools_calls: list[AssistantPromptMessage.ToolCall], - tool_call_id: str, -) -> AssistantPromptMessage.ToolCall: - """ - Get or create a tool call by ID. - - If `tool_call_id` is empty, returns the most recently created tool call. - """ - if not tool_call_id: - if not existing_tools_calls: - raise ValueError("tool_call_id is empty but no existing tool call is available to apply the delta") - return existing_tools_calls[-1] - - tool_call = next((tool_call for tool_call in existing_tools_calls if tool_call.id == tool_call_id), None) - if tool_call is None: - tool_call = AssistantPromptMessage.ToolCall( - id=tool_call_id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), - ) - existing_tools_calls.append(tool_call) - - return tool_call - - -def _merge_tool_call_delta( - tool_call: AssistantPromptMessage.ToolCall, - delta: AssistantPromptMessage.ToolCall, -) -> None: - if delta.id: - tool_call.id = delta.id - if delta.type: - tool_call.type = delta.type - if delta.function.name: - tool_call.function.name = delta.function.name - if delta.function.arguments: - tool_call.function.arguments += delta.function.arguments - - -def _build_llm_result_from_chunks( - model: str, - prompt_messages: Sequence[PromptMessage], - chunks: Iterator[LLMResultChunk], -) -> LLMResult: - """ - Build a single `LLMResult` by accumulating all returned chunks. - - Some models only support streaming output (e.g. Qwen3 open-source edition) - and the plugin side may still implement the response via a chunked stream, - so all chunks must be consumed and concatenated into a single ``LLMResult``. - - The ``usage`` is taken from the last chunk that carries it, which is the - typical convention for streaming responses (the final chunk contains the - aggregated token counts). - """ - content = "" - content_list: list[PromptMessageContentUnionTypes] = [] - usage = LLMUsage.empty_usage() - system_fingerprint: str | None = None - tools_calls: list[AssistantPromptMessage.ToolCall] = [] - - try: - for chunk in chunks: - if isinstance(chunk.delta.message.content, str): - content += chunk.delta.message.content - elif isinstance(chunk.delta.message.content, list): - content_list.extend(chunk.delta.message.content) - - if chunk.delta.message.tool_calls: - _increase_tool_call(chunk.delta.message.tool_calls, tools_calls) - - if chunk.delta.usage: - usage = chunk.delta.usage - if chunk.system_fingerprint: - system_fingerprint = chunk.system_fingerprint - except Exception: - logger.exception("Error while consuming non-stream plugin chunk iterator.") - raise - finally: - # Drain any remaining chunks to release underlying streaming resources (e.g. HTTP connections). - close = getattr(chunks, "close", None) - if callable(close): - close() - - return LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=content or content_list, - tool_calls=tools_calls, - ), - usage=usage, - system_fingerprint=system_fingerprint, - ) - - -def _invoke_llm_via_runtime( - *, - llm_model: "LargeLanguageModel", - provider: str, - model: str, - credentials: dict, - model_parameters: dict, - prompt_messages: Sequence[PromptMessage], - tools: list[PromptMessageTool] | None, - stop: Sequence[str] | None, - stream: bool, -) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: - return llm_model.model_runtime.invoke_llm( - provider=provider, - model=model, - credentials=credentials, - model_parameters=model_parameters, - prompt_messages=list(prompt_messages), - tools=tools, - stop=stop, - stream=stream, - ) - - -def _normalize_non_stream_runtime_result( - model: str, - prompt_messages: Sequence[PromptMessage], - result: Union[LLMResult, Iterator[LLMResultChunk]], -) -> LLMResult: - if isinstance(result, LLMResult): - return result - return _build_llm_result_from_chunks(model=model, prompt_messages=prompt_messages, chunks=result) - - -def _increase_tool_call( - new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall] -): - """ - Merge incremental tool call updates into existing tool calls. - - :param new_tool_calls: List of new tool call deltas to be merged. - :param existing_tools_calls: List of existing tool calls to be modified IN-PLACE. - """ - - for new_tool_call in new_tool_calls: - # generate ID for tool calls with function name but no ID to track them - if new_tool_call.function.name and not new_tool_call.id: - new_tool_call.id = _gen_tool_call_id() - - tool_call = _get_or_create_tool_call(existing_tools_calls, new_tool_call.id) - _merge_tool_call_delta(tool_call, new_tool_call) - - -class LargeLanguageModel(AIModel): - """ - Model class for large language model. - """ - - model_type: ModelType = ModelType.LLM - - def invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict | None = None, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, - stream: bool = True, - callbacks: list[Callback] | None = None, - ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: - """ - Invoke large language model - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param callbacks: callbacks - :return: full response or stream response chunk generator result - """ - # validate and filter model parameters - if model_parameters is None: - model_parameters = {} - - self.started_at = time.perf_counter() - - callbacks = callbacks or [] - - if logger.isEnabledFor(logging.DEBUG): - callbacks.append(LoggingCallback()) - - # trigger before invoke callbacks - self._trigger_before_invoke_callbacks( - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - callbacks=callbacks, - ) - - result: Union[LLMResult, Generator[LLMResultChunk, None, None]] - - try: - result = _invoke_llm_via_runtime( - llm_model=self, - provider=self.provider, - model=model, - credentials=credentials, - model_parameters=model_parameters, - prompt_messages=prompt_messages, - tools=tools, - stop=stop, - stream=stream, - ) - - if not stream: - result = _normalize_non_stream_runtime_result( - model=model, prompt_messages=prompt_messages, result=result - ) - except Exception as e: - self._trigger_invoke_error_callbacks( - model=model, - ex=e, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - callbacks=callbacks, - ) - - # TODO - raise self._transform_invoke_error(e) - - if stream and not isinstance(result, LLMResult): - return self._invoke_result_generator( - model=model, - result=result, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - callbacks=callbacks, - ) - elif isinstance(result, LLMResult): - self._trigger_after_invoke_callbacks( - model=model, - result=result, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - callbacks=callbacks, - ) - # Following https://github.com/langgenius/dify/issues/17799, - # we removed the prompt_messages from the chunk on the plugin daemon side. - # To ensure compatibility, we add the prompt_messages back here. - result.prompt_messages = prompt_messages - return result - raise NotImplementedError("unsupported invoke result type", type(result)) - - def _invoke_result_generator( - self, - model: str, - result: Generator[LLMResultChunk, None, None], - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - invocation_context: Mapping[str, object] | None = None, - callbacks: list[Callback] | None = None, - ) -> Generator[LLMResultChunk, None, None]: - """ - Invoke result generator - - :param result: result generator - :return: result generator - """ - callbacks = callbacks or [] - message_content: list[PromptMessageContentUnionTypes] = [] - usage = None - system_fingerprint = None - real_model = model - - def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None): - if not content: - return - if isinstance(content, list): - message_content.extend(content) - return - if isinstance(content, str): - message_content.append(TextPromptMessageContent(data=content)) - return - - try: - for chunk in result: - # Following https://github.com/langgenius/dify/issues/17799, - # we removed the prompt_messages from the chunk on the plugin daemon side. - # To ensure compatibility, we add the prompt_messages back here. - chunk.prompt_messages = prompt_messages - yield chunk - - self._trigger_new_chunk_callbacks( - chunk=chunk, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - callbacks=callbacks, - ) - - _update_message_content(chunk.delta.message.content) - - real_model = chunk.model - if chunk.delta.usage: - usage = chunk.delta.usage - - if chunk.system_fingerprint: - system_fingerprint = chunk.system_fingerprint - except Exception as e: - raise self._transform_invoke_error(e) - - assistant_message = AssistantPromptMessage(content=message_content) - self._trigger_after_invoke_callbacks( - model=model, - result=LLMResult( - model=real_model, - prompt_messages=prompt_messages, - message=assistant_message, - usage=usage or LLMUsage.empty_usage(), - system_fingerprint=system_fingerprint, - ), - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - callbacks=callbacks, - ) - - def get_num_tokens( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None, - ) -> int: - """ - Get number of tokens for given prompt messages - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param tools: tools for tool calling - :return: - """ - return self.model_runtime.get_llm_num_tokens( - provider=self.provider, - model_type=self.model_type, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, - ) - - def calc_response_usage( - self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int - ) -> LLMUsage: - """ - Calculate response usage - - :param model: model name - :param credentials: model credentials - :param prompt_tokens: prompt tokens - :param completion_tokens: completion tokens - :return: usage - """ - # get prompt price info - prompt_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=prompt_tokens, - ) - - # get completion price info - completion_price_info = self.get_price( - model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens - ) - - # transform usage - usage = LLMUsage( - prompt_tokens=prompt_tokens, - prompt_unit_price=prompt_price_info.unit_price, - prompt_price_unit=prompt_price_info.unit, - prompt_price=prompt_price_info.total_amount, - completion_tokens=completion_tokens, - completion_unit_price=completion_price_info.unit_price, - completion_price_unit=completion_price_info.unit, - completion_price=completion_price_info.total_amount, - total_tokens=prompt_tokens + completion_tokens, - total_price=prompt_price_info.total_amount + completion_price_info.total_amount, - currency=prompt_price_info.currency, - latency=time.perf_counter() - self.started_at, - ) - - return usage - - def _trigger_before_invoke_callbacks( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - invocation_context: Mapping[str, object] | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger before invoke callbacks - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - :param callbacks: callbacks - """ - _run_callbacks( - callbacks, - event="on_before_invoke", - invoke=lambda callback: callback.on_before_invoke( - llm_instance=self, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - ), - ) - - def _trigger_new_chunk_callbacks( - self, - chunk: LLMResultChunk, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - invocation_context: Mapping[str, object] | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger new chunk callbacks - - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - """ - _run_callbacks( - callbacks, - event="on_new_chunk", - invoke=lambda callback: callback.on_new_chunk( - llm_instance=self, - chunk=chunk, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - ), - ) - - def _trigger_after_invoke_callbacks( - self, - model: str, - result: LLMResult, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - invocation_context: Mapping[str, object] | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger after invoke callbacks - - :param model: model name - :param result: result - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - :param callbacks: callbacks - """ - _run_callbacks( - callbacks, - event="on_after_invoke", - invoke=lambda callback: callback.on_after_invoke( - llm_instance=self, - result=result, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - ), - ) - - def _trigger_invoke_error_callbacks( - self, - model: str, - ex: Exception, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - invocation_context: Mapping[str, object] | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger invoke error callbacks - - :param model: model name - :param ex: exception - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - :param callbacks: callbacks - """ - _run_callbacks( - callbacks, - event="on_invoke_error", - invoke=lambda callback: callback.on_invoke_error( - llm_instance=self, - ex=ex, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - ), - ) diff --git a/api/graphon/model_runtime/model_providers/__base/moderation_model.py b/api/graphon/model_runtime/model_providers/__base/moderation_model.py deleted file mode 100644 index 01f68429983..00000000000 --- a/api/graphon/model_runtime/model_providers/__base/moderation_model.py +++ /dev/null @@ -1,33 +0,0 @@ -import time - -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - - -class ModerationModel(AIModel): - """ - Model class for moderation model. - """ - - model_type: ModelType = ModelType.MODERATION - - def invoke(self, model: str, credentials: dict, text: str) -> bool: - """ - Invoke moderation model - - :param model: model name - :param credentials: model credentials - :param text: text to moderate - :return: false if text is safe, true otherwise - """ - self.started_at = time.perf_counter() - - try: - return self.model_runtime.invoke_moderation( - provider=self.provider, - model=model, - credentials=credentials, - text=text, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/graphon/model_runtime/model_providers/__base/rerank_model.py b/api/graphon/model_runtime/model_providers/__base/rerank_model.py deleted file mode 100644 index 94b2b5a4fba..00000000000 --- a/api/graphon/model_runtime/model_providers/__base/rerank_model.py +++ /dev/null @@ -1,76 +0,0 @@ -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - - -class RerankModel(AIModel): - """ - Base Model class for rerank model. - """ - - model_type: ModelType = ModelType.RERANK - - def invoke( - self, - model: str, - credentials: dict, - query: str, - docs: list[str], - score_threshold: float | None = None, - top_n: int | None = None, - ) -> RerankResult: - """ - Invoke rerank model - - :param model: model name - :param credentials: model credentials - :param query: search query - :param docs: docs for reranking - :param score_threshold: score threshold - :param top_n: top n - :return: rerank result - """ - try: - return self.model_runtime.invoke_rerank( - provider=self.provider, - model=model, - credentials=credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - ) - except Exception as e: - raise self._transform_invoke_error(e) - - def invoke_multimodal_rerank( - self, - model: str, - credentials: dict, - query: MultimodalRerankInput, - docs: list[MultimodalRerankInput], - score_threshold: float | None = None, - top_n: int | None = None, - ) -> RerankResult: - """ - Invoke multimodal rerank model - :param model: model name - :param credentials: model credentials - :param query: search query - :param docs: docs for reranking - :param score_threshold: score threshold - :param top_n: top n - :return: rerank result - """ - try: - return self.model_runtime.invoke_multimodal_rerank( - provider=self.provider, - model=model, - credentials=credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/graphon/model_runtime/model_providers/__base/speech2text_model.py b/api/graphon/model_runtime/model_providers/__base/speech2text_model.py deleted file mode 100644 index 4f5d648639d..00000000000 --- a/api/graphon/model_runtime/model_providers/__base/speech2text_model.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import IO - -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - - -class Speech2TextModel(AIModel): - """ - Model class for speech2text model. - """ - - model_type: ModelType = ModelType.SPEECH2TEXT - - def invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str: - """ - Invoke speech to text model - - :param model: model name - :param credentials: model credentials - :param file: audio file - :return: text for given audio file - """ - try: - return self.model_runtime.invoke_speech_to_text( - provider=self.provider, - model=model, - credentials=credentials, - file=file, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/graphon/model_runtime/model_providers/__base/text_embedding_model.py b/api/graphon/model_runtime/model_providers/__base/text_embedding_model.py deleted file mode 100644 index c8b4a0a6afd..00000000000 --- a/api/graphon/model_runtime/model_providers/__base/text_embedding_model.py +++ /dev/null @@ -1,98 +0,0 @@ -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - - -class TextEmbeddingModel(AIModel): - """ - Model class for text embedding model. - """ - - model_type: ModelType = ModelType.TEXT_EMBEDDING - - def invoke( - self, - model: str, - credentials: dict, - texts: list[str] | None = None, - multimodel_documents: list[dict] | None = None, - input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, - ) -> EmbeddingResult: - """ - Invoke text embedding model - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :param files: files to embed - :param input_type: input type - :return: embeddings result - """ - try: - if texts: - return self.model_runtime.invoke_text_embedding( - provider=self.provider, - model=model, - credentials=credentials, - texts=texts, - input_type=input_type, - ) - if multimodel_documents: - return self.model_runtime.invoke_multimodal_embedding( - provider=self.provider, - model=model, - credentials=credentials, - documents=multimodel_documents, - input_type=input_type, - ) - raise ValueError("No texts or files provided") - except Exception as e: - raise self._transform_invoke_error(e) - - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> list[int]: - """ - Get number of tokens for given prompt messages - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :return: - """ - return self.model_runtime.get_text_embedding_num_tokens( - provider=self.provider, - model=model, - credentials=credentials, - texts=texts, - ) - - def _get_context_size(self, model: str, credentials: dict) -> int: - """ - Get context size for given embedding model - - :param model: model name - :param credentials: model credentials - :return: context size - """ - model_schema = self.get_model_schema(model, credentials) - - if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties: - content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] - return content_size - - return 1000 - - def _get_max_chunks(self, model: str, credentials: dict) -> int: - """ - Get max chunks for given embedding model - - :param model: model name - :param credentials: model credentials - :return: max chunks - """ - model_schema = self.get_model_schema(model, credentials) - - if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: - max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] - return max_chunks - - return 1 diff --git a/api/graphon/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py b/api/graphon/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py deleted file mode 100644 index 3967acf07ba..00000000000 --- a/api/graphon/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py +++ /dev/null @@ -1,53 +0,0 @@ -import logging -from threading import Lock -from typing import Any - -logger = logging.getLogger(__name__) - -_tokenizer: Any | None = None -_lock = Lock() - - -class GPT2Tokenizer: - @staticmethod - def _get_num_tokens_by_gpt2(text: str) -> int: - """ - use gpt2 tokenizer to get num tokens - """ - _tokenizer = GPT2Tokenizer.get_encoder() - tokens = _tokenizer.encode(text) # type: ignore - return len(tokens) - - @staticmethod - def get_num_tokens(text: str) -> int: - # Because this process needs more cpu resource, we turn this back before we find a better way to handle it. - # - # future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text) - # result = future.result() - # return cast(int, result) - return GPT2Tokenizer._get_num_tokens_by_gpt2(text) - - @staticmethod - def get_encoder(): - global _tokenizer, _lock - if _tokenizer is not None: - return _tokenizer - with _lock: - if _tokenizer is None: - # Try to use tiktoken to get the tokenizer because it is faster - # - try: - import tiktoken - - _tokenizer = tiktoken.get_encoding("gpt2") - except Exception: - from os.path import abspath, dirname, join - - from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer - - base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), "gpt2") - _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) - logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken") - - return _tokenizer diff --git a/api/graphon/model_runtime/model_providers/__base/tts_model.py b/api/graphon/model_runtime/model_providers/__base/tts_model.py deleted file mode 100644 index 6846f3c4038..00000000000 --- a/api/graphon/model_runtime/model_providers/__base/tts_model.py +++ /dev/null @@ -1,58 +0,0 @@ -import logging -from collections.abc import Iterable - -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -class TTSModel(AIModel): - """ - Model class for TTS model. - """ - - model_type: ModelType = ModelType.TTS - - def invoke( - self, - model: str, - credentials: dict, - content_text: str, - voice: str, - ) -> Iterable[bytes]: - """ - Invoke large language model - - :param model: model name - :param credentials: model credentials - :param voice: model timbre - :param content_text: text content to be translated - :return: translated audio file - """ - try: - return self.model_runtime.invoke_tts( - provider=self.provider, - model=model, - credentials=credentials, - content_text=content_text, - voice=voice, - ) - except Exception as e: - raise self._transform_invoke_error(e) - - def get_tts_model_voices(self, model: str, credentials: dict, language: str | None = None): - """ - Retrieves the list of voices supported by a given text-to-speech (TTS) model. - - :param language: The language for which the voices are requested. - :param model: The name of the TTS model. - :param credentials: The credentials required to access the TTS model. - :return: A list of voices supported by the TTS model. - """ - return self.model_runtime.get_tts_model_voices( - provider=self.provider, - model=model, - credentials=credentials, - language=language, - ) diff --git a/api/graphon/model_runtime/model_providers/__init__.py b/api/graphon/model_runtime/model_providers/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/graphon/model_runtime/model_providers/_position.yaml b/api/graphon/model_runtime/model_providers/_position.yaml deleted file mode 100644 index fb02de3a67c..00000000000 --- a/api/graphon/model_runtime/model_providers/_position.yaml +++ /dev/null @@ -1,43 +0,0 @@ -- openai -- deepseek -- anthropic -- azure_openai -- google -- vertex_ai -- nvidia -- nvidia_nim -- cohere -- upstage -- bedrock -- togetherai -- openrouter -- ollama -- mistralai -- groq -- replicate -- huggingface_hub -- xinference -- triton_inference_server -- zhipuai -- baichuan -- spark -- minimax -- tongyi -- wenxin -- moonshot -- tencent -- jina -- chatglm -- yi -- openllm -- localai -- volcengine_maas -- openai_api_compatible -- hunyuan -- siliconflow -- perfxcloud -- zhinao -- fireworks -- mixedbread -- nomic -- voyage diff --git a/api/graphon/model_runtime/model_providers/model_provider_factory.py b/api/graphon/model_runtime/model_providers/model_provider_factory.py deleted file mode 100644 index 1ea30c71209..00000000000 --- a/api/graphon/model_runtime/model_providers/model_provider_factory.py +++ /dev/null @@ -1,173 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence - -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity -from graphon.model_runtime.model_providers.__base.ai_model import AIModel -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel -from graphon.model_runtime.runtime import ModelRuntime -from graphon.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator -from graphon.model_runtime.schema_validators.provider_credential_schema_validator import ( - ProviderCredentialSchemaValidator, -) - - -class ModelProviderFactory: - """Factory for provider schemas and model-type instances backed by a runtime adapter.""" - - def __init__(self, model_runtime: ModelRuntime): - if model_runtime is None: - raise ValueError("model_runtime is required.") - self.model_runtime = model_runtime - - def get_providers(self) -> Sequence[ProviderEntity]: - """ - Get all providers. - """ - return list(self.get_model_providers()) - - def get_model_providers(self) -> Sequence[ProviderEntity]: - """ - Get all model providers exposed by the runtime adapter. - """ - return self.model_runtime.fetch_model_providers() - - def get_provider_schema(self, provider: str) -> ProviderEntity: - """ - Get provider schema. - """ - return self.get_model_provider(provider=provider) - - def get_model_provider(self, provider: str) -> ProviderEntity: - """ - Get provider schema. - """ - provider_entity = self._resolve_provider(provider) - if provider_entity is None: - raise ValueError(f"Invalid provider: {provider}") - - return provider_entity - - def provider_credentials_validate(self, *, provider: str, credentials: dict): - """ - Validate provider credentials. - """ - provider_entity = self.get_model_provider(provider=provider) - - provider_credential_schema = provider_entity.provider_credential_schema - if not provider_credential_schema: - raise ValueError(f"Provider {provider} does not have provider_credential_schema") - - validator = ProviderCredentialSchemaValidator(provider_credential_schema) - filtered_credentials = validator.validate_and_filter(credentials) - - self.model_runtime.validate_provider_credentials( - provider=provider_entity.provider, - credentials=filtered_credentials, - ) - - return filtered_credentials - - def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict): - """ - Validate model credentials. - """ - provider_entity = self.get_model_provider(provider=provider) - - model_credential_schema = provider_entity.model_credential_schema - if not model_credential_schema: - raise ValueError(f"Provider {provider} does not have model_credential_schema") - - validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) - filtered_credentials = validator.validate_and_filter(credentials) - - self.model_runtime.validate_model_credentials( - provider=provider_entity.provider, - model_type=model_type, - model=model, - credentials=filtered_credentials, - ) - - return filtered_credentials - - def get_model_schema( - self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None - ) -> AIModelEntity | None: - """ - Get model schema. - """ - provider_entity = self.get_model_provider(provider) - return self.model_runtime.get_model_schema( - provider=provider_entity.provider, - model_type=model_type, - model=model, - credentials=credentials or {}, - ) - - def get_models( - self, - *, - provider: str | None = None, - model_type: ModelType | None = None, - provider_configs: list[ProviderConfig] | None = None, - ) -> list[SimpleProviderEntity]: - """ - Get all models for given model type. - """ - providers = [] - for provider_entity in self.get_model_providers(): - if provider and not self._matches_provider(provider_entity, provider): - continue - - if model_type and model_type not in provider_entity.supported_model_types: - continue - - simple_provider_schema = provider_entity.to_simple_provider() - if model_type is not None: - simple_provider_schema.models = [ - model_schema for model_schema in provider_entity.models if model_schema.model_type == model_type - ] - providers.append(simple_provider_schema) - - return providers - - def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel: - """ - Get model type instance by provider name and model type. - """ - provider_schema = self.get_model_provider(provider) - - if model_type == ModelType.LLM: - return LargeLanguageModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - if model_type == ModelType.TEXT_EMBEDDING: - return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - if model_type == ModelType.RERANK: - return RerankModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - if model_type == ModelType.SPEECH2TEXT: - return Speech2TextModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - if model_type == ModelType.MODERATION: - return ModerationModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - if model_type == ModelType.TTS: - return TTSModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - - raise ValueError(f"Unsupported model type: {model_type}") - - def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: - """ - Get provider icon. - """ - provider_entity = self.get_model_provider(provider) - return self.model_runtime.get_provider_icon(provider=provider_entity.provider, icon_type=icon_type, lang=lang) - - def _resolve_provider(self, provider: str) -> ProviderEntity | None: - return next((item for item in self.get_model_providers() if self._matches_provider(item, provider)), None) - - @staticmethod - def _matches_provider(provider_entity: ProviderEntity, provider: str) -> bool: - return provider in (provider_entity.provider, provider_entity.provider_name) diff --git a/api/graphon/model_runtime/runtime.py b/api/graphon/model_runtime/runtime.py deleted file mode 100644 index 79862bab8bb..00000000000 --- a/api/graphon/model_runtime/runtime.py +++ /dev/null @@ -1,159 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator, Iterable, Sequence -from typing import IO, Any, Protocol, Union, runtime_checkable - -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.model_runtime.entities.provider_entities import ProviderEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult - - -@runtime_checkable -class ModelRuntime(Protocol): - """Port for provider discovery, schema lookup, and model execution. - - `provider` is the model runtime's canonical provider identifier. Adapters may - derive transport-specific details from it, but those details stay outside - this boundary. - """ - - def fetch_model_providers(self) -> Sequence[ProviderEntity]: ... - - def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: ... - - def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: ... - - def validate_model_credentials( - self, - *, - provider: str, - model_type: ModelType, - model: str, - credentials: dict[str, Any], - ) -> None: ... - - def get_model_schema( - self, - *, - provider: str, - model_type: ModelType, - model: str, - credentials: dict[str, Any], - ) -> AIModelEntity | None: ... - - def invoke_llm( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - model_parameters: dict[str, Any], - prompt_messages: Sequence[PromptMessage], - tools: list[PromptMessageTool] | None, - stop: Sequence[str] | None, - stream: bool, - ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: ... - - def get_llm_num_tokens( - self, - *, - provider: str, - model_type: ModelType, - model: str, - credentials: dict[str, Any], - prompt_messages: Sequence[PromptMessage], - tools: Sequence[PromptMessageTool] | None, - ) -> int: ... - - def invoke_text_embedding( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - texts: list[str], - input_type: EmbeddingInputType, - ) -> EmbeddingResult: ... - - def invoke_multimodal_embedding( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - documents: list[dict[str, Any]], - input_type: EmbeddingInputType, - ) -> EmbeddingResult: ... - - def get_text_embedding_num_tokens( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - texts: list[str], - ) -> list[int]: ... - - def invoke_rerank( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - query: str, - docs: list[str], - score_threshold: float | None, - top_n: int | None, - ) -> RerankResult: ... - - def invoke_multimodal_rerank( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - query: MultimodalRerankInput, - docs: list[MultimodalRerankInput], - score_threshold: float | None, - top_n: int | None, - ) -> RerankResult: ... - - def invoke_tts( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - content_text: str, - voice: str, - ) -> Iterable[bytes]: ... - - def get_tts_model_voices( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - language: str | None, - ) -> Any: ... - - def invoke_speech_to_text( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - file: IO[bytes], - ) -> str: ... - - def invoke_moderation( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - text: str, - ) -> bool: ... diff --git a/api/graphon/model_runtime/schema_validators/__init__.py b/api/graphon/model_runtime/schema_validators/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/graphon/model_runtime/schema_validators/common_validator.py b/api/graphon/model_runtime/schema_validators/common_validator.py deleted file mode 100644 index 984507081b2..00000000000 --- a/api/graphon/model_runtime/schema_validators/common_validator.py +++ /dev/null @@ -1,92 +0,0 @@ -from typing import Union, cast - -from graphon.model_runtime.entities.provider_entities import CredentialFormSchema, FormType - - -class CommonValidator: - def _validate_and_filter_credential_form_schemas( - self, credential_form_schemas: list[CredentialFormSchema], credentials: dict - ): - need_validate_credential_form_schema_map = {} - for credential_form_schema in credential_form_schemas: - if not credential_form_schema.show_on: - need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema - continue - - all_show_on_match = True - for show_on_object in credential_form_schema.show_on: - if show_on_object.variable not in credentials: - all_show_on_match = False - break - - if credentials[show_on_object.variable] != show_on_object.value: - all_show_on_match = False - break - - if all_show_on_match: - need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema - - # Iterate over the remaining credential_form_schemas, verify each credential_form_schema - validated_credentials = {} - for credential_form_schema in need_validate_credential_form_schema_map.values(): - # add the value of the credential_form_schema corresponding to it to validated_credentials - result = self._validate_credential_form_schema(credential_form_schema, credentials) - if result: - validated_credentials[credential_form_schema.variable] = result - - return validated_credentials - - def _validate_credential_form_schema( - self, credential_form_schema: CredentialFormSchema, credentials: dict - ) -> Union[str, bool, None]: - """ - Validate credential form schema - - :param credential_form_schema: credential form schema - :param credentials: credentials - :return: validated credential form schema value - """ - # If the variable does not exist in credentials - value: Union[str, bool, None] = None - if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: - # If required is True, an exception is thrown - if credential_form_schema.required: - raise ValueError(f"Variable {credential_form_schema.variable} is required") - else: - # Get the value of default - if credential_form_schema.default: - # If it exists, add it to validated_credentials - return credential_form_schema.default - else: - # If default does not exist, skip - return None - - # Get the value corresponding to the variable from credentials - value = cast(str, credentials[credential_form_schema.variable]) - - # If max_length=0, no validation is performed - if credential_form_schema.max_length: - if len(value) > credential_form_schema.max_length: - raise ValueError( - f"Variable {credential_form_schema.variable} length should not be" - f" greater than {credential_form_schema.max_length}" - ) - - # check the type of value - if not isinstance(value, str): - raise ValueError(f"Variable {credential_form_schema.variable} should be string") - - if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}: - # If the value is in options, no validation is performed - if credential_form_schema.options: - if value not in [option.value for option in credential_form_schema.options]: - raise ValueError(f"Variable {credential_form_schema.variable} is not in options") - - if credential_form_schema.type == FormType.SWITCH: - # If the value is not in ['true', 'false'], an exception is thrown - if value.lower() not in {"true", "false"}: - raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") - - value = value.lower() == "true" - - return value diff --git a/api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py b/api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py deleted file mode 100644 index 9e4830c1b79..00000000000 --- a/api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py +++ /dev/null @@ -1,27 +0,0 @@ -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ModelCredentialSchema -from graphon.model_runtime.schema_validators.common_validator import CommonValidator - - -class ModelCredentialSchemaValidator(CommonValidator): - def __init__(self, model_type: ModelType, model_credential_schema: ModelCredentialSchema): - self.model_type = model_type - self.model_credential_schema = model_credential_schema - - def validate_and_filter(self, credentials: dict): - """ - Validate model credentials - - :param credentials: model credentials - :return: filtered credentials - """ - - if self.model_credential_schema is None: - raise ValueError("Model credential schema is None") - - # get the credential_form_schemas in provider_credential_schema - credential_form_schemas = self.model_credential_schema.credential_form_schemas - - credentials["__model_type"] = self.model_type.value - - return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) diff --git a/api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py b/api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py deleted file mode 100644 index 05fd3ce142a..00000000000 --- a/api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py +++ /dev/null @@ -1,19 +0,0 @@ -from graphon.model_runtime.entities.provider_entities import ProviderCredentialSchema -from graphon.model_runtime.schema_validators.common_validator import CommonValidator - - -class ProviderCredentialSchemaValidator(CommonValidator): - def __init__(self, provider_credential_schema: ProviderCredentialSchema): - self.provider_credential_schema = provider_credential_schema - - def validate_and_filter(self, credentials: dict): - """ - Validate provider credentials - - :param credentials: provider credentials - :return: validated provider credentials - """ - # get the credential_form_schemas in provider_credential_schema - credential_form_schemas = self.provider_credential_schema.credential_form_schemas - - return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) diff --git a/api/graphon/model_runtime/utils/__init__.py b/api/graphon/model_runtime/utils/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/graphon/model_runtime/utils/encoders.py b/api/graphon/model_runtime/utils/encoders.py deleted file mode 100644 index 13abf74767e..00000000000 --- a/api/graphon/model_runtime/utils/encoders.py +++ /dev/null @@ -1,218 +0,0 @@ -import dataclasses -import datetime -from collections import defaultdict, deque -from collections.abc import Callable, Sequence -from decimal import Decimal -from enum import Enum -from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network -from pathlib import Path, PurePath -from re import Pattern -from types import GeneratorType -from typing import Any, Literal, Union -from uuid import UUID - -from pydantic import BaseModel -from pydantic.networks import AnyUrl, NameEmail -from pydantic.types import SecretBytes, SecretStr -from pydantic_core import Url -from pydantic_extra_types.color import Color - - -def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: - return model.model_dump(mode=mode, **kwargs) - - -# Taken from Pydantic v1 as is -def isoformat(o: Union[datetime.date, datetime.time]) -> str: - return o.isoformat() - - -# Taken from Pydantic v1 as is -# TODO: pv2 should this return strings instead? -def decimal_encoder(dec_value: Decimal) -> Union[int, float]: - """ - Encodes a Decimal as int of there's no exponent, otherwise float - - This is useful when we use ConstrainedDecimal to represent Numeric(x,0) - where a integer (but not int typed) is used. Encoding this as a float - results in failed round-tripping between encode and parse. - Our Id type is a prime example of this. - - >>> decimal_encoder(Decimal("1.0")) - 1.0 - - >>> decimal_encoder(Decimal("1")) - 1 - """ - if dec_value.as_tuple().exponent >= 0: # type: ignore[operator] - return int(dec_value) - else: - return float(dec_value) - - -ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { - bytes: lambda o: o.decode(), - Color: str, - datetime.date: isoformat, - datetime.datetime: isoformat, - datetime.time: isoformat, - datetime.timedelta: lambda td: td.total_seconds(), - Decimal: decimal_encoder, - Enum: lambda o: o.value, - frozenset: list, - deque: list, - GeneratorType: list, - IPv4Address: str, - IPv4Interface: str, - IPv4Network: str, - IPv6Address: str, - IPv6Interface: str, - IPv6Network: str, - NameEmail: str, - Path: str, - Pattern: lambda o: o.pattern, - SecretBytes: str, - SecretStr: str, - set: list, - UUID: str, - Url: str, - AnyUrl: str, -} - - -def generate_encoders_by_class_tuples( - type_encoder_map: dict[Any, Callable[[Any], Any]], -) -> dict[Callable[[Any], Any], tuple[Any, ...]]: - encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple) - for type_, encoder in type_encoder_map.items(): - encoders_by_class_tuples[encoder] += (type_,) - return encoders_by_class_tuples - - -encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) - - -def jsonable_encoder( - obj: Any, - by_alias: bool = True, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - custom_encoder: dict[Any, Callable[[Any], Any]] | None = None, - excluded_key_prefixes: Sequence[str] = (), -) -> Any: - custom_encoder = custom_encoder or {} - if custom_encoder: - if type(obj) in custom_encoder: - return custom_encoder[type(obj)](obj) - else: - for encoder_type, encoder_instance in custom_encoder.items(): - if isinstance(obj, encoder_type): - return encoder_instance(obj) - if isinstance(obj, BaseModel): - obj_dict = _model_dump( - obj, - mode="json", - include=None, - exclude=None, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - exclude_defaults=exclude_defaults, - ) - if "__root__" in obj_dict: - obj_dict = obj_dict["__root__"] - return jsonable_encoder( - obj_dict, - exclude_none=exclude_none, - exclude_defaults=exclude_defaults, - excluded_key_prefixes=excluded_key_prefixes, - ) - if dataclasses.is_dataclass(obj): - # Ensure obj is a dataclass instance, not a dataclass type - if not isinstance(obj, type): - obj_dict = dataclasses.asdict(obj) - return jsonable_encoder( - obj_dict, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - excluded_key_prefixes=excluded_key_prefixes, - ) - if isinstance(obj, Enum): - return obj.value - if isinstance(obj, PurePath): - return str(obj) - if isinstance(obj, str | int | float | type(None)): - return obj - if isinstance(obj, Decimal): - return format(obj, "f") - if isinstance(obj, dict): - encoded_dict = {} - for key, value in obj.items(): - if isinstance(key, str) and any(key.startswith(prefix) for prefix in excluded_key_prefixes): - continue - if value is None and exclude_none: - continue - - encoded_key = jsonable_encoder( - key, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - excluded_key_prefixes=excluded_key_prefixes, - ) - encoded_value = jsonable_encoder( - value, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - excluded_key_prefixes=excluded_key_prefixes, - ) - encoded_dict[encoded_key] = encoded_value - return encoded_dict - if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque): - encoded_list = [] - for item in obj: - encoded_list.append( - jsonable_encoder( - item, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - excluded_key_prefixes=excluded_key_prefixes, - ) - ) - return encoded_list - - if type(obj) in ENCODERS_BY_TYPE: - return ENCODERS_BY_TYPE[type(obj)](obj) - for encoder, classes_tuple in encoders_by_class_tuples.items(): - if isinstance(obj, classes_tuple): - return encoder(obj) - - try: - data = dict(obj) # type: ignore - except Exception as e: - errors: list[Exception] = [] - errors.append(e) - try: - data = vars(obj) # type: ignore - except Exception as e: - errors.append(e) - raise ValueError(str(errors)) from e - return jsonable_encoder( - data, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - excluded_key_prefixes=excluded_key_prefixes, - ) diff --git a/api/graphon/node_events/__init__.py b/api/graphon/node_events/__init__.py deleted file mode 100644 index a2bbf9f1765..00000000000 --- a/api/graphon/node_events/__init__.py +++ /dev/null @@ -1,48 +0,0 @@ -from .agent import AgentLogEvent -from .base import NodeEventBase, NodeRunResult -from .iteration import ( - IterationFailedEvent, - IterationNextEvent, - IterationStartedEvent, - IterationSucceededEvent, -) -from .loop import ( - LoopFailedEvent, - LoopNextEvent, - LoopStartedEvent, - LoopSucceededEvent, -) -from .node import ( - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - ModelInvokeCompletedEvent, - PauseRequestedEvent, - RunRetrieverResourceEvent, - RunRetryEvent, - StreamChunkEvent, - StreamCompletedEvent, - VariableUpdatedEvent, -) - -__all__ = [ - "AgentLogEvent", - "HumanInputFormFilledEvent", - "HumanInputFormTimeoutEvent", - "IterationFailedEvent", - "IterationNextEvent", - "IterationStartedEvent", - "IterationSucceededEvent", - "LoopFailedEvent", - "LoopNextEvent", - "LoopStartedEvent", - "LoopSucceededEvent", - "ModelInvokeCompletedEvent", - "NodeEventBase", - "NodeRunResult", - "PauseRequestedEvent", - "RunRetrieverResourceEvent", - "RunRetryEvent", - "StreamChunkEvent", - "StreamCompletedEvent", - "VariableUpdatedEvent", -] diff --git a/api/graphon/node_events/agent.py b/api/graphon/node_events/agent.py deleted file mode 100644 index bf295ec7742..00000000000 --- a/api/graphon/node_events/agent.py +++ /dev/null @@ -1,18 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import Field - -from .base import NodeEventBase - - -class AgentLogEvent(NodeEventBase): - message_id: str = Field(..., description="id") - label: str = Field(..., description="label") - node_execution_id: str = Field(..., description="node execution id") - parent_id: str | None = Field(..., description="parent id") - error: str | None = Field(..., description="error") - status: str = Field(..., description="status") - data: Mapping[str, Any] = Field(..., description="data") - metadata: Mapping[str, Any] = Field(default_factory=dict, description="metadata") - node_id: str = Field(..., description="node id") diff --git a/api/graphon/node_events/base.py b/api/graphon/node_events/base.py deleted file mode 100644 index dcd1672428c..00000000000 --- a/api/graphon/node_events/base.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import BaseModel, Field - -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage - - -class NodeEventBase(BaseModel): - """Base class for all node events""" - - pass - - -def _default_metadata(): - v: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} - return v - - -class NodeRunResult(BaseModel): - """ - Node Run Result. - """ - - status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.PENDING - - inputs: Mapping[str, Any] = Field(default_factory=dict) - process_data: Mapping[str, Any] = Field(default_factory=dict) - outputs: Mapping[str, Any] = Field(default_factory=dict) - metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = Field(default_factory=_default_metadata) - llm_usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage) - - edge_source_handle: str = "source" # source handle id of node with multiple branches - - error: str = "" - error_type: str = "" - - # single step node run retry - retry_index: int = 0 diff --git a/api/graphon/node_events/iteration.py b/api/graphon/node_events/iteration.py deleted file mode 100644 index 744ddea628b..00000000000 --- a/api/graphon/node_events/iteration.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import NodeEventBase - - -class IterationStartedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class IterationNextEvent(NodeEventBase): - index: int = Field(..., description="index") - pre_iteration_output: Any = None - - -class IterationSucceededEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class IterationFailedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/graphon/node_events/loop.py b/api/graphon/node_events/loop.py deleted file mode 100644 index 3ae230f9f66..00000000000 --- a/api/graphon/node_events/loop.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import NodeEventBase - - -class LoopStartedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class LoopNextEvent(NodeEventBase): - index: int = Field(..., description="index") - pre_loop_output: Any = None - - -class LoopSucceededEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class LoopFailedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/graphon/node_events/node.py b/api/graphon/node_events/node.py deleted file mode 100644 index 17f1494cf24..00000000000 --- a/api/graphon/node_events/node.py +++ /dev/null @@ -1,72 +0,0 @@ -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any - -from pydantic import Field - -from graphon.entities.pause_reason import PauseReason -from graphon.file import File -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import NodeRunResult -from graphon.variables.variables import Variable - -from .base import NodeEventBase - - -class RunRetrieverResourceEvent(NodeEventBase): - retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources") - context: str = Field(..., description="context") - context_files: list[File] | None = Field(default=None, description="context files") - - -class ModelInvokeCompletedEvent(NodeEventBase): - text: str - usage: LLMUsage - finish_reason: str | None = None - reasoning_content: str | None = None - structured_output: dict | None = None - - -class RunRetryEvent(NodeEventBase): - error: str = Field(..., description="error") - retry_index: int = Field(..., description="Retry attempt number") - start_at: datetime = Field(..., description="Retry start time") - - -class StreamChunkEvent(NodeEventBase): - # Spec-compliant fields - selector: Sequence[str] = Field( - ..., description="selector identifying the output location (e.g., ['nodeA', 'text'])" - ) - chunk: str = Field(..., description="the actual chunk content") - is_final: bool = Field(default=False, description="indicates if this is the last chunk") - - -class StreamCompletedEvent(NodeEventBase): - node_run_result: NodeRunResult = Field(..., description="run result") - - -class VariableUpdatedEvent(NodeEventBase): - """Notify the engine that a single variable should be applied to the shared pool.""" - - variable: Variable = Field(..., description="Updated variable payload to apply.") - - -class PauseRequestedEvent(NodeEventBase): - reason: PauseReason = Field(..., description="pause reason") - - -class HumanInputFormFilledEvent(NodeEventBase): - """Event emitted when a human input form is submitted.""" - - node_title: str - rendered_content: str - action_id: str - action_text: str - - -class HumanInputFormTimeoutEvent(NodeEventBase): - """Event emitted when a human input form times out.""" - - node_title: str - expiration_time: datetime diff --git a/api/graphon/nodes/__init__.py b/api/graphon/nodes/__init__.py deleted file mode 100644 index 2d376d104da..00000000000 --- a/api/graphon/nodes/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from graphon.enums import BuiltinNodeTypes - -__all__ = ["BuiltinNodeTypes"] diff --git a/api/graphon/nodes/answer/__init__.py b/api/graphon/nodes/answer/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/graphon/nodes/answer/answer_node.py b/api/graphon/nodes/answer/answer_node.py deleted file mode 100644 index c5261a79393..00000000000 --- a/api/graphon/nodes/answer/answer_node.py +++ /dev/null @@ -1,70 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.answer.entities import AnswerNodeData -from graphon.nodes.base.node import Node -from graphon.nodes.base.template import Template -from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from graphon.variables import ArrayFileSegment, FileSegment, Segment - - -class AnswerNode(Node[AnswerNodeData]): - node_type = BuiltinNodeTypes.ANSWER - execution_type = NodeExecutionType.RESPONSE - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - segments = self.graph_runtime_state.variable_pool.convert_template(self.node_data.answer) - files = self._extract_files_from_segments(segments.value) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"answer": segments.markdown, "files": ArrayFileSegment(value=files)}, - ) - - def _extract_files_from_segments(self, segments: Sequence[Segment]): - """Extract all files from segments containing FileSegment or ArrayFileSegment instances. - - FileSegment contains a single file, while ArrayFileSegment contains multiple files. - This method flattens all files into a single list. - """ - files = [] - for segment in segments: - if isinstance(segment, FileSegment): - # Single file - wrap in list for consistency - files.append(segment.value) - elif isinstance(segment, ArrayFileSegment): - # Multiple files - extend the list - files.extend(segment.value) - return files - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: AnswerNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - variable_template_parser = VariableTemplateParser(template=node_data.answer) - variable_selectors = variable_template_parser.extract_variable_selectors() - - variable_mapping = {} - for variable_selector in variable_selectors: - variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector - - return variable_mapping - - def get_streaming_template(self) -> Template: - """ - Get the template for streaming. - - Returns: - Template instance for this Answer node - """ - return Template.from_answer_template(self.node_data.answer) diff --git a/api/graphon/nodes/answer/entities.py b/api/graphon/nodes/answer/entities.py deleted file mode 100644 index c49f1f38957..00000000000 --- a/api/graphon/nodes/answer/entities.py +++ /dev/null @@ -1,67 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum, auto - -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - - -class AnswerNodeData(BaseNodeData): - """ - Answer Node Data. - """ - - type: NodeType = BuiltinNodeTypes.ANSWER - answer: str = Field(..., description="answer template string") - - -class GenerateRouteChunk(BaseModel): - """ - Generate Route Chunk. - """ - - class ChunkType(StrEnum): - VAR = auto() - TEXT = auto() - - type: ChunkType = Field(..., description="generate route chunk type") - - -class VarGenerateRouteChunk(GenerateRouteChunk): - """ - Var Generate Route Chunk. - """ - - type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR - """generate route chunk type""" - value_selector: Sequence[str] = Field(..., description="value selector") - - -class TextGenerateRouteChunk(GenerateRouteChunk): - """ - Text Generate Route Chunk. - """ - - type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT - """generate route chunk type""" - text: str = Field(..., description="text") - - -class AnswerNodeDoubleLink(BaseModel): - node_id: str = Field(..., description="node id") - source_node_ids: list[str] = Field(..., description="source node ids") - target_node_ids: list[str] = Field(..., description="target node ids") - - -class AnswerStreamGenerateRoute(BaseModel): - """ - AnswerStreamGenerateRoute entity - """ - - answer_dependencies: dict[str, list[str]] = Field( - ..., description="answer dependencies (answer node id -> dependent answer node ids)" - ) - answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( - ..., description="answer generate route (answer node id -> generate route chunks)" - ) diff --git a/api/graphon/nodes/base/__init__.py b/api/graphon/nodes/base/__init__.py deleted file mode 100644 index 036e25895d2..00000000000 --- a/api/graphon/nodes/base/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState -from .usage_tracking_mixin import LLMUsageTrackingMixin - -__all__ = [ - "BaseIterationNodeData", - "BaseIterationState", - "BaseLoopNodeData", - "BaseLoopState", - "LLMUsageTrackingMixin", -] diff --git a/api/graphon/nodes/base/entities.py b/api/graphon/nodes/base/entities.py deleted file mode 100644 index 94b88c097d6..00000000000 --- a/api/graphon/nodes/base/entities.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from enum import StrEnum -from typing import Any - -from pydantic import BaseModel, field_validator - -from graphon.entities.base_node_data import BaseNodeData - - -class VariableSelector(BaseModel): - """ - Variable Selector. - """ - - variable: str - value_selector: Sequence[str] - - -class OutputVariableType(StrEnum): - STRING = "string" - NUMBER = "number" - INTEGER = "integer" - SECRET = "secret" - BOOLEAN = "boolean" - OBJECT = "object" - FILE = "file" - ARRAY = "array" - ARRAY_STRING = "array[string]" - ARRAY_NUMBER = "array[number]" - ARRAY_OBJECT = "array[object]" - ARRAY_BOOLEAN = "array[boolean]" - ARRAY_FILE = "array[file]" - ANY = "any" - ARRAY_ANY = "array[any]" - - -class OutputVariableEntity(BaseModel): - """ - Output Variable Entity. - """ - - variable: str - value_type: OutputVariableType = OutputVariableType.ANY - value_selector: Sequence[str] - - @field_validator("value_type", mode="before") - @classmethod - def normalize_value_type(cls, v: Any) -> Any: - """ - Normalize value_type to handle case-insensitive array types. - Converts 'Array[...]' to 'array[...]' for backward compatibility. - """ - if isinstance(v, str) and v.startswith("Array["): - return v.lower() - return v - - -class BaseIterationNodeData(BaseNodeData): - start_node_id: str | None = None - - -class BaseIterationState(BaseModel): - iteration_node_id: str - index: int - inputs: dict - - class MetaData(BaseModel): - pass - - metadata: MetaData - - -class BaseLoopNodeData(BaseNodeData): - start_node_id: str | None = None - - -class BaseLoopState(BaseModel): - loop_node_id: str - index: int - inputs: dict - - class MetaData(BaseModel): - pass - - metadata: MetaData diff --git a/api/graphon/nodes/base/node.py b/api/graphon/nodes/base/node.py deleted file mode 100644 index 613ff4f0372..00000000000 --- a/api/graphon/nodes/base/node.py +++ /dev/null @@ -1,787 +0,0 @@ -from __future__ import annotations - -import logging -import operator -from abc import abstractmethod -from collections.abc import Generator, Mapping, Sequence -from datetime import UTC, datetime -from functools import singledispatchmethod -from types import MappingProxyType -from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin -from uuid import uuid4 - -from graphon.entities import GraphInitParams -from graphon.entities.base_node_data import BaseNodeData, RetryConfig -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ( - ErrorStrategy, - NodeExecutionType, - NodeState, - NodeType, - WorkflowNodeExecutionStatus, -) -from graphon.graph_events import ( - GraphNodeEventBase, - NodeRunAgentLogEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunRetrieverResourceEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, -) -from graphon.node_events import ( - AgentLogEvent, - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - IterationFailedEvent, - IterationNextEvent, - IterationStartedEvent, - IterationSucceededEvent, - LoopFailedEvent, - LoopNextEvent, - LoopStartedEvent, - LoopSucceededEvent, - NodeEventBase, - NodeRunResult, - PauseRequestedEvent, - RunRetrieverResourceEvent, - StreamChunkEvent, - StreamCompletedEvent, - VariableUpdatedEvent, -) -from graphon.runtime import GraphRuntimeState - -NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) -_MISSING_RUN_CONTEXT_VALUE = object() - -logger = logging.getLogger(__name__) - - -class Node(Generic[NodeDataT]): - """BaseNode serves as the foundational class for all node implementations. - - Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output` - attribute to track files generated by the LLM). However, these states are not persisted - when the workflow is suspended or resumed. If a node needs its state to be preserved - across workflow suspension and resumption, it should include the relevant state data - in its output. - """ - - node_type: ClassVar[NodeType] - execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE - _node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData - - def __init_subclass__(cls, **kwargs: Any) -> None: - """ - Automatically extract and validate the node data type from the generic parameter. - - When a subclass is defined as `class MyNode(Node[MyNodeData])`, this method: - 1. Inspects `__orig_bases__` to find the `Node[T]` parameterization - 2. Extracts `T` (e.g., `MyNodeData`) from the generic argument - 3. Validates that `T` is a proper `BaseNodeData` subclass - 4. Stores it in `_node_data_type` for automatic hydration in `__init__` - - This eliminates the need for subclasses to manually implement boilerplate - accessor methods like `_get_title()`, `_get_error_strategy()`, etc. - - How it works: - :: - - class CodeNode(Node[CodeNodeData]): - โ”‚ โ”‚ - โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” - โ”‚ โ”‚ - โ–ผ โ–ผ - โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” - โ”‚ __orig_bases__ = ( โ”‚ โ”‚ CodeNodeData(BaseNodeData) โ”‚ - โ”‚ Node[CodeNodeData], โ”‚ โ”‚ title: str โ”‚ - โ”‚ ) โ”‚ โ”‚ desc: str | None โ”‚ - โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ ... โ”‚ - โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ - โ–ผ โ–ฒ - โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ - โ”‚ get_origin(base) -> Node โ”‚ โ”‚ - โ”‚ get_args(base) -> ( โ”‚ โ”‚ - โ”‚ CodeNodeData, โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ - โ”‚ ) โ”‚ - โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ - โ”‚ - โ–ผ - โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” - โ”‚ Validate: โ”‚ - โ”‚ - Is it a type? โ”‚ - โ”‚ - Is it a BaseNodeData โ”‚ - โ”‚ subclass? โ”‚ - โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ - โ”‚ - โ–ผ - โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” - โ”‚ cls._node_data_type = โ”‚ - โ”‚ CodeNodeData โ”‚ - โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ - - Later, in __init__: - :: - - config["data"] โ”€โ”€โ–บ _node_data_type.model_validate(..., from_attributes=True) - โ”‚ - โ–ผ - CodeNodeData instance - (stored in self._node_data) - - Example: - class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted - node_type = BuiltinNodeTypes.CODE - # No need to implement _get_title, _get_error_strategy, etc. - """ - super().__init_subclass__(**kwargs) - - if cls is Node: - return - - node_data_type = cls._extract_node_data_type_from_generic() - - if node_data_type is None: - raise TypeError(f"{cls.__name__} must inherit from Node[T] with a BaseNodeData subtype") - - cls._node_data_type = node_data_type - - # Skip base class itself - if cls is Node: - return - # Only treat nodes from the base graphon package as production - # registrations. Higher-layer packages may still register subclasses, - # but graphon itself should not know their module identities. - # This prevents test helper subclasses from polluting the global registry and - # accidentally overriding real node types (e.g., a test Answer node). - module_name = getattr(cls, "__module__", "") - # Only register concrete subclasses that define node_type and version() - node_type = cls.node_type - version = cls.version() - bucket = Node._registry.setdefault(node_type, {}) - if module_name.startswith("graphon.nodes."): - # Production node definitions take precedence and may override - bucket[version] = cls # type: ignore[index] - else: - # External/test subclasses may register but must not override production - bucket.setdefault(version, cls) # type: ignore[index] - # Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic - version_keys = [v for v in bucket if v != "latest"] - numeric_pairs: list[tuple[str, int]] = [] - for v in version_keys: - numeric_pairs.append((v, int(v))) - if numeric_pairs: - latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0] - else: - latest_key = max(version_keys) if version_keys else version - bucket["latest"] = bucket[latest_key] - Node._registry_version += 1 - - @classmethod - def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None: - """ - Extract the node data type from the generic parameter `Node[T]`. - - Inspects `__orig_bases__` to find the `Node[T]` parameterization and extracts `T`. - - Returns: - The extracted BaseNodeData subtype, or None if not found. - - Raises: - TypeError: If the generic argument is invalid (not exactly one argument, - or not a BaseNodeData subtype). - """ - # __orig_bases__ contains the original generic bases before type erasure. - # For `class CodeNode(Node[CodeNodeData])`, this would be `(Node[CodeNodeData],)`. - for base in getattr(cls, "__orig_bases__", ()): # type: ignore[attr-defined] - origin = get_origin(base) # Returns `Node` for `Node[CodeNodeData]` - if origin is Node: - args = get_args(base) # Returns `(CodeNodeData,)` for `Node[CodeNodeData]` - if len(args) != 1: - raise TypeError(f"{cls.__name__} must specify exactly one node data generic argument") - - candidate = args[0] - if not isinstance(candidate, type) or not issubclass(candidate, BaseNodeData): - raise TypeError(f"{cls.__name__} must parameterize Node with a BaseNodeData subtype") - - return candidate - - return None - - # Global registry populated via __init_subclass__ - _registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {} - _registry_version: ClassVar[int] = 0 - - @classmethod - def get_registry_version(cls) -> int: - return cls._registry_version - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - ) -> None: - self._graph_init_params = graph_init_params - self._run_context = MappingProxyType(dict(graph_init_params.run_context)) - self.id = id - self.workflow_id = graph_init_params.workflow_id - self.graph_config = graph_init_params.graph_config - self.workflow_call_depth = graph_init_params.call_depth - self.graph_runtime_state = graph_runtime_state - self.state: NodeState = NodeState.UNKNOWN # node execution state - - node_id = config["id"] - - self._node_id = node_id - self._node_execution_id: str = "" - self._start_at = datetime.now(UTC).replace(tzinfo=None) - - self._node_data = self.validate_node_data(config["data"]) - - self.post_init() - - @classmethod - def validate_node_data(cls, node_data: BaseNodeData | Mapping[str, Any]) -> NodeDataT: - """Validate shared graph node payloads against the subclass-declared NodeData model. - - Re-validate from a dumped payload instead of `from_attributes=True` so compatibility - extras stored on `BaseNodeData` survive the handoff to the concrete node data model. - Human Input delivery methods are one such extra field until graphon owns that schema. - """ - if isinstance(node_data, BaseNodeData): - payload = node_data.model_dump(mode="python") - else: - payload = dict(node_data) - return cast(NodeDataT, cls._node_data_type.model_validate(payload)) - - def init_node_data(self, data: BaseNodeData | Mapping[str, Any]) -> None: - """Hydrate `_node_data` for legacy callers that bypass `__init__`.""" - self._node_data = self.validate_node_data(cast(BaseNodeData, data)) - - def post_init(self) -> None: - """Optional hook for subclasses requiring extra initialization.""" - return - - @property - def graph_init_params(self) -> GraphInitParams: - return self._graph_init_params - - @property - def run_context(self) -> Mapping[str, Any]: - return self._run_context - - def get_run_context_value(self, key: str, default: Any = None) -> Any: - return self._run_context.get(key, default) - - def require_run_context_value(self, key: str) -> Any: - value = self.get_run_context_value(key, _MISSING_RUN_CONTEXT_VALUE) - if value is _MISSING_RUN_CONTEXT_VALUE: - raise ValueError(f"run_context missing required key: {key}") - return value - - @property - def execution_id(self) -> str: - return self._node_execution_id - - def ensure_execution_id(self) -> str: - if self._node_execution_id: - return self._node_execution_id - - resumed_execution_id = self._restore_execution_id_from_runtime_state() - if resumed_execution_id: - self._node_execution_id = resumed_execution_id - return self._node_execution_id - - self._node_execution_id = str(uuid4()) - return self._node_execution_id - - def _restore_execution_id_from_runtime_state(self) -> str | None: - graph_execution = self.graph_runtime_state.graph_execution - try: - node_executions = graph_execution.node_executions - except AttributeError: - return None - if not isinstance(node_executions, dict): - return None - node_execution = node_executions.get(self._node_id) - if node_execution is None: - return None - execution_id = node_execution.execution_id - if not execution_id: - return None - return str(execution_id) - - @abstractmethod - def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: - """ - Run node - :return: - """ - raise NotImplementedError - - def populate_start_event(self, event: NodeRunStartedEvent) -> None: - """Allow subclasses to enrich the started event without cross-node imports in the base class.""" - _ = event - - def run(self) -> Generator[GraphNodeEventBase, None, None]: - execution_id = self.ensure_execution_id() - self._start_at = datetime.now(UTC).replace(tzinfo=None) - - # Create and push start event with required fields - start_event = NodeRunStartedEvent( - id=execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.title, - in_iteration_id=None, - start_at=self._start_at, - ) - try: - self.populate_start_event(start_event) - except Exception: - logger.warning("Failed to populate start event for node %s", self._node_id, exc_info=True) - yield start_event - - try: - result = self._run() - - # Handle NodeRunResult - if isinstance(result, NodeRunResult): - yield self._convert_node_run_result_to_graph_node_event(result) - return - - # Handle event stream - for event in result: - # NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase - if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance] - yield self._dispatch(event) - elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance] - event.id = self.execution_id - yield event - else: - yield event - except Exception as e: - logger.exception("Node %s failed to run", self._node_id) - result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - error_type="WorkflowNodeError", - ) - finished_at = datetime.now(UTC).replace(tzinfo=None) - yield NodeRunFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=result, - error=str(e), - ) - - @classmethod - def extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - config: NodeConfigDict, - ) -> Mapping[str, Sequence[str]]: - """Extracts references variable selectors from node configuration. - - The `config` parameter represents the configuration for a specific node type and corresponds - to the `data` field in the node definition object. - - The returned mapping has the following structure: - - {'1747829548239.#1747829667553.result#': ['1747829667553', 'result']} - - For loop and iteration nodes, the mapping may look like this: - - { - "1748332301644.input_selector": ["1748332363630", "result"], - "1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"], - } - - where `1748332301644` is the ID of the loop / iteration node, - and `1748332325079` is the ID of the node inside the loop or iteration node. - - Here, the key consists of two parts: the current node ID (provided as the `node_id` - parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector, - enclosed in `#` symbols. These two parts are separated by a dot (`.`). - - The value is a list of string representing the variable selector, where the first element is the node ID - of the referenced variable, and the second element is the variable name within that node. - - The meaning of the above response is: - - The node with ID `1747829548239` references the variable `result` from the node with - ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a - reference to the `result` output variable of node `1747829667553`. - - :param graph_config: graph config - :param config: node config - :return: - """ - node_id = config["id"] - node_data = cls.validate_node_data(config["data"]) - data = cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, - node_id=node_id, - node_data=node_data, - ) - return data - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: NodeDataT, - ) -> Mapping[str, Sequence[str]]: - return {} - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: - """ - Check if this node blocks the output of specific variables. - - This method is used to determine if a node must complete execution before - the specified variables can be used in streaming output. - - :param variable_selectors: Set of variable selectors, each as a tuple (e.g., ('conversation', 'str')) - :return: True if this node blocks output of any of the specified variables, False otherwise - """ - return False - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return {} - - @classmethod - @abstractmethod - def version(cls) -> str: - """`node_version` returns the version of current node type.""" - # NOTE(QuantumGhost): Node versions must remain unique per `NodeType` so - # registry lookups can resolve numeric versions and `latest`. - raise NotImplementedError("subclasses of BaseNode must implement `version` method.") - - @classmethod - def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]: - """Return a read-only view of the currently registered node classes. - - This accessor intentionally performs no imports. The embedding layer that - owns bootstrap (for example `core.workflow.node_factory`) must import any - extension node packages before calling it so their subclasses register via - `__init_subclass__`. - """ - return {node_type: MappingProxyType(version_map) for node_type, version_map in cls._registry.items()} - - @property - def retry(self) -> bool: - return False - - def _get_error_strategy(self) -> ErrorStrategy | None: - """Get the error strategy for this node.""" - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - """Get the retry configuration for this node.""" - return self._node_data.retry_config - - def _get_title(self) -> str: - """Get the node title.""" - return self._node_data.title - - def _get_description(self) -> str | None: - """Get the node description.""" - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - """Get the default values dictionary for this node.""" - return self._node_data.default_value_dict - - # Public interface properties that delegate to abstract methods - @property - def error_strategy(self) -> ErrorStrategy | None: - """Get the error strategy for this node.""" - return self._get_error_strategy() - - @property - def retry_config(self) -> RetryConfig: - """Get the retry configuration for this node.""" - return self._get_retry_config() - - @property - def title(self) -> str: - """Get the node title.""" - return self._get_title() - - @property - def description(self) -> str | None: - """Get the node description.""" - return self._get_description() - - @property - def default_value_dict(self) -> dict[str, Any]: - """Get the default values dictionary for this node.""" - return self._get_default_value_dict() - - @property - def node_data(self) -> NodeDataT: - """Typed access to this node's configuration data.""" - return self._node_data - - def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase: - finished_at = datetime.now(UTC).replace(tzinfo=None) - match result.status: - case WorkflowNodeExecutionStatus.FAILED: - return NodeRunFailedEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=result, - error=result.error, - ) - case WorkflowNodeExecutionStatus.SUCCEEDED: - return NodeRunSucceededEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=result, - ) - case _: - raise Exception(f"result status {result.status} not supported") - - @singledispatchmethod - def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase: - raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}") - - @_dispatch.register - def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent: - return NodeRunStreamChunkEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - selector=event.selector, - chunk=event.chunk, - is_final=event.is_final, - ) - - @_dispatch.register - def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent: - finished_at = datetime.now(UTC).replace(tzinfo=None) - match event.node_run_result.status: - case WorkflowNodeExecutionStatus.SUCCEEDED: - return NodeRunSucceededEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=event.node_run_result, - ) - case WorkflowNodeExecutionStatus.FAILED: - return NodeRunFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=event.node_run_result, - error=event.node_run_result.error, - ) - case _: - raise NotImplementedError( - f"Node {self._node_id} does not support status {event.node_run_result.status}" - ) - - @_dispatch.register - def _(self, event: VariableUpdatedEvent) -> NodeRunVariableUpdatedEvent: - return NodeRunVariableUpdatedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - variable=event.variable, - ) - - @_dispatch.register - def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent: - return NodeRunPauseRequestedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED), - reason=event.reason, - ) - - @_dispatch.register - def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent: - return NodeRunAgentLogEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - message_id=event.message_id, - label=event.label, - node_execution_id=event.node_execution_id, - parent_id=event.parent_id, - error=event.error, - status=event.status, - data=event.data, - metadata=event.metadata, - ) - - @_dispatch.register - def _(self, event: HumanInputFormFilledEvent): - return NodeRunHumanInputFormFilledEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=event.node_title, - rendered_content=event.rendered_content, - action_id=event.action_id, - action_text=event.action_text, - ) - - @_dispatch.register - def _(self, event: HumanInputFormTimeoutEvent): - return NodeRunHumanInputFormTimeoutEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=event.node_title, - expiration_time=event.expiration_time, - ) - - @_dispatch.register - def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent: - return NodeRunLoopStartedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - metadata=event.metadata, - predecessor_node_id=event.predecessor_node_id, - ) - - @_dispatch.register - def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent: - return NodeRunLoopNextEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - index=event.index, - pre_loop_output=event.pre_loop_output, - ) - - @_dispatch.register - def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent: - return NodeRunLoopSucceededEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - ) - - @_dispatch.register - def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent: - return NodeRunLoopFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - error=event.error, - ) - - @_dispatch.register - def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent: - return NodeRunIterationStartedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - metadata=event.metadata, - predecessor_node_id=event.predecessor_node_id, - ) - - @_dispatch.register - def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent: - return NodeRunIterationNextEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - index=event.index, - pre_iteration_output=event.pre_iteration_output, - ) - - @_dispatch.register - def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent: - return NodeRunIterationSucceededEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - ) - - @_dispatch.register - def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent: - return NodeRunIterationFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - error=event.error, - ) - - @_dispatch.register - def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: - return NodeRunRetrieverResourceEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - retriever_resources=event.retriever_resources, - context=event.context, - node_version=self.version(), - ) diff --git a/api/graphon/nodes/base/template.py b/api/graphon/nodes/base/template.py deleted file mode 100644 index 311de4a6ea6..00000000000 --- a/api/graphon/nodes/base/template.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Template structures for Response nodes (Answer and End). - -This module provides a unified template structure for both Answer and End nodes, -similar to SegmentGroup but focused on template representation without values. -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Any, Union - -from graphon.nodes.base.variable_template_parser import VariableTemplateParser - - -@dataclass(frozen=True) -class TemplateSegment(ABC): - """Base class for template segments.""" - - @abstractmethod - def __str__(self) -> str: - """String representation of the segment.""" - pass - - -@dataclass(frozen=True) -class TextSegment(TemplateSegment): - """A text segment in a template.""" - - text: str - - def __str__(self) -> str: - return self.text - - -@dataclass(frozen=True) -class VariableSegment(TemplateSegment): - """A variable reference segment in a template.""" - - selector: Sequence[str] - variable_name: str | None = None # Optional variable name for End nodes - - def __str__(self) -> str: - return "{{#" + ".".join(self.selector) + "#}}" - - -# Type alias for segments -TemplateSegmentUnion = Union[TextSegment, VariableSegment] - - -@dataclass(frozen=True) -class Template: - """Unified template structure for Response nodes. - - Similar to SegmentGroup, but represents the template structure - without variable values - only marking variable selectors. - """ - - segments: list[TemplateSegmentUnion] - - @classmethod - def from_answer_template(cls, template_str: str) -> Template: - """Create a Template from an Answer node template string. - - Example: - "Hello, {{#node1.name#}}" -> [TextSegment("Hello, "), VariableSegment(["node1", "name"])] - - Args: - template_str: The answer template string - - Returns: - Template instance - """ - parser = VariableTemplateParser(template_str) - segments: list[TemplateSegmentUnion] = [] - - # Extract variable selectors to find all variables - variable_selectors = parser.extract_variable_selectors() - var_map = {var.variable: var.value_selector for var in variable_selectors} - - # Parse template to get ordered segments - # We need to split the template by variable placeholders while preserving order - import re - - # Create a regex pattern that matches variable placeholders - pattern = r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}" - - # Split template while keeping the delimiters (variable placeholders) - parts = re.split(pattern, template_str) - - for i, part in enumerate(parts): - if not part: - continue - - # Check if this part is a variable reference (odd indices after split) - if i % 2 == 1: # Odd indices are variable keys - # Remove the # symbols from the variable key - var_key = part - if var_key in var_map: - segments.append(VariableSegment(selector=list(var_map[var_key]))) - else: - # This shouldn't happen with valid templates - segments.append(TextSegment(text="{{" + part + "}}")) - else: - # Even indices are text segments - segments.append(TextSegment(text=part)) - - return cls(segments=segments) - - @classmethod - def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template: - """Create a Template from an End node outputs configuration. - - End nodes are treated as templates of concatenated variables with newlines. - - Example: - [{"variable": "text", "value_selector": ["node1", "text"]}, - {"variable": "result", "value_selector": ["node2", "result"]}] - -> - [VariableSegment(["node1", "text"]), - TextSegment("\n"), - VariableSegment(["node2", "result"])] - - Args: - outputs_config: List of output configurations with variable and value_selector - - Returns: - Template instance - """ - segments: list[TemplateSegmentUnion] = [] - - for i, output in enumerate(outputs_config): - if i > 0: - # Add newline separator between variables - segments.append(TextSegment(text="\n")) - - value_selector = output.get("value_selector", []) - variable_name = output.get("variable", "") - if value_selector: - segments.append(VariableSegment(selector=list(value_selector), variable_name=variable_name)) - - if len(segments) > 0 and isinstance(segments[-1], TextSegment): - segments = segments[:-1] - - return cls(segments=segments) - - def __str__(self) -> str: - """String representation of the template.""" - return "".join(str(segment) for segment in self.segments) diff --git a/api/graphon/nodes/base/usage_tracking_mixin.py b/api/graphon/nodes/base/usage_tracking_mixin.py deleted file mode 100644 index 955bfe67267..00000000000 --- a/api/graphon/nodes/base/usage_tracking_mixin.py +++ /dev/null @@ -1,28 +0,0 @@ -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState - - -class LLMUsageTrackingMixin: - """Provides shared helpers for merging and recording LLM usage within workflow nodes.""" - - graph_runtime_state: GraphRuntimeState - - @staticmethod - def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage: - """Return a combined usage snapshot, preserving zero-value inputs.""" - if new_usage is None or new_usage.total_tokens <= 0: - return current - if current.total_tokens == 0: - return new_usage - return current.plus(new_usage) - - def _accumulate_usage(self, usage: LLMUsage) -> None: - """Push usage into the graph runtime accumulator for downstream reporting.""" - if usage.total_tokens <= 0: - return - - current_usage = self.graph_runtime_state.llm_usage - if current_usage.total_tokens == 0: - self.graph_runtime_state.llm_usage = usage.model_copy() - else: - self.graph_runtime_state.llm_usage = current_usage.plus(usage) diff --git a/api/graphon/nodes/base/variable_template_parser.py b/api/graphon/nodes/base/variable_template_parser.py deleted file mode 100644 index de5e619e8c4..00000000000 --- a/api/graphon/nodes/base/variable_template_parser.py +++ /dev/null @@ -1,130 +0,0 @@ -import re -from collections.abc import Mapping, Sequence -from typing import Any - -from .entities import VariableSelector - -REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") - -SELECTOR_PATTERN = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") - - -def extract_selectors_from_template(template: str, /) -> Sequence[VariableSelector]: - parts = SELECTOR_PATTERN.split(template) - selectors = [] - for part in filter(lambda x: x, parts): - if "." in part and part[0] == "#" and part[-1] == "#": - selectors.append(VariableSelector(variable=f"{part}", value_selector=part[1:-1].split("."))) - return selectors - - -class VariableTemplateParser: - """ - !NOTE: Consider to use the new `segments` module instead of this class. - - A class for parsing and manipulating template variables in a string. - - Rules: - - 1. Template variables must be enclosed in `{{}}`. - 2. The template variable Key can only be: #node_id.var1.var2#. - 3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2. - - Example usage: - - template = "Hello, {{#node_id.query.name#}}! Your age is {{#node_id.query.age#}}." - parser = VariableTemplateParser(template) - - # Extract template variable keys - variable_keys = parser.extract() - print(variable_keys) - # Output: ['#node_id.query.name#', '#node_id.query.age#'] - - # Extract variable selectors - variable_selectors = parser.extract_variable_selectors() - print(variable_selectors) - # Output: [VariableSelector(variable='#node_id.query.name#', value_selector=['node_id', 'query', 'name']), - # VariableSelector(variable='#node_id.query.age#', value_selector=['node_id', 'query', 'age'])] - - # Format the template string - inputs = {'#node_id.query.name#': 'John', '#node_id.query.age#': 25}} - formatted_string = parser.format(inputs) - print(formatted_string) - # Output: "Hello, John! Your age is 25." - """ - - def __init__(self, template: str): - self.template = template - self.variable_keys = self.extract() - - def extract(self): - """ - Extracts all the template variable keys from the template string. - - Returns: - A list of template variable keys. - """ - # Regular expression to match the template rules - matches = re.findall(REGEX, self.template) - - first_group_matches = [match[0] for match in matches] - - return list(set(first_group_matches)) - - def extract_variable_selectors(self) -> list[VariableSelector]: - """ - Extracts the variable selectors from the template variable keys. - - Returns: - A list of VariableSelector objects representing the variable selectors. - """ - variable_selectors = [] - for variable_key in self.variable_keys: - remove_hash = variable_key.replace("#", "") - split_result = remove_hash.split(".") - if len(split_result) < 2: - continue - - variable_selectors.append(VariableSelector(variable=variable_key, value_selector=split_result)) - - return variable_selectors - - def format(self, inputs: Mapping[str, Any]) -> str: - """ - Formats the template string by replacing the template variables with their corresponding values. - - Args: - inputs: A dictionary containing the values for the template variables. - - Returns: - The formatted string with template variables replaced by their values. - """ - - def replacer(match): - key = match.group(1) - value = inputs.get(key, match.group(0)) # return original matched string if key not found - - if value is None: - value = "" - # convert the value to string - if isinstance(value, list | dict | bool | int | float): - value = str(value) - - # remove template variables if required - return VariableTemplateParser.remove_template_variables(value) - - prompt = re.sub(REGEX, replacer, self.template) - return re.sub(r"<\|.*?\|>", "", prompt) - - @classmethod - def remove_template_variables(cls, text: str): - """ - Removes the template variables from the given text. - - Args: - text: The text from which to remove the template variables. - - Returns: - The text with template variables removed. - """ - return re.sub(REGEX, r"{\1}", text) diff --git a/api/graphon/nodes/code/__init__.py b/api/graphon/nodes/code/__init__.py deleted file mode 100644 index 8c6dcc7fccb..00000000000 --- a/api/graphon/nodes/code/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .code_node import CodeNode - -__all__ = ["CodeNode"] diff --git a/api/graphon/nodes/code/code_node.py b/api/graphon/nodes/code/code_node.py deleted file mode 100644 index c2eea0bec1c..00000000000 --- a/api/graphon/nodes/code/code_node.py +++ /dev/null @@ -1,493 +0,0 @@ -from collections.abc import Mapping, Sequence -from decimal import Decimal -from textwrap import dedent -from typing import TYPE_CHECKING, Any, Protocol, cast - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.code.entities import CodeLanguage, CodeNodeData -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.variables.segments import ArrayFileSegment -from graphon.variables.types import SegmentType - -from .exc import ( - CodeNodeError, - DepthLimitError, - OutputValidationError, -) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - - -class WorkflowCodeExecutor(Protocol): - def execute( - self, - *, - language: CodeLanguage, - code: str, - inputs: Mapping[str, Any], - ) -> Mapping[str, Any]: ... - - def is_execution_error(self, error: Exception) -> bool: ... - - -def _build_default_config(*, language: CodeLanguage, code: str) -> Mapping[str, object]: - return { - "type": "code", - "config": { - "variables": [ - {"variable": "arg1", "value_selector": []}, - {"variable": "arg2", "value_selector": []}, - ], - "code_language": language, - "code": code, - "outputs": {"result": {"type": "string", "children": None}}, - }, - } - - -_DEFAULT_CODE_BY_LANGUAGE: Mapping[CodeLanguage, str] = { - CodeLanguage.PYTHON3: dedent( - """ - def main(arg1: str, arg2: str): - return { - "result": arg1 + arg2, - } - """ - ), - CodeLanguage.JAVASCRIPT: dedent( - """ - function main({arg1, arg2}) { - return { - result: arg1 + arg2 - } - } - """ - ), -} - - -class CodeNode(Node[CodeNodeData]): - node_type = BuiltinNodeTypes.CODE - _limits: CodeNodeLimits - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - code_executor: WorkflowCodeExecutor, - code_limits: CodeNodeLimits, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._code_executor: WorkflowCodeExecutor = code_executor - self._limits = code_limits - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ - code_language = CodeLanguage.PYTHON3 - if filters: - code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3)) - - default_code = _DEFAULT_CODE_BY_LANGUAGE.get(code_language) - if default_code is None: - raise CodeNodeError(f"Unsupported code language: {code_language}") - return _build_default_config(language=code_language, code=default_code) - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - # Get code language - code_language = self.node_data.code_language - code = self.node_data.code - - # Get variables - variables = {} - for variable_selector in self.node_data.variables: - variable_name = variable_selector.variable - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if isinstance(variable, ArrayFileSegment): - variables[variable_name] = [v.to_dict() for v in variable.value] if variable.value else None - else: - variables[variable_name] = variable.to_object() if variable else None - # Run code - try: - result = self._code_executor.execute( - language=code_language, - code=code, - inputs=variables, - ) - - # Transform result - result = self._transform_result(result=result, output_schema=self.node_data.outputs) - except CodeNodeError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ - ) - except Exception as e: - if not self._code_executor.is_execution_error(e): - raise - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ - ) - - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) - - def _check_string(self, value: str | None, variable: str) -> str | None: - """ - Check string - :param value: value - :param variable: variable - :return: - """ - if value is None: - return None - - if len(value) > self._limits.max_string_length: - raise OutputValidationError( - f"The length of output variable `{variable}` must be" - f" less than {self._limits.max_string_length} characters" - ) - - return value.replace("\x00", "") - - def _check_boolean(self, value: bool | None, variable: str) -> bool | None: - if value is None: - return None - - return value - - def _check_number(self, value: int | float | None, variable: str) -> int | float | None: - """ - Check number - :param value: value - :param variable: variable - :return: - """ - if value is None: - return None - - if value > self._limits.max_number or value < self._limits.min_number: - raise OutputValidationError( - f"Output variable `{variable}` is out of range," - f" it must be between {self._limits.min_number} and {self._limits.max_number}." - ) - - if isinstance(value, float): - decimal_value = Decimal(str(value)).normalize() - precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator] - # raise error if precision is too high - if precision > self._limits.max_precision: - raise OutputValidationError( - f"Output variable `{variable}` has too high precision," - f" it must be less than {self._limits.max_precision} digits." - ) - - return value - - def _transform_result( - self, - result: Mapping[str, Any], - output_schema: dict[str, CodeNodeData.Output] | None, - prefix: str = "", - depth: int = 1, - ): - # TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes. - # Note that `_transform_result` may produce lists containing `None` values, - # which don't conform to the type requirements of `Array*Segment` classes. - if depth > self._limits.max_depth: - raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.") - - transformed_result: dict[str, Any] = {} - if output_schema is None: - # validate output thought instance type - for output_name, output_value in result.items(): - if isinstance(output_value, dict): - self._transform_result( - result=output_value, - output_schema=None, - prefix=f"{prefix}.{output_name}" if prefix else output_name, - depth=depth + 1, - ) - elif isinstance(output_value, bool): - self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name) - elif isinstance(output_value, int | float): - self._check_number( - value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name - ) - elif isinstance(output_value, str): - self._check_string( - value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name - ) - elif isinstance(output_value, list): - first_element = output_value[0] if len(output_value) > 0 else None - if first_element is not None: - if isinstance(first_element, int | float) and all( - value is None or isinstance(value, int | float) for value in output_value - ): - for i, value in enumerate(output_value): - self._check_number( - value=value, - variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", - ) - elif isinstance(first_element, str) and all( - value is None or isinstance(value, str) for value in output_value - ): - for i, value in enumerate(output_value): - self._check_string( - value=value, - variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", - ) - elif ( - isinstance(first_element, dict) - and all(value is None or isinstance(value, dict) for value in output_value) - or isinstance(first_element, list) - and all(value is None or isinstance(value, list) for value in output_value) - ): - for i, value in enumerate(output_value): - if value is not None: - self._transform_result( - result=value, - output_schema=None, - prefix=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", - depth=depth + 1, - ) - else: - raise OutputValidationError( - f"Output {prefix}.{output_name} is not a valid array." - f" make sure all elements are of the same type." - ) - elif output_value is None: - pass - else: - raise OutputValidationError(f"Output {prefix}.{output_name} is not a valid type.") - - return result - - parameters_validated = {} - for output_name, output_config in output_schema.items(): - dot = "." if prefix else "" - if output_name not in result: - raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.") - - if output_config.type == SegmentType.OBJECT: - # check if output is object - if not isinstance(result.get(output_name), dict): - if result[output_name] is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an object," - f" got {type(result.get(output_name))} instead." - ) - else: - transformed_result[output_name] = self._transform_result( - result=result[output_name], - output_schema=output_config.children, - prefix=f"{prefix}.{output_name}", - depth=depth + 1, - ) - elif output_config.type == SegmentType.NUMBER: - # check if number available - value = result.get(output_name) - if value is not None and not isinstance(value, (int, float)): - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not a number," - f" got {type(result.get(output_name))} instead." - ) - checked = self._check_number(value=value, variable=f"{prefix}{dot}{output_name}") - # If the output is a boolean and the output schema specifies a NUMBER type, - # convert the boolean value to an integer. - # - # This ensures compatibility with existing workflows that may use - # `True` and `False` as values for NUMBER type outputs. - transformed_result[output_name] = self._convert_boolean_to_int(checked) - - elif output_config.type == SegmentType.STRING: - # check if string available - value = result.get(output_name) - if value is not None and not isinstance(value, str): - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} must be a string, got {type(value).__name__} instead" - ) - transformed_result[output_name] = self._check_string( - value=value, - variable=f"{prefix}{dot}{output_name}", - ) - elif output_config.type == SegmentType.BOOLEAN: - transformed_result[output_name] = self._check_boolean( - value=result[output_name], - variable=f"{prefix}{dot}{output_name}", - ) - elif output_config.type == SegmentType.ARRAY_NUMBER: - # check if array of number available - value = result[output_name] - if not isinstance(value, list): - if value is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead." - ) - else: - if len(value) > self._limits.max_number_array_length: - raise OutputValidationError( - f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {self._limits.max_number_array_length} elements." - ) - - for i, inner_value in enumerate(value): - if not isinstance(inner_value, (int, float)): - raise OutputValidationError( - f"The element at index {i} of output variable `{prefix}{dot}{output_name}` must be" - f" a number." - ) - _ = self._check_number(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]") - transformed_result[output_name] = [ - # If the element is a boolean and the output schema specifies a `array[number]` type, - # convert the boolean value to an integer. - # - # This ensures compatibility with existing workflows that may use - # `True` and `False` as values for NUMBER type outputs. - self._convert_boolean_to_int(v) - for v in value - ] - elif output_config.type == SegmentType.ARRAY_STRING: - # check if array of string available - if not isinstance(result[output_name], list): - if result[output_name] is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array," - f" got {type(result.get(output_name))} instead." - ) - else: - if len(result[output_name]) > self._limits.max_string_array_length: - raise OutputValidationError( - f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {self._limits.max_string_array_length} elements." - ) - - transformed_result[output_name] = [ - self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") - for i, value in enumerate(result[output_name]) - ] - elif output_config.type == SegmentType.ARRAY_OBJECT: - # check if array of object available - if not isinstance(result[output_name], list): - if result[output_name] is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array," - f" got {type(result.get(output_name))} instead." - ) - else: - if len(result[output_name]) > self._limits.max_object_array_length: - raise OutputValidationError( - f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {self._limits.max_object_array_length} elements." - ) - - for i, value in enumerate(result[output_name]): - if not isinstance(value, dict): - if value is None: - pass - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name}[{i}] is not an object," - f" got {type(value)} instead at index {i}." - ) - - transformed_result[output_name] = [ - None - if value is None - else self._transform_result( - result=value, - output_schema=output_config.children, - prefix=f"{prefix}{dot}{output_name}[{i}]", - depth=depth + 1, - ) - for i, value in enumerate(result[output_name]) - ] - elif output_config.type == SegmentType.ARRAY_BOOLEAN: - # check if array of object available - value = result[output_name] - if not isinstance(value, list): - if value is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array," - f" got {type(result.get(output_name))} instead." - ) - else: - for i, inner_value in enumerate(value): - if inner_value is not None and not isinstance(inner_value, bool): - raise OutputValidationError( - f"Output {prefix}{dot}{output_name}[{i}] is not a boolean," - f" got {type(inner_value)} instead." - ) - _ = self._check_boolean(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]") - transformed_result[output_name] = value - - else: - raise OutputValidationError(f"Output type {output_config.type} is not supported.") - - parameters_validated[output_name] = True - - # check if all output parameters are validated - if len(parameters_validated) != len(result): - raise CodeNodeError("Not all output parameters are validated.") - - return transformed_result - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: CodeNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - return { - node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables - } - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled - - @staticmethod - def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None: - """This function convert boolean to integers when the output schema specifies a NUMBER type. - - This ensures compatibility with existing workflows that may use - `True` and `False` as values for NUMBER type outputs. - """ - if value is None: - return None - if isinstance(value, bool): - return int(value) - return value diff --git a/api/graphon/nodes/code/entities.py b/api/graphon/nodes/code/entities.py deleted file mode 100644 index dc89d64495e..00000000000 --- a/api/graphon/nodes/code/entities.py +++ /dev/null @@ -1,57 +0,0 @@ -from enum import StrEnum -from typing import Annotated, Literal - -from pydantic import AfterValidator, BaseModel - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base.entities import VariableSelector -from graphon.variables.types import SegmentType - - -class CodeLanguage(StrEnum): - PYTHON3 = "python3" - JINJA2 = "jinja2" - JAVASCRIPT = "javascript" - - -_ALLOWED_OUTPUT_FROM_CODE = frozenset( - [ - SegmentType.STRING, - SegmentType.NUMBER, - SegmentType.OBJECT, - SegmentType.BOOLEAN, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_BOOLEAN, - ] -) - - -def _validate_type(segment_type: SegmentType) -> SegmentType: - if segment_type not in _ALLOWED_OUTPUT_FROM_CODE: - raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}") - return segment_type - - -class CodeNodeData(BaseNodeData): - """ - Code Node Data. - """ - - type: NodeType = BuiltinNodeTypes.CODE - - class Output(BaseModel): - type: Annotated[SegmentType, AfterValidator(_validate_type)] - children: dict[str, "CodeNodeData.Output"] | None = None - - class Dependency(BaseModel): - name: str - version: str - - variables: list[VariableSelector] - code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] - code: str - outputs: dict[str, Output] - dependencies: list[Dependency] | None = None diff --git a/api/graphon/nodes/code/exc.py b/api/graphon/nodes/code/exc.py deleted file mode 100644 index d6334fd554c..00000000000 --- a/api/graphon/nodes/code/exc.py +++ /dev/null @@ -1,16 +0,0 @@ -class CodeNodeError(ValueError): - """Base class for code node errors.""" - - pass - - -class OutputValidationError(CodeNodeError): - """Raised when there is an output validation error.""" - - pass - - -class DepthLimitError(CodeNodeError): - """Raised when the depth limit is reached.""" - - pass diff --git a/api/graphon/nodes/code/limits.py b/api/graphon/nodes/code/limits.py deleted file mode 100644 index a6b9e9e68ee..00000000000 --- a/api/graphon/nodes/code/limits.py +++ /dev/null @@ -1,13 +0,0 @@ -from dataclasses import dataclass - - -@dataclass(frozen=True) -class CodeNodeLimits: - max_string_length: int - max_number: int | float - min_number: int | float - max_precision: int - max_depth: int - max_number_array_length: int - max_string_array_length: int - max_object_array_length: int diff --git a/api/graphon/nodes/document_extractor/__init__.py b/api/graphon/nodes/document_extractor/__init__.py deleted file mode 100644 index 9922e3949da..00000000000 --- a/api/graphon/nodes/document_extractor/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .entities import DocumentExtractorNodeData, UnstructuredApiConfig -from .node import DocumentExtractorNode - -__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData", "UnstructuredApiConfig"] diff --git a/api/graphon/nodes/document_extractor/entities.py b/api/graphon/nodes/document_extractor/entities.py deleted file mode 100644 index 026a0cd2248..00000000000 --- a/api/graphon/nodes/document_extractor/entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from collections.abc import Sequence -from dataclasses import dataclass - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - - -class DocumentExtractorNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.DOCUMENT_EXTRACTOR - variable_selector: Sequence[str] - - -@dataclass(frozen=True) -class UnstructuredApiConfig: - api_url: str | None = None - api_key: str = "" diff --git a/api/graphon/nodes/document_extractor/exc.py b/api/graphon/nodes/document_extractor/exc.py deleted file mode 100644 index 5caf00ebc5f..00000000000 --- a/api/graphon/nodes/document_extractor/exc.py +++ /dev/null @@ -1,14 +0,0 @@ -class DocumentExtractorError(ValueError): - """Base exception for errors related to the DocumentExtractorNode.""" - - -class FileDownloadError(DocumentExtractorError): - """Exception raised when there's an error downloading a file.""" - - -class UnsupportedFileTypeError(DocumentExtractorError): - """Exception raised when trying to extract text from an unsupported file type.""" - - -class TextExtractionError(DocumentExtractorError): - """Exception raised when there's an error during text extraction from a file.""" diff --git a/api/graphon/nodes/document_extractor/node.py b/api/graphon/nodes/document_extractor/node.py deleted file mode 100644 index be46481e7dd..00000000000 --- a/api/graphon/nodes/document_extractor/node.py +++ /dev/null @@ -1,782 +0,0 @@ -import csv -import io -import json -import logging -import os -import tempfile -import zipfile -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -import charset_normalizer -import docx -import pandas as pd -import pypandoc -import pypdfium2 -import webvtt -import yaml -from docx.document import Document -from docx.oxml.table import CT_Tbl -from docx.oxml.text.paragraph import CT_P -from docx.table import Table -from docx.text.paragraph import Paragraph - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, file_manager -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.protocols import HttpClientProtocol -from graphon.variables import ArrayFileSegment -from graphon.variables.segments import ArrayStringSegment, FileSegment - -from .entities import DocumentExtractorNodeData, UnstructuredApiConfig -from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - - -class DocumentExtractorNode(Node[DocumentExtractorNodeData]): - """ - Extracts text content from various file types. - Supports plain text, PDF, and DOC/DOCX files. - """ - - node_type = BuiltinNodeTypes.DOCUMENT_EXTRACTOR - - @classmethod - def version(cls) -> str: - return "1" - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - unstructured_api_config: UnstructuredApiConfig | None = None, - http_client: HttpClientProtocol, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig() - self._http_client = http_client - - def _run(self): - variable_selector = self.node_data.variable_selector - variable = self.graph_runtime_state.variable_pool.get(variable_selector) - - if variable is None: - error_message = f"File variable not found for selector: {variable_selector}" - return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) - if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment): - error_message = f"Variable {variable_selector} is not an ArrayFileSegment" - return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) - - value = variable.value - inputs = {"variable_selector": variable_selector} - if isinstance(value, list): - value = list(filter(lambda x: x, value)) - process_data = {"documents": value if isinstance(value, list) else [value]} - - if not value: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={"text": ArrayStringSegment(value=[])}, - ) - - try: - if isinstance(value, list): - extracted_text_list = [ - _extract_text_from_file( - self._http_client, file, unstructured_api_config=self._unstructured_api_config - ) - for file in value - ] - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={"text": ArrayStringSegment(value=extracted_text_list)}, - ) - elif isinstance(value, File): - extracted_text = _extract_text_from_file( - self._http_client, value, unstructured_api_config=self._unstructured_api_config - ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={"text": extracted_text}, - ) - else: - raise DocumentExtractorError(f"Unsupported variable type: {type(value)}") - except DocumentExtractorError as e: - logger.warning(e, exc_info=True) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=inputs, - process_data=process_data, - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: DocumentExtractorNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - return {node_id + ".files": node_data.variable_selector} - - -def _extract_text_by_mime_type( - *, - file_content: bytes, - mime_type: str, - unstructured_api_config: UnstructuredApiConfig, -) -> str: - """Extract text from a file based on its MIME type.""" - match mime_type: - case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml": - return _extract_text_from_plain_text(file_content) - case "application/pdf": - return _extract_text_from_pdf(file_content) - case "application/msword": - return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config) - case "application/vnd.openxmlformats-officedocument.wordprocessingml.document": - return _extract_text_from_docx(file_content) - case "text/csv": - return _extract_text_from_csv(file_content) - case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel": - return _extract_text_from_excel(file_content) - case "application/vnd.ms-powerpoint": - return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config) - case "application/vnd.openxmlformats-officedocument.presentationml.presentation": - return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config) - case "application/epub+zip": - return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config) - case "message/rfc822": - return _extract_text_from_eml(file_content) - case "application/vnd.ms-outlook": - return _extract_text_from_msg(file_content) - case "application/json": - return _extract_text_from_json(file_content) - case "application/x-yaml" | "text/yaml": - return _extract_text_from_yaml(file_content) - case "text/vtt": - return _extract_text_from_vtt(file_content) - case "text/properties": - return _extract_text_from_properties(file_content) - case _: - raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") - - -def _extract_text_by_file_extension( - *, - file_content: bytes, - file_extension: str, - unstructured_api_config: UnstructuredApiConfig, -) -> str: - """Extract text from a file based on its file extension.""" - match file_extension: - case ( - ".txt" - | ".markdown" - | ".md" - | ".mdx" - | ".html" - | ".htm" - | ".xml" - | ".c" - | ".h" - | ".cpp" - | ".hpp" - | ".cc" - | ".cxx" - | ".c++" - | ".py" - | ".js" - | ".ts" - | ".jsx" - | ".tsx" - | ".java" - | ".php" - | ".rb" - | ".go" - | ".rs" - | ".swift" - | ".kt" - | ".scala" - | ".sh" - | ".bash" - | ".bat" - | ".ps1" - | ".sql" - | ".r" - | ".m" - | ".pl" - | ".lua" - | ".vim" - | ".asm" - | ".s" - | ".css" - | ".scss" - | ".less" - | ".sass" - | ".ini" - | ".cfg" - | ".conf" - | ".toml" - | ".env" - | ".log" - | ".vtt" - ): - return _extract_text_from_plain_text(file_content) - case ".json": - return _extract_text_from_json(file_content) - case ".yaml" | ".yml": - return _extract_text_from_yaml(file_content) - case ".pdf": - return _extract_text_from_pdf(file_content) - case ".doc": - return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config) - case ".docx": - return _extract_text_from_docx(file_content) - case ".csv": - return _extract_text_from_csv(file_content) - case ".xls" | ".xlsx": - return _extract_text_from_excel(file_content) - case ".ppt": - return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config) - case ".pptx": - return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config) - case ".epub": - return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config) - case ".eml": - return _extract_text_from_eml(file_content) - case ".msg": - return _extract_text_from_msg(file_content) - case ".properties": - return _extract_text_from_properties(file_content) - case _: - raise UnsupportedFileTypeError(f"Unsupported Extension Type: {file_extension}") - - -def _extract_text_from_plain_text(file_content: bytes) -> str: - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content, cp_isolation=["utf_8", "latin_1", "cp1252"]).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - return file_content.decode(encoding, errors="ignore") - except (UnicodeDecodeError, LookupError) as e: - # If decoding fails, try with utf-8 as last resort - try: - return file_content.decode("utf-8", errors="ignore") - except UnicodeDecodeError: - raise TextExtractionError(f"Failed to decode plain text file: {e}") from e - - -def _extract_text_from_json(file_content: bytes) -> str: - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - json_data = json.loads(file_content.decode(encoding, errors="ignore")) - return json.dumps(json_data, indent=2, ensure_ascii=False) - except (UnicodeDecodeError, LookupError, json.JSONDecodeError) as e: - # If decoding fails, try with utf-8 as last resort - try: - json_data = json.loads(file_content.decode("utf-8", errors="ignore")) - return json.dumps(json_data, indent=2, ensure_ascii=False) - except (UnicodeDecodeError, json.JSONDecodeError): - raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e - - -def _extract_text_from_yaml(file_content: bytes) -> str: - """Extract the content from yaml file""" - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore")) - return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) - except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e: - # If decoding fails, try with utf-8 as last resort - try: - yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore")) - return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) - except (UnicodeDecodeError, yaml.YAMLError): - raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e - - -def _extract_text_from_pdf(file_content: bytes) -> str: - try: - pdf_file = io.BytesIO(file_content) - pdf_document = pypdfium2.PdfDocument(pdf_file, autoclose=True) - text = "" - for page in pdf_document: - text_page = page.get_textpage() - text += text_page.get_text_range() - text_page.close() - page.close() - return text - except Exception as e: - raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e - - -def _extract_text_from_doc(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - """ - Extract text from a DOC file. - """ - from unstructured.partition.api import partition_via_api - - if not unstructured_api_config.api_url: - raise TextExtractionError("Unstructured API URL is not configured for DOC file processing.") - api_key = unstructured_api_config.api_key or "" - - try: - with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - return "\n".join([getattr(element, "text", "") for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from DOC: {str(e)}") from e - - -def parser_docx_part(block, doc: Document, content_items, i): - if isinstance(block, CT_P): - content_items.append((i, "paragraph", Paragraph(block, doc))) - elif isinstance(block, CT_Tbl): - content_items.append((i, "table", Table(block, doc))) - - -def _normalize_docx_zip(file_content: bytes) -> bytes: - """ - Some DOCX files (e.g. exported by Evernote on Windows) are malformed: - ZIP entry names use backslash (\\) as path separator instead of the forward - slash (/) required by both the ZIP spec and OOXML. On Linux/Mac the entry - "word\\document.xml" is never found when python-docx looks for - "word/document.xml", which triggers a KeyError about a missing relationship. - - This function rewrites the ZIP in-memory, normalizing all entry names to - use forward slashes without touching any actual document content. - """ - try: - with zipfile.ZipFile(io.BytesIO(file_content), "r") as zin: - out_buf = io.BytesIO() - with zipfile.ZipFile(out_buf, "w", compression=zipfile.ZIP_DEFLATED) as zout: - for item in zin.infolist(): - data = zin.read(item.filename) - # Normalize backslash path separators to forward slash - item.filename = item.filename.replace("\\", "/") - zout.writestr(item, data) - return out_buf.getvalue() - except zipfile.BadZipFile: - # Not a valid zip โ€” return as-is and let python-docx report the real error - return file_content - - -def _extract_text_from_docx(file_content: bytes) -> str: - """ - Extract text from a DOCX file. - For now support only paragraph and table add more if needed - """ - try: - doc_file = io.BytesIO(file_content) - try: - doc = docx.Document(doc_file) - except Exception as e: - logger.warning("Failed to parse DOCX, attempting to normalize ZIP entry paths: %s", e) - # Some DOCX files exported by tools like Evernote on Windows use - # backslash path separators in ZIP entries and/or single-quoted XML - # attributes, both of which break python-docx on Linux. Normalize and retry. - file_content = _normalize_docx_zip(file_content) - doc = docx.Document(io.BytesIO(file_content)) - text = [] - - # Keep track of paragraph and table positions - content_items: list[tuple[int, str, Table | Paragraph]] = [] - - it = iter(doc.element.body) - part = next(it, None) - i = 0 - while part is not None: - parser_docx_part(part, doc, content_items, i) - i = i + 1 - part = next(it, None) - - # Process sorted content - for _, item_type, item in content_items: - if item_type == "paragraph": - if isinstance(item, Table): - continue - text.append(item.text) - elif item_type == "table": - # Process tables - if not isinstance(item, Table): - continue - try: - # Check if any cell in the table has text - has_content = False - for row in item.rows: - if any(cell.text.strip() for cell in row.cells): - has_content = True - break - - if has_content: - cell_texts = [cell.text.replace("\n", "
") for cell in item.rows[0].cells] - markdown_table = f"| {' | '.join(cell_texts)} |\n" - markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n" - - for row in item.rows[1:]: - # Replace newlines with
in each cell - row_cells = [cell.text.replace("\n", "
") for cell in row.cells] - markdown_table += "| " + " | ".join(row_cells) + " |\n" - - text.append(markdown_table) - except Exception as e: - logger.warning("Failed to extract table from DOC: %s", e) - continue - - return "\n".join(text) - - except Exception as e: - raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e - - -def _download_file_content(http_client: HttpClientProtocol, file: File) -> bytes: - """Download the content of a file based on its transfer method.""" - try: - if file.transfer_method == FileTransferMethod.REMOTE_URL: - if file.remote_url is None: - raise FileDownloadError("Missing URL for remote file") - response = http_client.get(file.remote_url) - response.raise_for_status() - return response.content - else: - return file_manager.download(file) - except Exception as e: - raise FileDownloadError(f"Error downloading file: {str(e)}") from e - - -def _extract_text_from_file( - http_client: HttpClientProtocol, file: File, *, unstructured_api_config: UnstructuredApiConfig -) -> str: - file_content = _download_file_content(http_client, file) - if file.extension: - extracted_text = _extract_text_by_file_extension( - file_content=file_content, - file_extension=file.extension, - unstructured_api_config=unstructured_api_config, - ) - elif file.mime_type: - extracted_text = _extract_text_by_mime_type( - file_content=file_content, - mime_type=file.mime_type, - unstructured_api_config=unstructured_api_config, - ) - else: - raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing") - return extracted_text - - -def _extract_text_from_csv(file_content: bytes) -> str: - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - try: - csv_file = io.StringIO(file_content.decode(encoding, errors="ignore")) - except (UnicodeDecodeError, LookupError): - # If decoding fails, try with utf-8 as last resort - csv_file = io.StringIO(file_content.decode("utf-8", errors="ignore")) - - csv_reader = csv.reader(csv_file) - rows = list(csv_reader) - - if not rows: - return "" - - # Combine multi-line text in the header row - header_row = [cell.replace("\n", " ").replace("\r", "") for cell in rows[0]] - - # Create Markdown table - markdown_table = "| " + " | ".join(header_row) + " |\n" - markdown_table += "| " + " | ".join(["-" * len(col) for col in rows[0]]) + " |\n" - - # Process each data row and combine multi-line text in each cell - for row in rows[1:]: - processed_row = [cell.replace("\n", " ").replace("\r", "") for cell in row] - markdown_table += "| " + " | ".join(processed_row) + " |\n" - - return markdown_table - except Exception as e: - raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e - - -def _extract_text_from_excel(file_content: bytes) -> str: - """Extract text from an Excel file using pandas.""" - - def _construct_markdown_table(df: pd.DataFrame) -> str: - """Manually construct a Markdown table from a DataFrame.""" - # Construct the header row - header_row = "| " + " | ".join(df.columns) + " |" - - # Construct the separator row - separator_row = "| " + " | ".join(["-" * len(col) for col in df.columns]) + " |" - - # Construct the data rows - data_rows = [] - for _, row in df.iterrows(): - data_row = "| " + " | ".join(map(str, row)) + " |" - data_rows.append(data_row) - - # Combine all rows into a single string - markdown_table = "\n".join([header_row, separator_row] + data_rows) - return markdown_table - - try: - excel_file = pd.ExcelFile(io.BytesIO(file_content)) - markdown_table = "" - for sheet_name in excel_file.sheet_names: - try: - df = excel_file.parse(sheet_name=sheet_name) - df.dropna(how="all", inplace=True) - - # Combine multi-line text in each cell into a single line - df = df.map(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) - - # Combine multi-line text in column names into a single line - df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns]) - - # Manually construct the Markdown table - markdown_table += _construct_markdown_table(df) + "\n\n" - except Exception: - continue - return markdown_table - except Exception as e: - raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e - - -def _extract_text_from_ppt(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - from unstructured.partition.api import partition_via_api - from unstructured.partition.ppt import partition_ppt - - api_key = unstructured_api_config.api_key or "" - - try: - if unstructured_api_config.api_url: - with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - else: - with io.BytesIO(file_content) as file: - elements = partition_ppt(file=file) - return "\n".join([getattr(element, "text", "") for element in elements]) - - except Exception as e: - raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e - - -def _extract_text_from_pptx(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - from unstructured.partition.api import partition_via_api - from unstructured.partition.pptx import partition_pptx - - api_key = unstructured_api_config.api_key or "" - - try: - if unstructured_api_config.api_url: - with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - else: - with io.BytesIO(file_content) as file: - elements = partition_pptx(file=file) - return "\n".join([getattr(element, "text", "") for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e - - -def _extract_text_from_epub(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - from unstructured.partition.api import partition_via_api - from unstructured.partition.epub import partition_epub - - api_key = unstructured_api_config.api_key or "" - - try: - if unstructured_api_config.api_url: - with tempfile.NamedTemporaryFile(suffix=".epub", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - else: - pypandoc.download_pandoc() - with io.BytesIO(file_content) as file: - elements = partition_epub(file=file) - return "\n".join([str(element) for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from EPUB: {str(e)}") from e - - -def _extract_text_from_eml(file_content: bytes) -> str: - from unstructured.partition.email import partition_email - - try: - with io.BytesIO(file_content) as file: - elements = partition_email(file=file) - return "\n".join([str(element) for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from EML: {str(e)}") from e - - -def _extract_text_from_msg(file_content: bytes) -> str: - from unstructured.partition.msg import partition_msg - - try: - with io.BytesIO(file_content) as file: - elements = partition_msg(file=file) - return "\n".join([str(element) for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e - - -def _extract_text_from_vtt(vtt_bytes: bytes) -> str: - text = _extract_text_from_plain_text(vtt_bytes) - - # remove bom - text = text.lstrip("\ufeff") - - raw_results = [] - for caption in webvtt.from_string(text): - raw_results.append((caption.voice, caption.text)) - - # Merge consecutive utterances by the same speaker - merged_results = [] - if raw_results: - current_speaker, current_text = raw_results[0] - - for i in range(1, len(raw_results)): - spk, txt = raw_results[i] - if spk is None: - merged_results.append((None, current_text)) - continue - - if spk == current_speaker: - # If it is the same speaker, merge the utterances (joined by space) - current_text += " " + txt - else: - # If the speaker changes, register the utterance so far and move on - merged_results.append((current_speaker, current_text)) - current_speaker, current_text = spk, txt - - # Add the last element - merged_results.append((current_speaker, current_text)) - else: - merged_results = raw_results - - # Return the result in the specified format: Speaker "text" style - formatted = [f'{spk or ""} "{txt}"' for spk, txt in merged_results] - return "\n".join(formatted) - - -def _extract_text_from_properties(file_content: bytes) -> str: - try: - text = _extract_text_from_plain_text(file_content) - lines = text.splitlines() - result = [] - for line in lines: - line = line.strip() - # Preserve comments and empty lines - if not line or line.startswith("#") or line.startswith("!"): - result.append(line) - continue - - if "=" in line: - key, value = line.split("=", 1) - elif ":" in line: - key, value = line.split(":", 1) - else: - key, value = line, "" - - result.append(f"{key.strip()}: {value.strip()}") - - return "\n".join(result) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from properties file: {str(e)}") from e diff --git a/api/graphon/nodes/end/__init__.py b/api/graphon/nodes/end/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/graphon/nodes/end/end_node.py b/api/graphon/nodes/end/end_node.py deleted file mode 100644 index 11b9e586442..00000000000 --- a/api/graphon/nodes/end/end_node.py +++ /dev/null @@ -1,47 +0,0 @@ -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.base.template import Template -from graphon.nodes.end.entities import EndNodeData - - -class EndNode(Node[EndNodeData]): - node_type = BuiltinNodeTypes.END - execution_type = NodeExecutionType.RESPONSE - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run node - collect all outputs at once. - - This method runs after streaming is complete (if streaming was enabled). - It collects all output variables and returns them. - """ - output_variables = self.node_data.outputs - - outputs = {} - for variable_selector in output_variables: - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - value = variable.to_object() if variable is not None else None - outputs[variable_selector.variable] = value - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=outputs, - outputs=outputs, - ) - - def get_streaming_template(self) -> Template: - """ - Get the template for streaming. - - Returns: - Template instance for this End node - """ - outputs_config = [ - {"variable": output.variable, "value_selector": output.value_selector} for output in self.node_data.outputs - ] - return Template.from_end_outputs(outputs_config) diff --git a/api/graphon/nodes/end/entities.py b/api/graphon/nodes/end/entities.py deleted file mode 100644 index 839aed7e4b7..00000000000 --- a/api/graphon/nodes/end/entities.py +++ /dev/null @@ -1,27 +0,0 @@ -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base.entities import OutputVariableEntity - - -class EndNodeData(BaseNodeData): - """ - END Node Data. - """ - - type: NodeType = BuiltinNodeTypes.END - outputs: list[OutputVariableEntity] - - -class EndStreamParam(BaseModel): - """ - EndStreamParam entity - """ - - end_dependencies: dict[str, list[str]] = Field( - ..., description="end dependencies (end node id -> dependent node ids)" - ) - end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field( - ..., description="end stream variable selector mapping (end node id -> stream variable selectors)" - ) diff --git a/api/graphon/nodes/http_request/__init__.py b/api/graphon/nodes/http_request/__init__.py deleted file mode 100644 index b29099db230..00000000000 --- a/api/graphon/nodes/http_request/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .config import build_http_request_config, resolve_http_request_config -from .entities import ( - HTTP_REQUEST_CONFIG_FILTER_KEY, - BodyData, - HttpRequestNodeAuthorization, - HttpRequestNodeBody, - HttpRequestNodeConfig, - HttpRequestNodeData, -) -from .node import HttpRequestNode - -__all__ = [ - "HTTP_REQUEST_CONFIG_FILTER_KEY", - "BodyData", - "HttpRequestNode", - "HttpRequestNodeAuthorization", - "HttpRequestNodeBody", - "HttpRequestNodeConfig", - "HttpRequestNodeData", - "build_http_request_config", - "resolve_http_request_config", -] diff --git a/api/graphon/nodes/http_request/config.py b/api/graphon/nodes/http_request/config.py deleted file mode 100644 index 53bf6c7ae4c..00000000000 --- a/api/graphon/nodes/http_request/config.py +++ /dev/null @@ -1,33 +0,0 @@ -from collections.abc import Mapping - -from .entities import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNodeConfig - - -def build_http_request_config( - *, - max_connect_timeout: int = 10, - max_read_timeout: int = 600, - max_write_timeout: int = 600, - max_binary_size: int = 10 * 1024 * 1024, - max_text_size: int = 1 * 1024 * 1024, - ssl_verify: bool = True, - ssrf_default_max_retries: int = 3, -) -> HttpRequestNodeConfig: - return HttpRequestNodeConfig( - max_connect_timeout=max_connect_timeout, - max_read_timeout=max_read_timeout, - max_write_timeout=max_write_timeout, - max_binary_size=max_binary_size, - max_text_size=max_text_size, - ssl_verify=ssl_verify, - ssrf_default_max_retries=ssrf_default_max_retries, - ) - - -def resolve_http_request_config(filters: Mapping[str, object] | None) -> HttpRequestNodeConfig: - if not filters: - raise ValueError("http_request_config is required to build HTTP request default config") - config = filters.get(HTTP_REQUEST_CONFIG_FILTER_KEY) - if not isinstance(config, HttpRequestNodeConfig): - raise ValueError("http_request_config must be an HttpRequestNodeConfig instance") - return config diff --git a/api/graphon/nodes/http_request/entities.py b/api/graphon/nodes/http_request/entities.py deleted file mode 100644 index 6fa067bdd17..00000000000 --- a/api/graphon/nodes/http_request/entities.py +++ /dev/null @@ -1,241 +0,0 @@ -import mimetypes -from collections.abc import Sequence -from dataclasses import dataclass -from email.message import Message -from typing import Any, Literal - -import charset_normalizer -import httpx -from pydantic import BaseModel, Field, ValidationInfo, field_validator - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - -HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config" - - -class HttpRequestNodeAuthorizationConfig(BaseModel): - type: Literal["basic", "bearer", "custom"] - api_key: str - header: str = "" - - -class HttpRequestNodeAuthorization(BaseModel): - type: Literal["no-auth", "api-key"] - config: HttpRequestNodeAuthorizationConfig | None = None - - @field_validator("config", mode="before") - @classmethod - def check_config(cls, v: HttpRequestNodeAuthorizationConfig, values: ValidationInfo): - """ - Check config, if type is no-auth, config should be None, otherwise it should be a dict. - """ - if values.data["type"] == "no-auth": - return None - else: - if not v or not isinstance(v, dict): - raise ValueError("config should be a dict") - - return v - - -class BodyData(BaseModel): - key: str = "" - type: Literal["file", "text"] - value: str = "" - file: Sequence[str] = Field(default_factory=list) - - -class HttpRequestNodeBody(BaseModel): - type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json", "binary"] - data: Sequence[BodyData] = Field(default_factory=list) - - @field_validator("data", mode="before") - @classmethod - def check_data(cls, v: Any): - """For compatibility, if body is not set, return empty list.""" - if not v: - return [] - if isinstance(v, str): - return [BodyData(key="", type="text", value=v)] - return v - - -class HttpRequestNodeTimeout(BaseModel): - connect: int | None = None - read: int | None = None - write: int | None = None - - -@dataclass(frozen=True, slots=True) -class HttpRequestNodeConfig: - max_connect_timeout: int - max_read_timeout: int - max_write_timeout: int - max_binary_size: int - max_text_size: int - ssl_verify: bool - ssrf_default_max_retries: int - - def default_timeout(self) -> "HttpRequestNodeTimeout": - return HttpRequestNodeTimeout( - connect=self.max_connect_timeout, - read=self.max_read_timeout, - write=self.max_write_timeout, - ) - - -class HttpRequestNodeData(BaseNodeData): - """ - Code Node Data. - """ - - type: NodeType = BuiltinNodeTypes.HTTP_REQUEST - method: Literal[ - "get", - "post", - "put", - "patch", - "delete", - "head", - "options", - "GET", - "POST", - "PUT", - "PATCH", - "DELETE", - "HEAD", - "OPTIONS", - ] - url: str - authorization: HttpRequestNodeAuthorization - headers: str - params: str - body: HttpRequestNodeBody | None = None - timeout: HttpRequestNodeTimeout | None = None - ssl_verify: bool | None = None - - -class Response: - headers: dict[str, str] - response: httpx.Response - _cached_text: str | None - - def __init__(self, response: httpx.Response): - self.response = response - self.headers = dict(response.headers) - self._cached_text = None - - @property - def is_file(self): - """ - Determine if the response contains a file by checking: - 1. Content-Disposition header (RFC 6266) - 2. Content characteristics - 3. MIME type analysis - """ - content_type = self.content_type.split(";")[0].strip().lower() - parsed_content_disposition = self.parsed_content_disposition - - # Check if it's explicitly marked as an attachment - if parsed_content_disposition: - disp_type = parsed_content_disposition.get_content_disposition() # Returns 'attachment', 'inline', or None - filename = parsed_content_disposition.get_filename() # Returns filename if present, None otherwise - if disp_type == "attachment" or filename is not None: - return True - - # For 'text/' types, only 'csv' should be downloaded as file - if content_type.startswith("text/") and "csv" not in content_type: - return False - - # For application types, try to detect if it's a text-based format - if content_type.startswith("application/"): - # Common text-based application types - if any( - text_type in content_type - for text_type in ("json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql") - ): - return False - - # Try to detect if content is text-based by sampling first few bytes - try: - # Sample first 1024 bytes for text detection - content_sample = self.response.content[:1024] - content_sample.decode("utf-8") - # If we can decode as UTF-8 and find common text patterns, likely not a file - text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ") - if any(marker in content_sample for marker in text_markers): - return False - except UnicodeDecodeError: - # If we can't decode as UTF-8, likely a binary file - return True - - # For other types, use MIME type analysis - main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or "")) - if main_type: - return main_type.split("/")[0] in ("application", "image", "audio", "video") - - # For unknown types, check if it's a media type - return any(media_type in content_type for media_type in ("image/", "audio/", "video/")) - - @property - def content_type(self) -> str: - return self.headers.get("content-type", "") - - @property - def text(self) -> str: - """ - Get response text with robust encoding detection. - - Uses charset_normalizer for better encoding detection than httpx's default, - which helps handle Chinese and other non-ASCII characters properly. - """ - # Check cache first - if hasattr(self, "_cached_text") and self._cached_text is not None: - return self._cached_text - - # Try charset_normalizer for robust encoding detection first - detected_encoding = charset_normalizer.from_bytes(self.response.content).best() - if detected_encoding and detected_encoding.encoding: - try: - text = self.response.content.decode(detected_encoding.encoding) - self._cached_text = text - return text - except (UnicodeDecodeError, TypeError, LookupError): - # Fallback to httpx's encoding detection if charset_normalizer fails - pass - - # Fallback to httpx's built-in encoding detection - text = self.response.text - self._cached_text = text - return text - - @property - def content(self) -> bytes: - return self.response.content - - @property - def status_code(self) -> int: - return self.response.status_code - - @property - def size(self) -> int: - return len(self.content) - - @property - def readable_size(self) -> str: - if self.size < 1024: - return f"{self.size} bytes" - elif self.size < 1024 * 1024: - return f"{(self.size / 1024):.2f} KB" - else: - return f"{(self.size / 1024 / 1024):.2f} MB" - - @property - def parsed_content_disposition(self) -> Message | None: - content_disposition = self.headers.get("content-disposition", "") - if content_disposition: - msg = Message() - msg["content-disposition"] = content_disposition - return msg - return None diff --git a/api/graphon/nodes/http_request/exc.py b/api/graphon/nodes/http_request/exc.py deleted file mode 100644 index 46613c9e861..00000000000 --- a/api/graphon/nodes/http_request/exc.py +++ /dev/null @@ -1,26 +0,0 @@ -class HttpRequestNodeError(ValueError): - """Custom error for HTTP request node.""" - - -class AuthorizationConfigError(HttpRequestNodeError): - """Raised when authorization config is missing or invalid.""" - - -class FileFetchError(HttpRequestNodeError): - """Raised when a file cannot be fetched.""" - - -class InvalidHttpMethodError(HttpRequestNodeError): - """Raised when an invalid HTTP method is used.""" - - -class ResponseSizeError(HttpRequestNodeError): - """Raised when the response size exceeds the allowed threshold.""" - - -class RequestBodyError(HttpRequestNodeError): - """Raised when the request body is invalid.""" - - -class InvalidURLError(HttpRequestNodeError): - """Raised when the URL is invalid.""" diff --git a/api/graphon/nodes/http_request/executor.py b/api/graphon/nodes/http_request/executor.py deleted file mode 100644 index 0c6f4ecd3a6..00000000000 --- a/api/graphon/nodes/http_request/executor.py +++ /dev/null @@ -1,488 +0,0 @@ -import base64 -import json -import secrets -import string -from collections.abc import Callable, Mapping -from copy import deepcopy -from typing import Any, Literal -from urllib.parse import urlencode, urlparse - -import httpx -from json_repair import repair_json - -from graphon.file.enums import FileTransferMethod -from graphon.runtime import VariablePool -from graphon.variables.segments import ArrayFileSegment, FileSegment - -from ..protocols import FileManagerProtocol, HttpClientProtocol -from .entities import ( - HttpRequestNodeAuthorization, - HttpRequestNodeConfig, - HttpRequestNodeData, - HttpRequestNodeTimeout, - Response, -) -from .exc import ( - AuthorizationConfigError, - FileFetchError, - HttpRequestNodeError, - InvalidHttpMethodError, - InvalidURLError, - RequestBodyError, - ResponseSizeError, -) - -BODY_TYPE_TO_CONTENT_TYPE = { - "json": "application/json", - "x-www-form-urlencoded": "application/x-www-form-urlencoded", - "form-data": "multipart/form-data", - "raw-text": "text/plain", -} - - -class Executor: - method: Literal[ - "get", - "head", - "post", - "put", - "delete", - "patch", - "options", - "GET", - "POST", - "PUT", - "PATCH", - "DELETE", - "HEAD", - "OPTIONS", - ] - url: str - params: list[tuple[str, str]] | None - content: str | bytes | None - data: Mapping[str, Any] | None - files: list[tuple[str, tuple[str | None, bytes, str]]] | None - json: Any - headers: dict[str, str] - auth: HttpRequestNodeAuthorization - timeout: HttpRequestNodeTimeout - max_retries: int - - boundary: str - - def __init__( - self, - *, - node_data: HttpRequestNodeData, - timeout: HttpRequestNodeTimeout, - variable_pool: VariablePool, - http_request_config: HttpRequestNodeConfig, - max_retries: int | None = None, - ssl_verify: bool | None = None, - http_client: HttpClientProtocol, - file_manager: FileManagerProtocol, - ): - self._http_request_config = http_request_config - # If authorization API key is present, convert the API key using the variable pool - if node_data.authorization.type == "api-key": - if node_data.authorization.config is None: - raise AuthorizationConfigError("authorization config is required") - node_data.authorization.config.api_key = variable_pool.convert_template( - node_data.authorization.config.api_key - ).text - # Validate that API key is not empty after template conversion - if not node_data.authorization.config.api_key or not node_data.authorization.config.api_key.strip(): - raise AuthorizationConfigError( - "API key is required for authorization but was empty. Please provide a valid API key." - ) - - self.url = node_data.url - self.method = node_data.method - self.auth = node_data.authorization - self.timeout = timeout - self.ssl_verify = ssl_verify if ssl_verify is not None else node_data.ssl_verify - if self.ssl_verify is None: - self.ssl_verify = self._http_request_config.ssl_verify - if not isinstance(self.ssl_verify, bool): - raise ValueError("ssl_verify must be a boolean") - self.params = None - self.headers = {} - self.content = None - self.files = None - self.data = None - self.json = None - self.max_retries = ( - max_retries if max_retries is not None else self._http_request_config.ssrf_default_max_retries - ) - self._http_client = http_client - self._file_manager = file_manager - - # init template - self.variable_pool = variable_pool - self.node_data = node_data - self._initialize() - - def _initialize(self): - self._init_url() - self._init_params() - self._init_headers() - self._init_body() - - def _init_url(self): - self.url = self.variable_pool.convert_template(self.node_data.url).text - - # check if url is a valid URL - if not self.url: - raise InvalidURLError("url is required") - if not self.url.startswith(("http://", "https://")): - raise InvalidURLError("url should start with http:// or https://") - - def _init_params(self): - """ - Almost same as _init_headers(), difference: - 1. response a list tuple to support same key, like 'aa=1&aa=2' - 2. param value may have '\n', we need to splitlines then extract the variable value. - """ - result = [] - for line in self.node_data.params.splitlines(): - if not (line := line.strip()): - continue - - key, *value = line.split(":", 1) - if not (key := key.strip()): - continue - - value_str = value[0].strip() if value else "" - result.append( - (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text) - ) - - if result: - self.params = result - - def _init_headers(self): - """ - Convert the header string of frontend to a dictionary. - - Each line in the header string represents a key-value pair. - Keys and values are separated by ':'. - Empty values are allowed. - - Examples: - 'aa:bb\n cc:dd' -> {'aa': 'bb', 'cc': 'dd'} - 'aa:\n cc:dd\n' -> {'aa': '', 'cc': 'dd'} - 'aa\n cc : dd' -> {'aa': '', 'cc': 'dd'} - - """ - headers = self.variable_pool.convert_template(self.node_data.headers).text - self.headers = { - key.strip(): (value[0].strip() if value else "") - for line in headers.splitlines() - if line.strip() - for key, *value in [line.split(":", 1)] - } - - def _init_body(self): - body = self.node_data.body - if body is not None: - data = body.data - match body.type: - case "none": - self.content = "" - case "raw-text": - if len(data) != 1: - raise RequestBodyError("raw-text body type should have exactly one item") - self.content = self.variable_pool.convert_template(data[0].value).text - case "json": - if len(data) != 1: - raise RequestBodyError("json body type should have exactly one item") - json_string = self.variable_pool.convert_template(data[0].value).text - try: - repaired = repair_json(json_string) - json_object = json.loads(repaired, strict=False) - except json.JSONDecodeError as e: - raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e - self.json = json_object - # self.json = self._parse_object_contains_variables(json_object) - case "binary": - if len(data) != 1: - raise RequestBodyError("binary body type should have exactly one item") - file_selector = data[0].file - file_variable = self.variable_pool.get_file(file_selector) - if file_variable is None: - raise FileFetchError(f"cannot fetch file with selector {file_selector}") - file = file_variable.value - self.content = self._file_manager.download(file) - case "x-www-form-urlencoded": - form_data = { - self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( - item.value - ).text - for item in data - } - self.data = form_data - case "form-data": - form_data = { - self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( - item.value - ).text - for item in filter(lambda item: item.type == "text", data) - } - file_selectors = { - self.variable_pool.convert_template(item.key).text: item.file - for item in filter(lambda item: item.type == "file", data) - } - - # get files from file_selectors, add support for array file variables - files_list = [] - for key, selector in file_selectors.items(): - segment = self.variable_pool.get(selector) - if isinstance(segment, FileSegment): - files_list.append((key, [segment.value])) - elif isinstance(segment, ArrayFileSegment): - files_list.append((key, list(segment.value))) - - # get files from file_manager - files: dict[str, list[tuple[str | None, bytes, str]]] = {} - for key, files_in_segment in files_list: - for file in files_in_segment: - if file.reference is not None or ( - file.transfer_method == FileTransferMethod.REMOTE_URL and file.remote_url is not None - ): - file_tuple = ( - file.filename, - self._file_manager.download(file), - file.mime_type or "application/octet-stream", - ) - if key not in files: - files[key] = [] - files[key].append(file_tuple) - - # convert files to list for httpx request - # If there are no actual files, we still need to force httpx to use `multipart/form-data`. - # This is achieved by inserting a harmless placeholder file that will be ignored by the server. - if not files: - self.files = [("__multipart_placeholder__", ("", b"", "application/octet-stream"))] - if files: - self.files = [] - for key, file_tuples in files.items(): - for file_tuple in file_tuples: - self.files.append((key, file_tuple)) - - self.data = form_data - - def _assembling_headers(self) -> dict[str, Any]: - authorization = deepcopy(self.auth) - headers = deepcopy(self.headers) or {} - if self.auth.type == "api-key": - if self.auth.config is None: - raise AuthorizationConfigError("self.authorization config is required") - if authorization.config is None: - raise AuthorizationConfigError("authorization config is required") - - if not authorization.config.header: - authorization.config.header = "Authorization" - - if self.auth.config.type == "bearer" and authorization.config.api_key: - headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" - elif self.auth.config.type == "basic" and authorization.config.api_key: - credentials = authorization.config.api_key - if ":" in credentials: - encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8") - else: - encoded_credentials = credentials - headers[authorization.config.header] = f"Basic {encoded_credentials}" - elif self.auth.config.type == "custom": - if authorization.config.header and authorization.config.api_key: - headers[authorization.config.header] = authorization.config.api_key - - # Handle Content-Type for multipart/form-data requests - # Fix for issue #23829: Missing boundary when using multipart/form-data - body = self.node_data.body - if body and body.type == "form-data": - # For multipart/form-data with files (including placeholder files), - # remove any manually set Content-Type header to let httpx handle - # For multipart/form-data, if any files are present (including placeholder files), - # we must remove any manually set Content-Type header. This is because httpx needs to - # automatically set the Content-Type and boundary for multipart encoding whenever files - # are included, even if they are placeholders, to avoid boundary issues and ensure correct - # file upload behaviour. Manually setting Content-Type can cause httpx to fail to set the - # boundary, resulting in invalid requests. - if self.files: - # Remove Content-Type if it was manually set to avoid boundary issues - headers = {k: v for k, v in headers.items() if k.lower() != "content-type"} - else: - # No files at all, set Content-Type manually - if "content-type" not in (k.lower() for k in headers): - headers["Content-Type"] = "multipart/form-data" - elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE: - # Set Content-Type for other body types - if "content-type" not in (k.lower() for k in headers): - headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] - - return headers - - def _validate_and_parse_response(self, response: httpx.Response) -> Response: - executor_response = Response(response) - - threshold_size = ( - self._http_request_config.max_binary_size - if executor_response.is_file - else self._http_request_config.max_text_size - ) - if executor_response.size > threshold_size: - raise ResponseSizeError( - f"{'File' if executor_response.is_file else 'Text'} size is too large," - f" max size is {threshold_size / 1024 / 1024:.2f} MB," - f" but current size is {executor_response.readable_size}." - ) - - return executor_response - - def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: - """ - do http request depending on api bundle - """ - _METHOD_MAP: dict[str, Callable[..., httpx.Response]] = { - "get": self._http_client.get, - "head": self._http_client.head, - "post": self._http_client.post, - "put": self._http_client.put, - "delete": self._http_client.delete, - "patch": self._http_client.patch, - } - method_lc = self.method.lower() - if method_lc not in _METHOD_MAP: - raise InvalidHttpMethodError(f"Invalid http method {self.method}") - - request_args: dict[str, Any] = { - "data": self.data, - "files": self.files, - "json": self.json, - "content": self.content, - "headers": headers, - "params": self.params, - "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), - "ssl_verify": self.ssl_verify, - "follow_redirects": True, - } - # request_args = {k: v for k, v in request_args.items() if v is not None} - try: - response = _METHOD_MAP[method_lc]( - url=self.url, - **request_args, - max_retries=self.max_retries, - ) - except self._http_client.max_retries_exceeded_error as e: - raise HttpRequestNodeError(f"Reached maximum retries for URL {self.url}") from e - except self._http_client.request_error as e: - raise HttpRequestNodeError(str(e)) from e - return response - - def invoke(self) -> Response: - # assemble headers - headers = self._assembling_headers() - # do http request - response = self._do_http_request(headers) - # validate response - return self._validate_and_parse_response(response) - - def to_log(self): - url_parts = urlparse(self.url) - path = url_parts.path or "/" - - # Add query parameters - if self.params: - query_string = urlencode(self.params) - path += f"?{query_string}" - elif url_parts.query: - path += f"?{url_parts.query}" - - raw = f"{self.method.upper()} {path} HTTP/1.1\r\n" - raw += f"Host: {url_parts.netloc}\r\n" - - headers = self._assembling_headers() - body = self.node_data.body - boundary = f"----WebKitFormBoundary{_generate_random_string(16)}" - if body: - if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE: - headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] - if body.type == "form-data": - headers["Content-Type"] = f"multipart/form-data; boundary={boundary}" - for k, v in headers.items(): - if self.auth.type == "api-key": - authorization_header = "Authorization" - if self.auth.config and self.auth.config.header: - authorization_header = self.auth.config.header - if k.lower() == authorization_header.lower(): - raw += f"{k}: {'*' * len(v)}\r\n" - continue - raw += f"{k}: {v}\r\n" - - body_string = "" - # Only log actual files if present. - # '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file. - # This prevents logging meaningless placeholder entries. - if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files): - for file_entry in self.files: - # file_entry should be (key, (filename, content, mime_type)), but handle edge cases - if len(file_entry) != 2 or len(file_entry[1]) < 2: - continue # skip malformed entries - key = file_entry[0] - content = file_entry[1][1] - body_string += f"--{boundary}\r\n" - body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - # decode content safely - # Do not decode binary content; use a placeholder with file metadata instead. - # Includes filename, size, and MIME type for better logging context. - body_string += ( - f"\r\n" - ) - body_string += f"--{boundary}--\r\n" - elif self.node_data.body: - if self.content: - # If content is bytes, do not decode it; show a placeholder with size. - # Provides content size information for binary data without exposing the raw bytes. - if isinstance(self.content, bytes): - body_string = f"" - else: - body_string = self.content - elif self.data and self.node_data.body.type == "x-www-form-urlencoded": - body_string = urlencode(self.data) - elif self.data and self.node_data.body.type == "form-data": - for key, value in self.data.items(): - body_string += f"--{boundary}\r\n" - body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - body_string += f"{value}\r\n" - body_string += f"--{boundary}--\r\n" - elif self.json: - body_string = json.dumps(self.json) - elif self.node_data.body.type == "raw-text": - if len(self.node_data.body.data) != 1: - raise RequestBodyError("raw-text body type should have exactly one item") - body_string = self.node_data.body.data[0].value - if body_string: - raw += f"Content-Length: {len(body_string)}\r\n" - raw += "\r\n" # Empty line between headers and body - raw += body_string - - return raw - - -def _generate_random_string(n: int) -> str: - """ - Generate a random string of lowercase ASCII letters. - - Args: - n (int): The length of the random string to generate. - - Returns: - str: A random string of lowercase ASCII letters with length n. - - Example: - >>> _generate_random_string(5) - 'abcde' - """ - return "".join(secrets.choice(string.ascii_lowercase) for _ in range(n)) diff --git a/api/graphon/nodes/http_request/node.py b/api/graphon/nodes/http_request/node.py deleted file mode 100644 index 3d74347a7fc..00000000000 --- a/api/graphon/nodes/http_request/node.py +++ /dev/null @@ -1,261 +0,0 @@ -import logging -import mimetypes -from collections.abc import Callable, Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod -from graphon.node_events import NodeRunResult -from graphon.nodes.base import variable_template_parser -from graphon.nodes.base.entities import VariableSelector -from graphon.nodes.base.node import Node -from graphon.nodes.http_request.executor import Executor -from graphon.nodes.protocols import ( - FileManagerProtocol, - FileReferenceFactoryProtocol, - HttpClientProtocol, - ToolFileManagerProtocol, -) -from graphon.variables.segments import ArrayFileSegment - -from .config import build_http_request_config, resolve_http_request_config -from .entities import ( - HTTP_REQUEST_CONFIG_FILTER_KEY, - HttpRequestNodeConfig, - HttpRequestNodeData, - HttpRequestNodeTimeout, - Response, -) -from .exc import HttpRequestNodeError, RequestBodyError - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - - -class HttpRequestNode(Node[HttpRequestNodeData]): - node_type = BuiltinNodeTypes.HTTP_REQUEST - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - http_request_config: HttpRequestNodeConfig, - http_client: HttpClientProtocol, - tool_file_manager_factory: Callable[[], ToolFileManagerProtocol], - file_manager: FileManagerProtocol, - file_reference_factory: FileReferenceFactoryProtocol, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - self._http_request_config = http_request_config - self._http_client = http_client - self._tool_file_manager_factory = tool_file_manager_factory - self._file_manager = file_manager - self._file_reference_factory = file_reference_factory - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - if not filters or HTTP_REQUEST_CONFIG_FILTER_KEY not in filters: - http_request_config = build_http_request_config() - else: - http_request_config = resolve_http_request_config(filters) - default_timeout = http_request_config.default_timeout() - return { - "type": "http-request", - "config": { - "method": "get", - "authorization": { - "type": "no-auth", - }, - "body": {"type": "none"}, - "timeout": { - **default_timeout.model_dump(), - "max_connect_timeout": http_request_config.max_connect_timeout, - "max_read_timeout": http_request_config.max_read_timeout, - "max_write_timeout": http_request_config.max_write_timeout, - }, - "ssl_verify": http_request_config.ssl_verify, - }, - "retry_config": { - "max_retries": http_request_config.ssrf_default_max_retries, - "retry_interval": 0.5 * (2**2), - "retry_enabled": True, - }, - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - process_data = {} - try: - http_executor = Executor( - node_data=self.node_data, - timeout=self._get_request_timeout(self.node_data), - variable_pool=self.graph_runtime_state.variable_pool, - http_request_config=self._http_request_config, - # Must be 0 to disable executor-level retries, as the graph engine handles them. - # This is critical to prevent nested retries. - max_retries=0, - ssl_verify=self.node_data.ssl_verify, - http_client=self._http_client, - file_manager=self._file_manager, - ) - process_data["request"] = http_executor.to_log() - - response = http_executor.invoke() - files = self.extract_files(url=http_executor.url, response=response) - if not response.response.is_success and (self.error_strategy or self.retry): - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - outputs={ - "status_code": response.status_code, - "body": response.text if not files.value else "", - "headers": response.headers, - "files": files, - }, - process_data={ - "request": http_executor.to_log(), - }, - error=f"Request failed with status code {response.status_code}", - error_type="HTTPResponseCodeError", - ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "status_code": response.status_code, - "body": response.text if not files.value else "", - "headers": response.headers, - "files": files, - }, - process_data={ - "request": http_executor.to_log(), - }, - ) - except HttpRequestNodeError as e: - logger.warning("http request node %s failed to run: %s", self._node_id, e) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - process_data=process_data, - error_type=type(e).__name__, - ) - - def _get_request_timeout(self, node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: - default_timeout = self._http_request_config.default_timeout() - timeout = node_data.timeout - if timeout is None: - return default_timeout - - return HttpRequestNodeTimeout( - connect=timeout.connect or default_timeout.connect, - read=timeout.read or default_timeout.read, - write=timeout.write or default_timeout.write, - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: HttpRequestNodeData, - ) -> Mapping[str, Sequence[str]]: - selectors: list[VariableSelector] = [] - selectors += variable_template_parser.extract_selectors_from_template(node_data.url) - selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) - selectors += variable_template_parser.extract_selectors_from_template(node_data.params) - if node_data.body: - body_type = node_data.body.type - data = node_data.body.data - match body_type: - case "none": - pass - case "binary": - if len(data) != 1: - raise RequestBodyError("invalid body data, should have only one item") - selector = data[0].file - selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector)) - case "json" | "raw-text": - if len(data) != 1: - raise RequestBodyError("invalid body data, should have only one item") - selectors += variable_template_parser.extract_selectors_from_template(data[0].key) - selectors += variable_template_parser.extract_selectors_from_template(data[0].value) - case "x-www-form-urlencoded": - for item in data: - selectors += variable_template_parser.extract_selectors_from_template(item.key) - selectors += variable_template_parser.extract_selectors_from_template(item.value) - case "form-data": - for item in data: - selectors += variable_template_parser.extract_selectors_from_template(item.key) - if item.type == "text": - selectors += variable_template_parser.extract_selectors_from_template(item.value) - elif item.type == "file": - selectors.append( - VariableSelector(variable="#" + ".".join(item.file) + "#", value_selector=item.file) - ) - - mapping = {} - for selector_iter in selectors: - mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector - - return mapping - - def extract_files(self, url: str, response: Response) -> ArrayFileSegment: - """ - Extract files from response by checking both Content-Type header and URL - """ - files: list[File] = [] - is_file = response.is_file - content_type = response.content_type - content = response.content - parsed_content_disposition = response.parsed_content_disposition - content_disposition_type = None - - if not is_file: - return ArrayFileSegment(value=[]) - - if parsed_content_disposition: - content_disposition_filename = parsed_content_disposition.get_filename() - if content_disposition_filename: - # If filename is available from content-disposition, use it to guess the content type - content_disposition_type = mimetypes.guess_type(content_disposition_filename)[0] - - # Guess file extension from URL or Content-Type header - filename = url.split("?")[0].split("/")[-1] or "" - mime_type = ( - content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream" - ) - tool_file_manager = self._tool_file_manager_factory() - - tool_file = tool_file_manager.create_file_by_raw( - file_binary=content, - mimetype=mime_type, - ) - - file = self._file_reference_factory.build_from_mapping( - mapping={ - "tool_file_id": tool_file.id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - ) - files.append(file) - - return ArrayFileSegment(value=files) - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled diff --git a/api/graphon/nodes/human_input/__init__.py b/api/graphon/nodes/human_input/__init__.py deleted file mode 100644 index 17896045779..00000000000 --- a/api/graphon/nodes/human_input/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Human Input node implementation. -""" diff --git a/api/graphon/nodes/human_input/entities.py b/api/graphon/nodes/human_input/entities.py deleted file mode 100644 index aa01bde1457..00000000000 --- a/api/graphon/nodes/human_input/entities.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Human Input node entities. - -The graph package owns the workflow-facing form schema and keeps it transportable -across runtimes. Dify-specific delivery surface and recipient translation stay -outside `graphon`. -""" - -import re -from collections.abc import Mapping, Sequence -from datetime import datetime, timedelta -from typing import Any, Self - -from pydantic import BaseModel, Field, field_validator, model_validator - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from graphon.variables.consts import SELECTORS_LENGTH - -from .enums import ButtonStyle, FormInputType, PlaceholderType, TimeoutUnit - -_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") - - -class FormInputDefault(BaseModel): - """Default configuration for form inputs.""" - - # NOTE: Ideally, a discriminated union would be used to model - # FormInputDefault. However, the UI requires preserving the previous - # value when switching between `VARIABLE` and `CONSTANT` types. This - # necessitates retaining all fields, making a discriminated union unsuitable. - - type: PlaceholderType - - # The selector of default variable, used when `type` is `VARIABLE`. - selector: Sequence[str] = Field(default_factory=tuple) # - - # The value of the default, used when `type` is `CONSTANT`. - # TODO: How should we express JSON values? - value: str = "" - - @model_validator(mode="after") - def _validate_selector(self) -> Self: - if self.type == PlaceholderType.CONSTANT: - return self - if len(self.selector) < SELECTORS_LENGTH: - raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}") - return self - - -class FormInput(BaseModel): - """Form input definition.""" - - type: FormInputType - output_variable_name: str - default: FormInputDefault | None = None - - -_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") - - -class UserAction(BaseModel): - """User action configuration.""" - - # id is the identifier for this action. - # It also serves as the identifiers of output handle. - # - # The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.) - id: str = Field(max_length=20) - title: str = Field(max_length=20) - button_style: ButtonStyle = ButtonStyle.DEFAULT - - @field_validator("id") - @classmethod - def _validate_id(cls, value: str) -> str: - if not _IDENTIFIER_PATTERN.match(value): - raise ValueError( - f"'{value}' is not a valid identifier. It must start with a letter or underscore, " - f"and contain only letters, numbers, or underscores." - ) - return value - - -class HumanInputNodeData(BaseNodeData): - """Human Input node data.""" - - type: NodeType = BuiltinNodeTypes.HUMAN_INPUT - form_content: str = "" - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - timeout: int = 36 - timeout_unit: TimeoutUnit = TimeoutUnit.HOUR - - @field_validator("inputs") - @classmethod - def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]: - seen_names: set[str] = set() - for form_input in inputs: - name = form_input.output_variable_name - if name in seen_names: - raise ValueError(f"duplicated output_variable_name '{name}' in inputs") - seen_names.add(name) - return inputs - - @field_validator("user_actions") - @classmethod - def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]: - seen_ids: set[str] = set() - for action in user_actions: - action_id = action.id - if action_id in seen_ids: - raise ValueError(f"duplicated user action id '{action_id}'") - seen_ids.add(action_id) - return user_actions - - def expiration_time(self, start_time: datetime) -> datetime: - if self.timeout_unit == TimeoutUnit.HOUR: - return start_time + timedelta(hours=self.timeout) - elif self.timeout_unit == TimeoutUnit.DAY: - return start_time + timedelta(days=self.timeout) - else: - raise AssertionError("unknown timeout unit.") - - def outputs_field_names(self) -> Sequence[str]: - field_names = [] - for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content): - field_names.append(match.group("field_name")) - return field_names - - def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]: - variable_mappings: dict[str, Sequence[str]] = {} - - def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None: - for selector in selectors: - if len(selector) < SELECTORS_LENGTH: - continue - qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#" - variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH]) - - form_template_parser = VariableTemplateParser(template=self.form_content) - _add_variable_selectors( - [selector.value_selector for selector in form_template_parser.extract_variable_selectors()] - ) - - for input in self.inputs: - default_value = input.default - if default_value is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - default_value_key = ".".join(default_value.selector) - qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#" - variable_mappings[qualified_variable_mapping_key] = default_value.selector - - return variable_mappings - - def find_action_text(self, action_id: str) -> str: - """ - Resolve action display text by id. - """ - for action in self.user_actions: - if action.id == action_id: - return action.title - return action_id - - -class FormDefinition(BaseModel): - form_content: str - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - rendered_content: str - expiration_time: datetime - - # this is used to store the resolved default values - default_values: dict[str, Any] = Field(default_factory=dict) - - # node_title records the title of the HumanInput node. - node_title: str | None = None - - # display_in_ui controls whether the form should be displayed in UI surfaces. - display_in_ui: bool | None = None - - -class HumanInputSubmissionValidationError(ValueError): - pass - - -def validate_human_input_submission( - *, - inputs: Sequence[FormInput], - user_actions: Sequence[UserAction], - selected_action_id: str, - form_data: Mapping[str, Any], -) -> None: - available_actions = {action.id for action in user_actions} - if selected_action_id not in available_actions: - raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}") - - provided_inputs = set(form_data.keys()) - missing_inputs = [ - form_input.output_variable_name - for form_input in inputs - if form_input.output_variable_name not in provided_inputs - ] - - if missing_inputs: - missing_list = ", ".join(missing_inputs) - raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}") diff --git a/api/graphon/nodes/human_input/enums.py b/api/graphon/nodes/human_input/enums.py deleted file mode 100644 index 3fb0ab44995..00000000000 --- a/api/graphon/nodes/human_input/enums.py +++ /dev/null @@ -1,55 +0,0 @@ -import enum - - -class HumanInputFormStatus(enum.StrEnum): - """Status of a human input form.""" - - # Awaiting submission from any recipient. Forms stay in this state until - # submitted or a timeout rule applies. - WAITING = enum.auto() - # Global timeout reached. The workflow run is stopped and will not resume. - # This is distinct from node-level timeout. - EXPIRED = enum.auto() - # Submitted by a recipient; form data is available and execution resumes - # along the selected action edge. - SUBMITTED = enum.auto() - # Node-level timeout reached. The human input node should emit a timeout - # event and the workflow should resume along the timeout edge. - TIMEOUT = enum.auto() - - -class HumanInputFormKind(enum.StrEnum): - """Kind of a human input form.""" - - RUNTIME = enum.auto() # Form created during workflow execution. - DELIVERY_TEST = enum.auto() # Form created for delivery tests. - - -class ButtonStyle(enum.StrEnum): - """Button styles for user actions.""" - - PRIMARY = enum.auto() - DEFAULT = enum.auto() - ACCENT = enum.auto() - GHOST = enum.auto() - - -class TimeoutUnit(enum.StrEnum): - """Timeout unit for form expiration.""" - - HOUR = enum.auto() - DAY = enum.auto() - - -class FormInputType(enum.StrEnum): - """Form input types.""" - - TEXT_INPUT = enum.auto() - PARAGRAPH = enum.auto() - - -class PlaceholderType(enum.StrEnum): - """Default value types for form inputs.""" - - VARIABLE = enum.auto() - CONSTANT = enum.auto() diff --git a/api/graphon/nodes/human_input/human_input_node.py b/api/graphon/nodes/human_input/human_input_node.py deleted file mode 100644 index fe04022877b..00000000000 --- a/api/graphon/nodes/human_input/human_input_node.py +++ /dev/null @@ -1,299 +0,0 @@ -import json -import logging -from collections.abc import Generator, Mapping, Sequence -from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, cast - -from graphon.entities.graph_config import NodeConfigDict -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import ( - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - NodeRunResult, - PauseRequestedEvent, -) -from graphon.node_events.base import NodeEventBase -from graphon.node_events.node import StreamCompletedEvent -from graphon.nodes.base.node import Node -from graphon.nodes.runtime import HumanInputFormStateProtocol, HumanInputNodeRuntimeProtocol -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter - -from .entities import HumanInputNodeData -from .enums import HumanInputFormStatus, PlaceholderType - -if TYPE_CHECKING: - from graphon.entities.graph_init_params import GraphInitParams - from graphon.runtime.graph_runtime_state import GraphRuntimeState - - -_SELECTED_BRANCH_KEY = "selected_branch" - - -logger = logging.getLogger(__name__) - - -class HumanInputNode(Node[HumanInputNodeData]): - node_type = BuiltinNodeTypes.HUMAN_INPUT - execution_type = NodeExecutionType.BRANCH - - _BRANCH_SELECTION_KEYS: tuple[str, ...] = ( - "edge_source_handle", - "edgeSourceHandle", - "source_handle", - _SELECTED_BRANCH_KEY, - "selectedBranch", - "branch", - "branch_id", - "branchId", - "handle", - ) - - _node_data: HumanInputNodeData - _OUTPUT_FIELD_ACTION_ID = "__action_id" - _OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content" - _TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout" - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - runtime: HumanInputNodeRuntimeProtocol | None = None, - form_repository: object | None = None, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - resolved_runtime = runtime - if resolved_runtime is None: - raise ValueError("runtime is required") - if form_repository is not None: - with_form_repository = getattr(resolved_runtime, "with_form_repository", None) - if callable(with_form_repository): - resolved_runtime = cast(HumanInputNodeRuntimeProtocol, with_form_repository(form_repository)) - self._runtime: HumanInputNodeRuntimeProtocol = resolved_runtime - - @classmethod - def version(cls) -> str: - return "1" - - def _resolve_branch_selection(self) -> str | None: - """Determine the branch handle selected by human input if available.""" - - variable_pool = self.graph_runtime_state.variable_pool - - for key in self._BRANCH_SELECTION_KEYS: - handle = self._extract_branch_handle(variable_pool.get((self.id, key))) - if handle: - return handle - - default_values = self.node_data.default_value_dict - for key in self._BRANCH_SELECTION_KEYS: - handle = self._normalize_branch_value(default_values.get(key)) - if handle: - return handle - - return None - - @staticmethod - def _extract_branch_handle(segment: Any) -> str | None: - if segment is None: - return None - - candidate = getattr(segment, "to_object", None) - raw_value = candidate() if callable(candidate) else getattr(segment, "value", None) - if raw_value is None: - return None - - return HumanInputNode._normalize_branch_value(raw_value) - - @staticmethod - def _normalize_branch_value(value: Any) -> str | None: - if value is None: - return None - - if isinstance(value, str): - stripped = value.strip() - return stripped or None - - if isinstance(value, Mapping): - for key in ("handle", "edge_source_handle", "edgeSourceHandle", "branch", "id", "value"): - candidate = value.get(key) - if isinstance(candidate, str) and candidate: - return candidate - - return None - - def _form_to_pause_event(self, form_entity: HumanInputFormStateProtocol): - required_event = self._human_input_required_event(form_entity) - pause_requested_event = PauseRequestedEvent(reason=required_event) - return pause_requested_event - - def resolve_default_values(self) -> Mapping[str, Any]: - variable_pool = self.graph_runtime_state.variable_pool - resolved_defaults: dict[str, Any] = {} - for input in self._node_data.inputs: - if (default_value := input.default) is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - resolved_value = variable_pool.get(default_value.selector) - if resolved_value is None: - # TODO: How should we handle this? - continue - resolved_defaults[input.output_variable_name] = ( - WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value) - ) - - return resolved_defaults - - def _human_input_required_event(self, form_entity: HumanInputFormStateProtocol) -> HumanInputRequired: - node_data = self._node_data - resolved_default_values = self.resolve_default_values() - return HumanInputRequired( - form_id=form_entity.id, - form_content=form_entity.rendered_content, - inputs=node_data.inputs, - actions=node_data.user_actions, - node_id=self.id, - node_title=node_data.title, - resolved_default_values=resolved_default_values, - ) - - def _run(self) -> Generator[NodeEventBase, None, None]: - """ - Execute the human input node. - - This method will: - 1. Generate a unique form ID - 2. Create form content with variable substitution - 3. Persist the form through the configured repository - 4. Send form via configured delivery methods - 5. Suspend workflow execution - 6. Wait for form submission to resume - """ - form = self._runtime.get_form(node_id=self.id) - if form is None: - form_entity = self._runtime.create_form( - node_id=self.id, - node_data=self._node_data, - rendered_content=self.render_form_content_before_submission(), - resolved_default_values=self.resolve_default_values(), - ) - - logger.info( - "Human Input node suspended workflow for form. node_id=%s, form_id=%s", - self.id, - form_entity.id, - ) - yield self._form_to_pause_event(form_entity) - return - - if form.status in { - HumanInputFormStatus.TIMEOUT, - HumanInputFormStatus.EXPIRED, - } or form.expiration_time <= datetime.now(UTC).replace(tzinfo=None): - yield HumanInputFormTimeoutEvent( - node_title=self._node_data.title, - expiration_time=form.expiration_time, - ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={self._OUTPUT_FIELD_ACTION_ID: ""}, - edge_source_handle=self._TIMEOUT_HANDLE, - ) - ) - return - - if not form.submitted: - yield self._form_to_pause_event(form) - return - - selected_action_id = form.selected_action_id - if selected_action_id is None: - raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}") - submitted_data = form.submitted_data or {} - outputs: dict[str, Any] = dict(submitted_data) - outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id - rendered_content = self.render_form_content_with_outputs( - form.rendered_content, - outputs, - self._node_data.outputs_field_names(), - ) - outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content - - action_text = self._node_data.find_action_text(selected_action_id) - - yield HumanInputFormFilledEvent( - node_title=self._node_data.title, - rendered_content=rendered_content, - action_id=selected_action_id, - action_text=action_text, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs, - edge_source_handle=selected_action_id, - ) - ) - - def render_form_content_before_submission(self) -> str: - """ - Process form content by substituting variables. - - This method should: - 1. Parse the form_content markdown - 2. Substitute {{#node_name.var_name#}} with actual values - 3. Keep {{#$output.field_name#}} placeholders for form inputs - """ - rendered_form_content = self.graph_runtime_state.variable_pool.convert_template( - self._node_data.form_content, - ) - return rendered_form_content.markdown - - @staticmethod - def render_form_content_with_outputs( - form_content: str, - outputs: Mapping[str, Any], - field_names: Sequence[str], - ) -> str: - """ - Replace {{#$output.xxx#}} placeholders with submitted values. - """ - rendered_content = form_content - for field_name in field_names: - placeholder = "{{#$output." + field_name + "#}}" - value = outputs.get(field_name) - if value is None: - replacement = "" - elif isinstance(value, (dict, list)): - replacement = json.dumps(value, ensure_ascii=False) - else: - replacement = str(value) - rendered_content = rendered_content.replace(placeholder, replacement) - return rendered_content - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: HumanInputNodeData, - ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selectors referenced in form content and input default values. - - This method should parse: - 1. Variables referenced in form_content ({{#node_name.var_name#}}) - 2. Variables referenced in input default values - """ - return node_data.extract_variable_selector_to_variable_mapping(node_id) diff --git a/api/graphon/nodes/if_else/__init__.py b/api/graphon/nodes/if_else/__init__.py deleted file mode 100644 index afa0e8112c5..00000000000 --- a/api/graphon/nodes/if_else/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .if_else_node import IfElseNode - -__all__ = ["IfElseNode"] diff --git a/api/graphon/nodes/if_else/entities.py b/api/graphon/nodes/if_else/entities.py deleted file mode 100644 index d59b782747d..00000000000 --- a/api/graphon/nodes/if_else/entities.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.utils.condition.entities import Condition - - -class IfElseNodeData(BaseNodeData): - """ - If Else Node Data. - """ - - type: NodeType = BuiltinNodeTypes.IF_ELSE - - class Case(BaseModel): - """ - Case entity representing a single logical condition group - """ - - case_id: str - logical_operator: Literal["and", "or"] - conditions: list[Condition] - - logical_operator: Literal["and", "or"] | None = "and" - conditions: list[Condition] | None = Field(default=None, deprecated=True) - - cases: list[Case] | None = None diff --git a/api/graphon/nodes/if_else/if_else_node.py b/api/graphon/nodes/if_else/if_else_node.py deleted file mode 100644 index 81e934971ab..00000000000 --- a/api/graphon/nodes/if_else/if_else_node.py +++ /dev/null @@ -1,124 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Literal - -from typing_extensions import deprecated - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.if_else.entities import IfElseNodeData -from graphon.runtime import VariablePool -from graphon.utils.condition.entities import Condition -from graphon.utils.condition.processor import ConditionProcessor - - -class IfElseNode(Node[IfElseNodeData]): - node_type = BuiltinNodeTypes.IF_ELSE - execution_type = NodeExecutionType.BRANCH - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run node - :return: - """ - node_inputs: dict[str, Sequence[Mapping[str, Any]]] = {"conditions": []} - - process_data: dict[str, list] = {"condition_results": []} - - input_conditions: Sequence[Mapping[str, Any]] = [] - final_result = False - selected_case_id = "false" - condition_processor = ConditionProcessor() - try: - # Check if the new cases structure is used - if self.node_data.cases: - for case in self.node_data.cases: - input_conditions, group_result, final_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=case.conditions, - operator=case.logical_operator, - ) - - process_data["condition_results"].append( - { - "group": case.model_dump(), - "results": group_result, - "final_result": final_result, - } - ) - - # Break if a case passes (logical short-circuit) - if final_result: - selected_case_id = case.case_id # Capture the ID of the passing case - break - - else: - # TODO: Remove this once all graph definitions use the `cases` structure. - # Fallback to the legacy node shape when `cases` are not defined. - input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated] - condition_processor=condition_processor, - variable_pool=self.graph_runtime_state.variable_pool, - conditions=self.node_data.conditions or [], - operator=self.node_data.logical_operator or "and", - ) - - selected_case_id = "true" if final_result else "false" - - process_data["condition_results"].append( - {"group": "default", "results": group_result, "final_result": final_result} - ) - - node_inputs["conditions"] = input_conditions - - except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_data, error=str(e) - ) - - outputs = {"result": final_result, "selected_case_id": selected_case_id} - - data = NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - process_data=process_data, - edge_source_handle=selected_case_id or "false", # Use case ID or 'default' - outputs=outputs, - ) - - return data - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IfElseNodeData, - ) -> Mapping[str, Sequence[str]]: - var_mapping: dict[str, list[str]] = {} - _ = graph_config # Explicitly mark as unused - for case in node_data.cases or []: - for condition in case.conditions: - key = f"{node_id}.#{'.'.join(condition.variable_selector)}#" - var_mapping[key] = condition.variable_selector - - return var_mapping - - -@deprecated("This function is deprecated. You should use the new cases structure.") -def _should_not_use_old_function( - *, - condition_processor: ConditionProcessor, - variable_pool: VariablePool, - conditions: list[Condition], - operator: Literal["and", "or"], -): - return condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=conditions, - operator=operator, - ) diff --git a/api/graphon/nodes/iteration/__init__.py b/api/graphon/nodes/iteration/__init__.py deleted file mode 100644 index 5bb87aaffa9..00000000000 --- a/api/graphon/nodes/iteration/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .entities import IterationNodeData -from .iteration_node import IterationNode -from .iteration_start_node import IterationStartNode - -__all__ = ["IterationNode", "IterationNodeData", "IterationStartNode"] diff --git a/api/graphon/nodes/iteration/entities.py b/api/graphon/nodes/iteration/entities.py deleted file mode 100644 index 30b6e4bea8f..00000000000 --- a/api/graphon/nodes/iteration/entities.py +++ /dev/null @@ -1,67 +0,0 @@ -from enum import StrEnum -from typing import Any - -from pydantic import Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base import BaseIterationNodeData, BaseIterationState - - -class ErrorHandleMode(StrEnum): - TERMINATED = "terminated" - CONTINUE_ON_ERROR = "continue-on-error" - REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output" - - -class IterationNodeData(BaseIterationNodeData): - """ - Iteration Node Data. - """ - - type: NodeType = BuiltinNodeTypes.ITERATION - parent_loop_id: str | None = None # redundant field, not used currently - iterator_selector: list[str] # variable selector - output_selector: list[str] # output selector - is_parallel: bool = False # open the parallel mode or not - parallel_nums: int = 10 # the numbers of parallel - error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error - flatten_output: bool = True # whether to flatten the output array if all elements are lists - - -class IterationStartNodeData(BaseNodeData): - """ - Iteration Start Node Data. - """ - - type: NodeType = BuiltinNodeTypes.ITERATION_START - - -class IterationState(BaseIterationState): - """ - Iteration State. - """ - - outputs: list[Any] = Field(default_factory=list) - current_output: Any = None - - class MetaData(BaseIterationState.MetaData): - """ - Data. - """ - - iterator_length: int - - def get_last_output(self) -> Any: - """ - Get last output. - """ - if self.outputs: - return self.outputs[-1] - return None - - def get_current_output(self) -> Any: - """ - Get current output. - """ - return self.current_output diff --git a/api/graphon/nodes/iteration/exc.py b/api/graphon/nodes/iteration/exc.py deleted file mode 100644 index 7b6af61b9db..00000000000 --- a/api/graphon/nodes/iteration/exc.py +++ /dev/null @@ -1,26 +0,0 @@ -class IterationNodeError(ValueError): - """Base class for iteration node errors.""" - - -class IteratorVariableNotFoundError(IterationNodeError): - """Raised when the iterator variable is not found.""" - - -class InvalidIteratorValueError(IterationNodeError): - """Raised when the iterator value is invalid.""" - - -class StartNodeIdNotFoundError(IterationNodeError): - """Raised when the start node ID is not found.""" - - -class IterationGraphNotFoundError(IterationNodeError): - """Raised when the iteration graph is not found.""" - - -class IterationIndexNotFoundError(IterationNodeError): - """Raised when the iteration index is not found.""" - - -class ChildGraphAbortedError(IterationNodeError): - """Raised when a child graph aborts and the container must stop immediately.""" diff --git a/api/graphon/nodes/iteration/iteration_node.py b/api/graphon/nodes/iteration/iteration_node.py deleted file mode 100644 index c0137396536..00000000000 --- a/api/graphon/nodes/iteration/iteration_node.py +++ /dev/null @@ -1,686 +0,0 @@ -import logging -from collections.abc import Generator, Mapping, Sequence -from concurrent.futures import Future, ThreadPoolExecutor, as_completed -from contextlib import suppress -from datetime import UTC, datetime -from threading import Lock -from typing import TYPE_CHECKING, Any, NewType, cast - -from typing_extensions import TypeIs - -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.enums import ( - BuiltinNodeTypes, - NodeExecutionType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.graph_events import ( - GraphNodeEventBase, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunSucceededEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import ( - IterationFailedEvent, - IterationNextEvent, - IterationStartedEvent, - IterationSucceededEvent, - NodeEventBase, - NodeRunResult, - StreamCompletedEvent, -) -from graphon.nodes.base import LLMUsageTrackingMixin -from graphon.nodes.base.node import Node -from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from graphon.runtime import VariablePool -from graphon.variables import IntegerVariable, NoneSegment -from graphon.variables.segments import ArrayAnySegment, ArraySegment - -from .exc import ( - ChildGraphAbortedError, - InvalidIteratorValueError, - IterationGraphNotFoundError, - IterationIndexNotFoundError, - IterationNodeError, - IteratorVariableNotFoundError, - StartNodeIdNotFoundError, -) - -if TYPE_CHECKING: - from graphon.graph_engine import GraphEngine - -logger = logging.getLogger(__name__) -_DEFAULT_CHILD_ABORT_REASON = "child graph aborted" - -EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) - - -class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): - """ - Iteration Node. - """ - - node_type = BuiltinNodeTypes.ITERATION - execution_type = NodeExecutionType.CONTAINER - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return { - "type": "iteration", - "config": { - "is_parallel": False, - "parallel_nums": 10, - "error_handle_mode": ErrorHandleMode.TERMINATED, - "flatten_output": True, - }, - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # type: ignore - variable = self._get_iterator_variable() - - if self._is_empty_iteration(variable): - yield from self._handle_empty_iteration(variable) - return - - iterator_list_value = self._validate_and_get_iterator_list(variable) - inputs = {"iterator_selector": iterator_list_value} - - self._validate_start_node() - - started_at = datetime.now(UTC).replace(tzinfo=None) - iter_run_map: dict[str, float] = {} - outputs: list[object] = [] - usage_accumulator = [LLMUsage.empty_usage()] - - yield IterationStartedEvent( - start_at=started_at, - inputs=inputs, - metadata={"iteration_length": len(iterator_list_value)}, - ) - - try: - yield from self._execute_iterations( - iterator_list_value=iterator_list_value, - outputs=outputs, - iter_run_map=iter_run_map, - usage_accumulator=usage_accumulator, - ) - - self._accumulate_usage(usage_accumulator[0]) - yield from self._handle_iteration_success( - started_at=started_at, - inputs=inputs, - outputs=outputs, - iterator_list_value=iterator_list_value, - iter_run_map=iter_run_map, - usage=usage_accumulator[0], - ) - except IterationNodeError as e: - self._accumulate_usage(usage_accumulator[0]) - yield from self._handle_iteration_failure( - started_at=started_at, - inputs=inputs, - outputs=outputs, - iterator_list_value=iterator_list_value, - iter_run_map=iter_run_map, - usage=usage_accumulator[0], - error=e, - ) - - def _get_iterator_variable(self) -> ArraySegment | NoneSegment: - variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) - - if not variable: - raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found") - - if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment): - raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") - - return variable - - def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]: - return isinstance(variable, NoneSegment) or len(variable.value) == 0 - - def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]: - # Try our best to preserve the type information. - if isinstance(variable, ArraySegment): - output = variable.model_copy(update={"value": []}) - else: - output = ArrayAnySegment(value=[]) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - # TODO(QuantumGhost): is it possible to compute the type of `output` - # from graph definition? - outputs={"output": output}, - ) - ) - - def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]: - iterator_list_value = variable.to_object() - - if not isinstance(iterator_list_value, list): - raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") - - return cast(list[object], iterator_list_value) - - def _validate_start_node(self) -> None: - if not self.node_data.start_node_id: - raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") - - def _execute_iterations( - self, - iterator_list_value: Sequence[object], - outputs: list[object], - iter_run_map: dict[str, float], - usage_accumulator: list[LLMUsage], - ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: - if self.node_data.is_parallel: - # Parallel mode execution - yield from self._execute_parallel_iterations( - iterator_list_value=iterator_list_value, - outputs=outputs, - iter_run_map=iter_run_map, - usage_accumulator=usage_accumulator, - ) - else: - # Sequential mode execution - for index, item in enumerate(iterator_list_value): - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - yield IterationNextEvent(index=index) - - graph_engine = self._create_graph_engine(index, item) - - # Run the iteration - try: - yield from self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs, - graph_engine=graph_engine, - ) - finally: - self._merge_graph_engine_usage(usage_accumulator=usage_accumulator, graph_engine=graph_engine) - iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() - - def _execute_parallel_iterations( - self, - iterator_list_value: Sequence[object], - outputs: list[object], - iter_run_map: dict[str, float], - usage_accumulator: list[LLMUsage], - ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: - # Initialize outputs list with None values to maintain order - outputs.extend([None] * len(iterator_list_value)) - - # Determine the number of parallel workers - max_workers = min(self.node_data.parallel_nums, len(iterator_list_value)) - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - # Submit all iteration tasks - started_child_engines: dict[int, GraphEngine] = {} - started_child_engines_lock = Lock() - merged_usage_indexes: set[int] = set() - future_to_index: dict[ - Future[ - tuple[ - float, - list[GraphNodeEventBase], - object | None, - LLMUsage, - ] - ], - int, - ] = {} - for index, item in enumerate(iterator_list_value): - yield IterationNextEvent(index=index) - future = executor.submit( - self._execute_tracked_iteration_parallel, - index=index, - item=item, - started_child_engines=started_child_engines, - started_child_engines_lock=started_child_engines_lock, - ) - future_to_index[future] = index - - # Process completed iterations as they finish - for future in as_completed(future_to_index): - index = future_to_index[future] - try: - result = future.result() - ( - iteration_duration, - events, - output_value, - iteration_usage, - ) = result - - # Update outputs at the correct index - outputs[index] = output_value - - # Yield all events from this iteration - yield from events - - # The worker computes duration before we replay buffered events here, - # so slow downstream consumers don't inflate per-iteration timing. - iter_run_map[str(index)] = iteration_duration - - usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage) - merged_usage_indexes.add(index) - - except Exception as e: - if index not in merged_usage_indexes: - self._merge_graph_engine_usage( - usage_accumulator=usage_accumulator, - graph_engine=started_child_engines.get(index), - ) - merged_usage_indexes.add(index) - if isinstance(e, ChildGraphAbortedError): - self._abort_parallel_siblings( - future_to_index=future_to_index, - current_future=future, - started_child_engines=started_child_engines, - reason=str(e) or _DEFAULT_CHILD_ABORT_REASON, - ) - self._drain_parallel_siblings( - future_to_index=future_to_index, - current_future=future, - started_child_engines=started_child_engines, - usage_accumulator=usage_accumulator, - merged_usage_indexes=merged_usage_indexes, - ) - raise e - - # Handle errors based on error_handle_mode - match self.node_data.error_handle_mode: - case ErrorHandleMode.TERMINATED: - # Cancel remaining futures and re-raise - for f in future_to_index: - if f != future: - f.cancel() - raise IterationNodeError(str(e)) - case ErrorHandleMode.CONTINUE_ON_ERROR: - outputs[index] = None - case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: - outputs[index] = None # Will be filtered later - - # Remove None values if in REMOVE_ABNORMAL_OUTPUT mode - if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: - outputs[:] = [output for output in outputs if output is not None] - - @staticmethod - def _merge_graph_engine_usage( - *, - usage_accumulator: list[LLMUsage], - graph_engine: "GraphEngine | None", - ) -> None: - if graph_engine is None: - return - usage_accumulator[0] = IterationNode._merge_usage( - usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage - ) - - def _abort_parallel_siblings( - self, - *, - future_to_index: Mapping[Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], int], - current_future: Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], - started_child_engines: Mapping[int, "GraphEngine"], - reason: str, - ) -> None: - for future, index in future_to_index.items(): - if future == current_future: - continue - - graph_engine = started_child_engines.get(index) - if graph_engine is not None: - graph_engine.request_abort(reason) - - future.cancel() - - def _drain_parallel_siblings( - self, - *, - future_to_index: Mapping[Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], int], - current_future: Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], - started_child_engines: Mapping[int, "GraphEngine"], - usage_accumulator: list[LLMUsage], - merged_usage_indexes: set[int], - ) -> None: - for future, index in future_to_index.items(): - if future == current_future: - continue - if future.cancelled(): - continue - - with suppress(Exception): - future.result() - - if index in merged_usage_indexes: - continue - - self._merge_graph_engine_usage( - usage_accumulator=usage_accumulator, - graph_engine=started_child_engines.get(index), - ) - merged_usage_indexes.add(index) - - def _execute_tracked_iteration_parallel( - self, - *, - index: int, - item: object, - started_child_engines: dict[int, "GraphEngine"], - started_child_engines_lock: Lock, - ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: - graph_engine = self._create_graph_engine(index, item) - with started_child_engines_lock: - started_child_engines[index] = graph_engine - - return self._execute_parallel_iteration_with_graph_engine( - index=index, - graph_engine=graph_engine, - ) - - def _execute_single_iteration_parallel( - self, - index: int, - item: object, - ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: - """Execute a single iteration in parallel mode and return results.""" - graph_engine = self._create_graph_engine(index, item) - return self._execute_parallel_iteration_with_graph_engine(index=index, graph_engine=graph_engine) - - def _execute_parallel_iteration_with_graph_engine( - self, - *, - index: int, - graph_engine: "GraphEngine", - ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: - """Execute a prepared child engine in parallel mode and return results.""" - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - events: list[GraphNodeEventBase] = [] - outputs_temp: list[object] = [] - - # Collect events instead of yielding them directly - for event in self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs_temp, - graph_engine=graph_engine, - ): - events.append(event) - - # Get the output value from the temporary outputs list - output_value = outputs_temp[0] if outputs_temp else None - iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() - - return ( - iteration_duration, - events, - output_value, - graph_engine.graph_runtime_state.llm_usage, - ) - - def _handle_iteration_success( - self, - started_at: datetime, - inputs: dict[str, Sequence[object]], - outputs: list[object], - iterator_list_value: Sequence[object], - iter_run_map: dict[str, float], - *, - usage: LLMUsage, - ) -> Generator[NodeEventBase, None, None]: - # Flatten the list of lists if all outputs are lists - flattened_outputs = self._flatten_outputs_if_needed(outputs) - - yield IterationSucceededEvent( - start_at=started_at, - inputs=inputs, - outputs={"output": flattened_outputs}, - steps=len(iterator_list_value), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, - }, - ) - - # Yield final success event - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": flattened_outputs}, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - ) - - def _flatten_outputs_if_needed(self, outputs: list[object]) -> list[object]: - """ - Flatten the outputs list if all elements are lists. - This maintains backward compatibility with version 1.8.1 behavior. - - If flatten_output is False, returns outputs as-is (nested structure). - If flatten_output is True (default), flattens the list if all elements are lists. - """ - # If flatten_output is disabled, return outputs as-is - if not self.node_data.flatten_output: - return outputs - - if not outputs: - return outputs - - # Check if all non-None outputs are lists - non_none_outputs: list[object] = [output for output in outputs if output is not None] - if not non_none_outputs: - return outputs - - if all(isinstance(output, list) for output in non_none_outputs): - # Flatten the list of lists - flattened: list[Any] = [] - for output in outputs: - if isinstance(output, list): - flattened.extend(output) - elif output is not None: - # This shouldn't happen based on our check, but handle it gracefully - flattened.append(output) - return flattened - - return outputs - - def _handle_iteration_failure( - self, - started_at: datetime, - inputs: dict[str, Sequence[object]], - outputs: list[object], - iterator_list_value: Sequence[object], - iter_run_map: dict[str, float], - *, - usage: LLMUsage, - error: IterationNodeError, - ) -> Generator[NodeEventBase, None, None]: - # Flatten the list of lists if all outputs are lists (even in failure case) - flattened_outputs = self._flatten_outputs_if_needed(outputs) - - yield IterationFailedEvent( - start_at=started_at, - inputs=inputs, - outputs={"output": flattened_outputs}, - steps=len(iterator_list_value), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, - }, - error=str(error), - ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(error), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IterationNodeData, - ) -> Mapping[str, Sequence[str]]: - variable_mapping: dict[str, Sequence[str]] = { - f"{node_id}.input_selector": node_data.iterator_selector, - } - iteration_node_ids = set() - - # Find all nodes that belong to this loop - nodes = graph_config.get("nodes", []) - for node in nodes: - node_config_data = node.get("data", {}) - if node_config_data.get("iteration_id") == node_id: - in_iteration_node_id = node.get("id") - if in_iteration_node_id: - iteration_node_ids.add(in_iteration_node_id) - - # Get node configs from graph_config instead of non-existent node_id_config_mapping - node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} - for sub_node_id, sub_node_config in node_configs.items(): - if sub_node_config.get("data", {}).get("iteration_id") != node_id: - continue - - # variable selector to variable mapping - try: - typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) - node_type = typed_sub_node_config["data"].type - node_mapping = Node.get_node_type_classes_mapping() - if node_type not in node_mapping: - continue - node_version = str(typed_sub_node_config["data"].version) - node_cls = node_mapping[node_type][node_version] - - sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, config=typed_sub_node_config - ) - sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) - except NotImplementedError: - sub_node_variable_mapping = {} - - # remove iteration variables - sub_node_variable_mapping = { - sub_node_id + "." + key: value - for key, value in sub_node_variable_mapping.items() - if value[0] != node_id - } - - variable_mapping.update(sub_node_variable_mapping) - - # remove variable out from iteration - variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids} - - return variable_mapping - - def _append_iteration_info_to_event( - self, - event: GraphNodeEventBase, - iter_run_index: int, - ): - event.in_iteration_id = self._node_id - iter_metadata = { - WorkflowNodeExecutionMetadataKey.ITERATION_ID: self._node_id, - WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index, - } - - current_metadata = event.node_run_result.metadata - if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata: - event.node_run_result.metadata = {**current_metadata, **iter_metadata} - - def _run_single_iter( - self, - *, - variable_pool: VariablePool, - outputs: list[object], - graph_engine: "GraphEngine", - ) -> Generator[GraphNodeEventBase, None, None]: - rst = graph_engine.run() - # get current iteration index - index_variable = variable_pool.get([self._node_id, "index"]) - if not isinstance(index_variable, IntegerVariable): - raise IterationIndexNotFoundError(f"iteration {self._node_id} current index not found") - current_index = index_variable.value - for event in rst: - if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.ITERATION_START: - continue - - if isinstance(event, GraphNodeEventBase): - self._append_iteration_info_to_event(event=event, iter_run_index=current_index) - yield event - elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)): - result = variable_pool.get(self.node_data.output_selector) - if result is None: - outputs.append(None) - else: - outputs.append(result.to_object()) - return - elif isinstance(event, GraphRunAbortedEvent): - raise ChildGraphAbortedError(event.reason or _DEFAULT_CHILD_ABORT_REASON) - elif isinstance(event, GraphRunFailedEvent): - match self.node_data.error_handle_mode: - case ErrorHandleMode.TERMINATED: - raise IterationNodeError(event.error) - case ErrorHandleMode.CONTINUE_ON_ERROR: - outputs.append(None) - return - case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: - return - - def _create_graph_engine(self, index: int, item: object): - from graphon.entities import GraphInitParams - from graphon.runtime import ChildGraphNotFoundError - - # Create GraphInitParams for child graph execution. - graph_init_params = GraphInitParams( - workflow_id=self.workflow_id, - graph_config=self.graph_config, - run_context=self.run_context, - call_depth=self.workflow_call_depth, - ) - # Create a deep copy of the variable pool for each iteration - variable_pool_copy = self.graph_runtime_state.variable_pool.model_copy(deep=True) - - # append iteration variable (item, index) to variable pool - variable_pool_copy.add([self._node_id, "index"], index) - variable_pool_copy.add([self._node_id, "item"], item) - root_node_id = self.node_data.start_node_id - if root_node_id is None: - raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") - - try: - return self.graph_runtime_state.create_child_engine( - workflow_id=self.workflow_id, - graph_init_params=graph_init_params, - root_node_id=root_node_id, - variable_pool=variable_pool_copy, - ) - except ChildGraphNotFoundError as exc: - raise IterationGraphNotFoundError("iteration graph not found") from exc diff --git a/api/graphon/nodes/iteration/iteration_start_node.py b/api/graphon/nodes/iteration/iteration_start_node.py deleted file mode 100644 index 3a44d3d81d7..00000000000 --- a/api/graphon/nodes/iteration/iteration_start_node.py +++ /dev/null @@ -1,22 +0,0 @@ -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.iteration.entities import IterationStartNodeData - - -class IterationStartNode(Node[IterationStartNodeData]): - """ - Iteration Start Node. - """ - - node_type = BuiltinNodeTypes.ITERATION_START - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run the node. - """ - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) diff --git a/api/graphon/nodes/list_operator/__init__.py b/api/graphon/nodes/list_operator/__init__.py deleted file mode 100644 index 1877586ef41..00000000000 --- a/api/graphon/nodes/list_operator/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .node import ListOperatorNode - -__all__ = ["ListOperatorNode"] diff --git a/api/graphon/nodes/list_operator/entities.py b/api/graphon/nodes/list_operator/entities.py deleted file mode 100644 index 0db1c75cddc..00000000000 --- a/api/graphon/nodes/list_operator/entities.py +++ /dev/null @@ -1,71 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum - -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - - -class FilterOperator(StrEnum): - # string conditions - CONTAINS = "contains" - START_WITH = "start with" - END_WITH = "end with" - IS = "is" - IN = "in" - EMPTY = "empty" - NOT_CONTAINS = "not contains" - IS_NOT = "is not" - NOT_IN = "not in" - NOT_EMPTY = "not empty" - # number conditions - EQUAL = "=" - NOT_EQUAL = "โ‰ " - LESS_THAN = "<" - GREATER_THAN = ">" - GREATER_THAN_OR_EQUAL = "โ‰ฅ" - LESS_THAN_OR_EQUAL = "โ‰ค" - - -class Order(StrEnum): - ASC = "asc" - DESC = "desc" - - -class FilterCondition(BaseModel): - key: str = "" - comparison_operator: FilterOperator = FilterOperator.CONTAINS - # the value is bool if the filter operator is comparing with - # a boolean constant. - value: str | Sequence[str] | bool = "" - - -class FilterBy(BaseModel): - enabled: bool = False - conditions: Sequence[FilterCondition] = Field(default_factory=list) - - -class OrderByConfig(BaseModel): - enabled: bool = False - key: str = "" - value: Order = Order.ASC - - -class Limit(BaseModel): - enabled: bool = False - size: int = -1 - - -class ExtractConfig(BaseModel): - enabled: bool = False - serial: str = "1" - - -class ListOperatorNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.LIST_OPERATOR - variable: Sequence[str] = Field(default_factory=list) - filter_by: FilterBy - order_by: OrderByConfig - limit: Limit - extract_by: ExtractConfig = Field(default_factory=ExtractConfig) diff --git a/api/graphon/nodes/list_operator/exc.py b/api/graphon/nodes/list_operator/exc.py deleted file mode 100644 index f88aa0be29c..00000000000 --- a/api/graphon/nodes/list_operator/exc.py +++ /dev/null @@ -1,16 +0,0 @@ -class ListOperatorError(ValueError): - """Base class for all ListOperator errors.""" - - pass - - -class InvalidFilterValueError(ListOperatorError): - pass - - -class InvalidKeyError(ListOperatorError): - pass - - -class InvalidConditionError(ListOperatorError): - pass diff --git a/api/graphon/nodes/list_operator/node.py b/api/graphon/nodes/list_operator/node.py deleted file mode 100644 index dad17a8f4a0..00000000000 --- a/api/graphon/nodes/list_operator/node.py +++ /dev/null @@ -1,345 +0,0 @@ -from collections.abc import Callable, Sequence -from typing import Any, TypeAlias, TypeVar - -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.file import File -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment -from graphon.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment - -from .entities import FilterOperator, ListOperatorNodeData, Order -from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError - -_SUPPORTED_TYPES_TUPLE = ( - ArrayFileSegment, - ArrayNumberSegment, - ArrayStringSegment, - ArrayBooleanSegment, -) -_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment - - -_T = TypeVar("_T") - - -def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]: - """Returns the negation of a given filter function. If the original filter - returns `True` for a value, the negated filter will return `False`, and vice versa. - """ - - def wrapper(value: _T) -> bool: - return not filter_(value) - - return wrapper - - -class ListOperatorNode(Node[ListOperatorNodeData]): - node_type = BuiltinNodeTypes.LIST_OPERATOR - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self): - inputs: dict[str, Sequence[object]] = {} - process_data: dict[str, Sequence[object]] = {} - outputs: dict[str, Any] = {} - - variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) - if variable is None: - error_message = f"Variable not found for selector: {self.node_data.variable}" - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs - ) - if not variable.value: - inputs = {"variable": []} - process_data = {"variable": []} - if isinstance(variable, ArraySegment): - result = variable.model_copy(update={"value": []}) - else: - result = ArrayAnySegment(value=[]) - outputs = {"result": result, "first_record": None, "last_record": None} - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) - if not isinstance(variable, _SUPPORTED_TYPES_TUPLE): - error_message = f"Variable {self.node_data.variable} is not an array type, actual type: {type(variable)}" - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs - ) - - if isinstance(variable, ArrayFileSegment): - inputs = {"variable": [item.to_dict() for item in variable.value]} - process_data["variable"] = [item.to_dict() for item in variable.value] - else: - inputs = {"variable": variable.value} - process_data["variable"] = variable.value - - try: - # Filter - if self.node_data.filter_by.enabled: - variable = self._apply_filter(variable) - - # Extract - if self.node_data.extract_by.enabled: - variable = self._extract_slice(variable) - - # Order - if self.node_data.order_by.enabled: - variable = self._apply_order(variable) - - # Slice - if self.node_data.limit.enabled: - variable = self._apply_slice(variable) - - outputs = { - "result": variable, - "first_record": variable.value[0] if variable.value else None, - "last_record": variable.value[-1] if variable.value else None, - } - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) - except ListOperatorError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) - - def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - filter_func: Callable[[Any], bool] - result: list[Any] = [] - for condition in self.node_data.filter_by.conditions: - if isinstance(variable, ArrayStringSegment): - if not isinstance(condition.value, str): - raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayNumberSegment): - if not isinstance(condition.value, str): - raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value)) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayFileSegment): - if isinstance(condition.value, str): - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - elif isinstance(condition.value, bool): - raise ValueError(f"File filter expects a string value, got {type(condition.value)}") - else: - value = condition.value - filter_func = _get_file_filter_func( - key=condition.key, - condition=condition.comparison_operator, - value=value, - ) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - else: - if not isinstance(condition.value, bool): - raise ValueError(f"Boolean filter expects a boolean value, got {type(condition.value)}") - filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - return variable - - def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)): - result = sorted(variable.value, reverse=self.node_data.order_by.value == Order.DESC) - variable = variable.model_copy(update={"value": result}) - else: - result = _order_file( - order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value - ) - variable = variable.model_copy(update={"value": result}) - - return variable - - def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - result = variable.value[: self.node_data.limit.size] - return variable.model_copy(update={"value": result}) - - def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text) - if value < 1: - raise ValueError(f"Invalid serial index: must be >= 1, got {value}") - if value > len(variable.value): - raise InvalidKeyError(f"Invalid serial index: must be <= {len(variable.value)}, got {value}") - value -= 1 - result = variable.value[value] - return variable.model_copy(update={"value": [result]}) - - -def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: - match key: - case "size": - return lambda x: x.size - case _: - raise InvalidKeyError(f"Invalid key: {key}") - - -def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: - match key: - case "name": - return lambda x: x.filename or "" - case "type": - return lambda x: str(x.type) - case "extension": - return lambda x: x.extension or "" - case "mime_type": - return lambda x: x.mime_type or "" - case "transfer_method": - return lambda x: str(x.transfer_method) - case "url": - return lambda x: x.remote_url or "" - case "related_id": - return lambda x: x.related_id or "" - case _: - raise InvalidKeyError(f"Invalid key: {key}") - - -def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]: - match condition: - case "contains": - return _contains(value) - case "start with": - return _startswith(value) - case "end with": - return _endswith(value) - case "is": - return _is(value) - case "in": - return _in(value) - case "empty": - return lambda x: x == "" - case "not contains": - return _negation(_contains(value)) - case "is not": - return _negation(_is(value)) - case "not in": - return _negation(_in(value)) - case "not empty": - return lambda x: x != "" - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]: - match condition: - case "in": - return _in(value) - case "not in": - return _negation(_in(value)) - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]: - match condition: - case "=": - return _eq(value) - case "โ‰ ": - return _ne(value) - case "<": - return _lt(value) - case "โ‰ค": - return _le(value) - case ">": - return _gt(value) - case "โ‰ฅ": - return _ge(value) - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]: - match condition: - case FilterOperator.IS: - return _is(value) - case FilterOperator.IS_NOT: - return _negation(_is(value)) - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: - if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str): - extract_func = _get_file_extract_string_func(key=key) - return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) - if key in {"type", "transfer_method"}: - extract_func = _get_file_extract_string_func(key=key) - return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x)) - elif key == "size" and isinstance(value, str): - extract_number = _get_file_extract_number_func(key=key) - return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x)) - else: - raise InvalidKeyError(f"Invalid key: {key}") - - -def _contains(value: str) -> Callable[[str], bool]: - return lambda x: value in x - - -def _startswith(value: str) -> Callable[[str], bool]: - return lambda x: x.startswith(value) - - -def _endswith(value: str) -> Callable[[str], bool]: - return lambda x: x.endswith(value) - - -def _is(value: _T) -> Callable[[_T], bool]: - return lambda x: x == value - - -def _in(value: str | Sequence[str]) -> Callable[[str], bool]: - return lambda x: x in value - - -def _eq(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x == value - - -def _ne(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x != value - - -def _lt(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x < value - - -def _le(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x <= value - - -def _gt(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x > value - - -def _ge(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x >= value - - -def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]): - extract_func: Callable[[File], Any] - if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url", "related_id"}: - extract_func = _get_file_extract_string_func(key=order_by) - return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC) - elif order_by == "size": - extract_func = _get_file_extract_number_func(key=order_by) - return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC) - else: - raise InvalidKeyError(f"Invalid order key: {order_by}") diff --git a/api/graphon/nodes/llm/__init__.py b/api/graphon/nodes/llm/__init__.py deleted file mode 100644 index f7bc713f631..00000000000 --- a/api/graphon/nodes/llm/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from .entities import ( - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from .node import LLMNode - -__all__ = [ - "LLMNode", - "LLMNodeChatModelMessage", - "LLMNodeCompletionModelPromptTemplate", - "LLMNodeData", - "ModelConfig", - "VisionConfig", -] diff --git a/api/graphon/nodes/llm/entities.py b/api/graphon/nodes/llm/entities.py deleted file mode 100644 index 196152548c0..00000000000 --- a/api/graphon/nodes/llm/entities.py +++ /dev/null @@ -1,100 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Literal - -from pydantic import BaseModel, Field, field_validator - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.model_runtime.entities import ImagePromptMessageContent, LLMMode -from graphon.nodes.base.entities import VariableSelector -from graphon.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig - - -class ModelConfig(BaseModel): - provider: str - name: str - mode: LLMMode - completion_params: dict[str, Any] = Field(default_factory=dict) - - -class ContextConfig(BaseModel): - enabled: bool - variable_selector: list[str] | None = None - - -class VisionConfigOptions(BaseModel): - variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"]) - detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH - - -class VisionConfig(BaseModel): - enabled: bool = False - configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions) - - @field_validator("configs", mode="before") - @classmethod - def convert_none_configs(cls, v: Any): - if v is None: - return VisionConfigOptions() - return v - - -class PromptConfig(BaseModel): - jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list) - - @field_validator("jinja2_variables", mode="before") - @classmethod - def convert_none_jinja2_variables(cls, v: Any): - if v is None: - return [] - return v - - -class LLMNodeChatModelMessage(ChatModelMessage): - text: str = "" - jinja2_text: str | None = None - - -class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): - jinja2_text: str | None = None - - -class LLMNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.LLM - model: ModelConfig - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - prompt_config: PromptConfig = Field(default_factory=PromptConfig) - memory: MemoryConfig | None = None - context: ContextConfig - vision: VisionConfig = Field(default_factory=VisionConfig) - structured_output: Mapping[str, Any] | None = None - # We used 'structured_output_enabled' in the past, but it's not a good name. - structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") - reasoning_format: Literal["separated", "tagged"] = Field( - # Keep tagged as default for backward compatibility - default="tagged", - description=( - """ - Strategy for handling model reasoning output. - - separated: Return clean text (without tags) + reasoning_content field. - Recommended for new workflows. Enables safe downstream parsing and - workflow variable access: {{#node_id.reasoning_content#}} - - tagged : Return original text (with tags) + reasoning_content field. - Maintains full backward compatibility while still providing reasoning_content - for workflow automation. Frontend thinking panels work as before. - """ - ), - ) - - @field_validator("prompt_config", mode="before") - @classmethod - def convert_none_prompt_config(cls, v: Any): - if v is None: - return PromptConfig() - return v - - @property - def structured_output_enabled(self) -> bool: - return self.structured_output_switch_on and self.structured_output is not None diff --git a/api/graphon/nodes/llm/exc.py b/api/graphon/nodes/llm/exc.py deleted file mode 100644 index 4d160952963..00000000000 --- a/api/graphon/nodes/llm/exc.py +++ /dev/null @@ -1,45 +0,0 @@ -class LLMNodeError(ValueError): - """Base class for LLM Node errors.""" - - -class VariableNotFoundError(LLMNodeError): - """Raised when a required variable is not found.""" - - -class InvalidContextStructureError(LLMNodeError): - """Raised when the context structure is invalid.""" - - -class InvalidVariableTypeError(LLMNodeError): - """Raised when the variable type is invalid.""" - - -class ModelNotExistError(LLMNodeError): - """Raised when the specified model does not exist.""" - - -class LLMModeRequiredError(LLMNodeError): - """Raised when LLM mode is required but not provided.""" - - -class NoPromptFoundError(LLMNodeError): - """Raised when no prompt is found in the LLM configuration.""" - - -class TemplateTypeNotSupportError(LLMNodeError): - def __init__(self, *, type_name: str): - super().__init__(f"Prompt type {type_name} is not supported.") - - -class MemoryRolePrefixRequiredError(LLMNodeError): - """Raised when memory role prefix is required for completion model.""" - - -class FileTypeNotSupportError(LLMNodeError): - def __init__(self, *, type_name: str): - super().__init__(f"{type_name} type is not supported by this model") - - -class UnsupportedPromptContentTypeError(LLMNodeError): - def __init__(self, *, type_name: str): - super().__init__(f"Prompt content type {type_name} is not supported.") diff --git a/api/graphon/nodes/llm/file_saver.py b/api/graphon/nodes/llm/file_saver.py deleted file mode 100644 index 0bedb42f3a5..00000000000 --- a/api/graphon/nodes/llm/file_saver.py +++ /dev/null @@ -1,139 +0,0 @@ -import mimetypes -import typing as tp - -from graphon.file import File, FileTransferMethod, FileType -from graphon.file.constants import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE -from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol - - -class LLMFileSaver(tp.Protocol): - """LLMFileSaver is responsible for save multimodal output returned by - LLM. - """ - - def save_binary_string( - self, - data: bytes, - mime_type: str, - file_type: FileType, - extension_override: str | None = None, - ) -> File: - """save_binary_string saves the inline file data returned by LLM. - - Currently (2025-04-30), only some of Google Gemini models will return - multimodal output as inline data. - - :param data: the contents of the file - :param mime_type: the media type of the file, specified by rfc6838 - (https://datatracker.ietf.org/doc/html/rfc6838) - :param file_type: The file type of the inline file. - :param extension_override: Override the auto-detected file extension while saving this file. - - The default value is `None`, which means do not override the file extension and guessing it - from the `mime_type` attribute while saving the file. - - Setting it to values other than `None` means override the file's extension, and - will bypass the extension guessing saving the file. - - Specially, setting it to empty string (`""`) will leave the file extension empty. - - When it is not `None` or empty string (`""`), it should be a string beginning with a - dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py` - and `tar.gz` are not. - """ - raise NotImplementedError() - - def save_remote_url(self, url: str, file_type: FileType) -> File: - """save_remote_url saves the file from a remote url returned by LLM. - - Currently (2025-04-30), no model returns multimodel output as a url. - - :param url: the url of the file. - :param file_type: the file type of the file, check `FileType` enum for reference. - """ - raise NotImplementedError() - - -class FileSaverImpl(LLMFileSaver): - _tool_file_manager: ToolFileManagerProtocol - _file_reference_factory: FileReferenceFactoryProtocol - - def __init__( - self, - *, - tool_file_manager: ToolFileManagerProtocol, - file_reference_factory: FileReferenceFactoryProtocol, - http_client: HttpClientProtocol, - ): - self._tool_file_manager = tool_file_manager - self._file_reference_factory = file_reference_factory - self._http_client = http_client - - def save_remote_url(self, url: str, file_type: FileType) -> File: - http_response = self._http_client.get(url) - http_response.raise_for_status() - data = http_response.content - mime_type_from_header = http_response.headers.get("Content-Type") - mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header) - return self.save_binary_string(data, mime_type, file_type, extension_override=extension) - - def save_binary_string( - self, - data: bytes, - mime_type: str, - file_type: FileType, - extension_override: str | None = None, - ) -> File: - tool_file = self._tool_file_manager.create_file_by_raw( - file_binary=data, - mimetype=mime_type, - ) - extension_override = _validate_extension_override(extension_override) - extension = _get_extension(mime_type, extension_override) - return self._file_reference_factory.build_from_mapping( - mapping={ - "type": file_type, - "transfer_method": FileTransferMethod.TOOL_FILE, - "filename": tool_file.name, - "extension": extension, - "mime_type": mime_type, - "size": len(data), - "tool_file_id": str(tool_file.id), - "related_id": str(tool_file.id), - "storage_key": tool_file.file_key, - } - ) - - -def _get_extension(mime_type: str, extension_override: str | None = None) -> str: - """get_extension return the extension of file. - - If the `extension_override` parameter is set, this function should honor it and - return its value. - """ - if extension_override is not None: - return extension_override - return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION - - -def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]: - """_extract_content_type_and_extension tries to - guess content type of file from url and `Content-Type` header in response. - """ - if content_type_header: - extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION - return content_type_header, extension - content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE - extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION - return content_type, extension - - -def _validate_extension_override(extension_override: str | None) -> str | None: - # `extension_override` is allow to be `None or `""`. - if extension_override is None: - return None - if extension_override == "": - return "" - if not extension_override.startswith("."): - raise ValueError("extension_override should start with '.' if not None or empty.", extension_override) - return extension_override diff --git a/api/graphon/nodes/llm/llm_utils.py b/api/graphon/nodes/llm/llm_utils.py deleted file mode 100644 index 11a1d83a9d7..00000000000 --- a/api/graphon/nodes/llm/llm_utils.py +++ /dev/null @@ -1,545 +0,0 @@ -from __future__ import annotations - -import json -import logging -import re -from collections.abc import Mapping, Sequence -from typing import Any - -from graphon.file import FileType, file_manager -from graphon.file.models import File -from graphon.model_runtime.entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, - PromptMessageRole, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageContentUnionTypes, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.nodes.base.entities import VariableSelector -from graphon.runtime import VariablePool -from graphon.template_rendering import Jinja2TemplateRenderer -from graphon.variables import ArrayFileSegment, FileSegment -from graphon.variables.segments import ArrayAnySegment, NoneSegment - -from .entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig -from .exc import ( - InvalidVariableTypeError, - MemoryRolePrefixRequiredError, - NoPromptFoundError, - TemplateTypeNotSupportError, -) -from .runtime_protocols import PreparedLLMProtocol - -CONTEXT_PLACEHOLDER = "{{#context#}}" - -logger = logging.getLogger(__name__) - -VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}") -MAX_RESOLVED_VALUE_LENGTH = 1024 - - -def fetch_model_schema(*, model_instance: PreparedLLMProtocol) -> AIModelEntity: - model_schema = model_instance.get_model_schema() - if not model_schema: - raise ValueError(f"Model schema not found for {getattr(model_instance, 'model_name', 'unknown model')}") - return model_schema - - -def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence[File]: - variable = variable_pool.get(selector) - if variable is None: - return [] - elif isinstance(variable, FileSegment): - return [variable.value] - elif isinstance(variable, ArrayFileSegment): - return variable.value - elif isinstance(variable, NoneSegment | ArrayAnySegment): - return [] - raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") - - -def convert_history_messages_to_text( - *, - history_messages: Sequence[PromptMessage], - human_prefix: str, - ai_prefix: str, -) -> str: - string_messages: list[str] = [] - for message in history_messages: - if message.role == PromptMessageRole.USER: - role = human_prefix - elif message.role == PromptMessageRole.ASSISTANT: - role = ai_prefix - else: - continue - - if isinstance(message.content, list): - content_parts = [] - for content in message.content: - if isinstance(content, TextPromptMessageContent): - content_parts.append(content.data) - elif isinstance(content, ImagePromptMessageContent): - content_parts.append("[image]") - - inner_msg = "\n".join(content_parts) - string_messages.append(f"{role}: {inner_msg}") - else: - string_messages.append(f"{role}: {message.content}") - - return "\n".join(string_messages) - - -def fetch_memory_text( - *, - memory: PromptMessageMemory, - max_token_limit: int, - message_limit: int | None = None, - human_prefix: str = "Human", - ai_prefix: str = "Assistant", -) -> str: - history_messages = memory.get_history_prompt_messages( - max_token_limit=max_token_limit, - message_limit=message_limit, - ) - return convert_history_messages_to_text( - history_messages=history_messages, - human_prefix=human_prefix, - ai_prefix=ai_prefix, - ) - - -def fetch_prompt_messages( - *, - sys_query: str | None = None, - sys_files: Sequence[File], - context: str = "", - memory: PromptMessageMemory | None = None, - model_instance: PreparedLLMProtocol, - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, - stop: Sequence[str] | None = None, - memory_config: MemoryConfig | None = None, - vision_enabled: bool = False, - vision_detail: ImagePromptMessageContent.DETAIL, - variable_pool: VariablePool, - jinja2_variables: Sequence[VariableSelector], - context_files: list[File] | None = None, - template_renderer: Jinja2TemplateRenderer | None = None, -) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: - prompt_messages: list[PromptMessage] = [] - model_schema = fetch_model_schema(model_instance=model_instance) - - if isinstance(prompt_template, list): - prompt_messages.extend( - handle_list_messages( - messages=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - vision_detail_config=vision_detail, - template_renderer=template_renderer, - ) - ) - - prompt_messages.extend( - handle_memory_chat_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - ) - - if sys_query: - prompt_messages.extend( - handle_list_messages( - messages=[ - LLMNodeChatModelMessage( - text=sys_query, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context="", - jinja2_variables=[], - variable_pool=variable_pool, - vision_detail_config=vision_detail, - template_renderer=template_renderer, - ) - ) - elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - prompt_messages.extend( - handle_completion_template( - template=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - template_renderer=template_renderer, - ) - ) - - memory_text = handle_memory_completion_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - prompt_content = prompt_messages[0].content - if isinstance(prompt_content, str): - prompt_content = str(prompt_content) - if "#histories#" in prompt_content: - prompt_content = prompt_content.replace("#histories#", memory_text) - else: - prompt_content = memory_text + "\n" + prompt_content - prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - if "#histories#" in content_item.data: - content_item.data = content_item.data.replace("#histories#", memory_text) - else: - content_item.data = memory_text + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - - if sys_query: - if isinstance(prompt_content, str): - prompt_messages[0].content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - content_item.data = sys_query + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - else: - raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) - - _append_file_prompts( - prompt_messages=prompt_messages, - files=sys_files, - vision_enabled=vision_enabled, - vision_detail=vision_detail, - ) - _append_file_prompts( - prompt_messages=prompt_messages, - files=context_files or [], - vision_enabled=vision_enabled, - vision_detail=vision_detail, - ) - - filtered_prompt_messages: list[PromptMessage] = [] - for prompt_message in prompt_messages: - if isinstance(prompt_message.content, list): - prompt_message_content: list[PromptMessageContentUnionTypes] = [] - for content_item in prompt_message.content: - if not model_schema.features: - if content_item.type == PromptMessageContentType.TEXT: - prompt_message_content.append(content_item) - continue - - if ( - ( - content_item.type == PromptMessageContentType.IMAGE - and ModelFeature.VISION not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.DOCUMENT - and ModelFeature.DOCUMENT not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.VIDEO - and ModelFeature.VIDEO not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.AUDIO - and ModelFeature.AUDIO not in model_schema.features - ) - ): - continue - prompt_message_content.append(content_item) - if not prompt_message_content: - continue - if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: - prompt_message.content = prompt_message_content[0].data - else: - prompt_message.content = prompt_message_content - filtered_prompt_messages.append(prompt_message) - elif not prompt_message.is_empty(): - filtered_prompt_messages.append(prompt_message) - - if len(filtered_prompt_messages) == 0: - raise NoPromptFoundError( - "No prompt found in the LLM configuration. Please ensure a prompt is properly configured before proceeding." - ) - - return filtered_prompt_messages, stop - - -def handle_list_messages( - *, - messages: Sequence[LLMNodeChatModelMessage], - context: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - vision_detail_config: ImagePromptMessageContent.DETAIL, - template_renderer: Jinja2TemplateRenderer | None = None, -) -> Sequence[PromptMessage]: - prompt_messages: list[PromptMessage] = [] - for message in messages: - if message.edition_type == "jinja2": - result_text = render_jinja2_message( - template=message.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - template_renderer=template_renderer, - ) - prompt_messages.append( - combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], - role=message.role, - ) - ) - continue - - template = message.text.replace(CONTEXT_PLACEHOLDER, context) - segment_group = variable_pool.convert_template(template) - file_contents: list[PromptMessageContentUnionTypes] = [] - for segment in segment_group.value: - if isinstance(segment, ArrayFileSegment): - for file in segment.value: - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config) - ) - elif isinstance(segment, FileSegment): - file = segment.value - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config) - ) - - if segment_group.text: - prompt_messages.append( - combine_message_content_with_role( - contents=[TextPromptMessageContent(data=segment_group.text)], - role=message.role, - ) - ) - if file_contents: - prompt_messages.append(combine_message_content_with_role(contents=file_contents, role=message.role)) - - return prompt_messages - - -def render_jinja2_message( - *, - template: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - template_renderer: Jinja2TemplateRenderer | None = None, -) -> str: - if not template: - return "" - if template_renderer is None: - raise ValueError("template_renderer is required for jinja2 prompt rendering") - - jinja2_inputs: dict[str, Any] = {} - for jinja2_variable in jinja2_variables: - variable = variable_pool.get(jinja2_variable.value_selector) - jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" - return template_renderer.render_template(template, jinja2_inputs) - - -def handle_completion_template( - *, - template: LLMNodeCompletionModelPromptTemplate, - context: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - template_renderer: Jinja2TemplateRenderer | None = None, -) -> Sequence[PromptMessage]: - if template.edition_type == "jinja2": - result_text = render_jinja2_message( - template=template.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - template_renderer=template_renderer, - ) - else: - template_text = template.text.replace(CONTEXT_PLACEHOLDER, context) - result_text = variable_pool.convert_template(template_text).text - return [ - combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], - role=PromptMessageRole.USER, - ) - ] - - -def combine_message_content_with_role( - *, - contents: str | list[PromptMessageContentUnionTypes] | None = None, - role: PromptMessageRole, -) -> PromptMessage: - match role: - case PromptMessageRole.USER: - return UserPromptMessage(content=contents) - case PromptMessageRole.ASSISTANT: - return AssistantPromptMessage(content=contents) - case PromptMessageRole.SYSTEM: - return SystemPromptMessage(content=contents) - case _: - raise NotImplementedError(f"Role {role} is not supported") - - -def calculate_rest_token( - *, - prompt_messages: list[PromptMessage], - model_instance: PreparedLLMProtocol, -) -> int: - rest_tokens = 2000 - runtime_model_schema = fetch_model_schema(model_instance=model_instance) - runtime_model_parameters = model_instance.parameters - - model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in runtime_model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - runtime_model_parameters.get(parameter_rule.name) - or runtime_model_parameters.get(str(parameter_rule.use_template)) - or 0 - ) - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - -def handle_memory_chat_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: PreparedLLMProtocol, -) -> Sequence[PromptMessage]: - if not memory or not memory_config: - return [] - rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance) - return memory.get_history_prompt_messages( - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - ) - - -def handle_memory_completion_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: PreparedLLMProtocol, -) -> str: - if not memory or not memory_config: - return "" - - rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance) - if not memory_config.role_prefix: - raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") - - return fetch_memory_text( - memory=memory, - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - human_prefix=memory_config.role_prefix.user, - ai_prefix=memory_config.role_prefix.assistant, - ) - - -def _append_file_prompts( - *, - prompt_messages: list[PromptMessage], - files: Sequence[File], - vision_enabled: bool, - vision_detail: ImagePromptMessageContent.DETAIL, -) -> None: - if not vision_enabled or not files: - return - - file_prompts = [file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in files] - if ( - prompt_messages - and isinstance(prompt_messages[-1], UserPromptMessage) - and isinstance(prompt_messages[-1].content, list) - ): - existing_contents = prompt_messages[-1].content - assert isinstance(existing_contents, list) - prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents) - else: - prompt_messages.append(UserPromptMessage(content=file_prompts)) - - -def _coerce_resolved_value(raw: str) -> int | float | bool | str: - """Try to restore the original type from a resolved template string. - - Variable references are always resolved to text, but completion params may - expect numeric or boolean values (e.g. a variable that holds "0.7" mapped to - the ``temperature`` parameter). This helper attempts a JSON parse so that - ``"0.7"`` โ†’ ``0.7``, ``"true"`` โ†’ ``True``, etc. Plain strings that are not - valid JSON literals are returned as-is. - """ - stripped = raw.strip() - if not stripped: - return raw - - try: - parsed: object = json.loads(stripped) - except (json.JSONDecodeError, ValueError): - return raw - - if isinstance(parsed, (int, float, bool)): - return parsed - return raw - - -def resolve_completion_params_variables( - completion_params: Mapping[str, Any], - variable_pool: VariablePool, -) -> dict[str, Any]: - """Resolve variable references (``{{#node_id.var#}}``) in string-typed completion params. - - Security notes: - - Resolved values are length-capped to ``MAX_RESOLVED_VALUE_LENGTH`` to - prevent denial-of-service through excessively large variable payloads. - - This follows the same ``VariablePool.convert_template`` pattern used across - Dify (Answer Node, HTTP Request Node, Agent Node, etc.). The downstream - model plugin receives these values as structured JSON key-value pairs โ€” they - are never concatenated into raw HTTP headers or SQL queries. - - Numeric/boolean coercion is applied so that variables holding ``"0.7"`` are - restored to their native type rather than sent as a bare string. - """ - resolved: dict[str, Any] = {} - for key, value in completion_params.items(): - if isinstance(value, str) and VARIABLE_PATTERN.search(value): - segment_group = variable_pool.convert_template(value) - text = segment_group.text - if len(text) > MAX_RESOLVED_VALUE_LENGTH: - logger.warning( - "Resolved value for param '%s' truncated from %d to %d chars", - key, - len(text), - MAX_RESOLVED_VALUE_LENGTH, - ) - text = text[:MAX_RESOLVED_VALUE_LENGTH] - resolved[key] = _coerce_resolved_value(text) - else: - resolved[key] = value - return resolved diff --git a/api/graphon/nodes/llm/node.py b/api/graphon/nodes/llm/node.py deleted file mode 100644 index 4de2a954653..00000000000 --- a/api/graphon/nodes/llm/node.py +++ /dev/null @@ -1,1372 +0,0 @@ -from __future__ import annotations - -import base64 -import io -import json -import logging -import re -import time -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal, cast - -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ( - BuiltinNodeTypes, - NodeType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import File, FileType, file_manager -from graphon.model_runtime.entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, - LLMStructuredOutput, - LLMUsage, -) -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageContentUnionTypes, - PromptMessageRole, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.node_events import ( - ModelInvokeCompletedEvent, - NodeEventBase, - NodeRunResult, - RunRetrieverResourceEvent, - StreamChunkEvent, - StreamCompletedEvent, -) -from graphon.nodes.base.entities import VariableSelector -from graphon.nodes.base.node import Node -from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from graphon.nodes.llm.runtime_protocols import ( - PreparedLLMProtocol, - PromptMessageSerializerProtocol, - RetrieverAttachmentLoaderProtocol, -) -from graphon.nodes.protocols import HttpClientProtocol -from graphon.prompt_entities import CompletionModelPromptTemplate, MemoryConfig -from graphon.runtime import VariablePool -from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError -from graphon.variables import ( - ArrayFileSegment, - ArraySegment, - FileSegment, - NoneSegment, - ObjectSegment, - StringSegment, -) - -from . import llm_utils -from .entities import ( - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - LLMNodeData, -) -from .exc import ( - InvalidContextStructureError, - InvalidVariableTypeError, - LLMNodeError, - MemoryRolePrefixRequiredError, - NoPromptFoundError, - TemplateTypeNotSupportError, - VariableNotFoundError, -) -from .file_saver import LLMFileSaver - -if TYPE_CHECKING: - from graphon.file.models import File - from graphon.runtime import GraphRuntimeState - -logger = logging.getLogger(__name__) - - -class LLMNode(Node[LLMNodeData]): - node_type = BuiltinNodeTypes.LLM - - # Compiled regex for extracting blocks (with compatibility for attributes) - _THINK_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) - - # Instance attributes specific to LLMNode. - # Output variable for file - _file_outputs: list[File] - - _llm_file_saver: LLMFileSaver - _retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None - _prompt_message_serializer: PromptMessageSerializerProtocol - _jinja2_template_renderer: Jinja2TemplateRenderer | None - _model_instance: PreparedLLMProtocol - _memory: PromptMessageMemory | None - _default_query_selector: tuple[str, ...] | None - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - *, - credentials_provider: object | None = None, - model_factory: object | None = None, - model_instance: PreparedLLMProtocol, - http_client: HttpClientProtocol, - memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver, - prompt_message_serializer: PromptMessageSerializerProtocol, - retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None = None, - jinja2_template_renderer: Jinja2TemplateRenderer | None = None, - default_query_selector: Sequence[str] | None = None, - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - # LLM file outputs, used for MultiModal outputs. - self._file_outputs = [] - - _ = credentials_provider, model_factory, http_client - self._model_instance = model_instance - self._memory = memory - - self._llm_file_saver = llm_file_saver - self._prompt_message_serializer = prompt_message_serializer - self._retriever_attachment_loader = retriever_attachment_loader - self._jinja2_template_renderer = jinja2_template_renderer - self._default_query_selector = tuple(default_query_selector) if default_query_selector is not None else None - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator: - node_inputs: dict[str, Any] = {} - process_data: dict[str, Any] = {} - result_text = "" - clean_text = "" - usage = LLMUsage.empty_usage() - finish_reason = None - reasoning_content = None - variable_pool = self.graph_runtime_state.variable_pool - - try: - # init messages template - self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) - - # fetch variables and fetch values from variable pool - inputs = self._fetch_inputs(node_data=self.node_data) - - # fetch jinja2 inputs - jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) - - # merge inputs - inputs.update(jinja_inputs) - - # fetch files - files = ( - llm_utils.fetch_files( - variable_pool=variable_pool, - selector=self.node_data.vision.configs.variable_selector, - ) - if self.node_data.vision.enabled - else [] - ) - - if files: - node_inputs["#files#"] = [file.to_dict() for file in files] - - # fetch context value - generator = self._fetch_context(node_data=self.node_data) - context = None - context_files: list[File] = [] - if generator is not None: - for event in generator: - context = event.context - context_files = event.context_files or [] - yield event - if context: - node_inputs["#context#"] = context - - if context_files: - node_inputs["#context_files#"] = [file.model_dump() for file in context_files] - - # fetch model config - model_instance = self._model_instance - # Resolve variable references in string-typed completion params - model_instance.parameters = llm_utils.resolve_completion_params_variables( - model_instance.parameters, variable_pool - ) - model_name = model_instance.model_name - model_provider = model_instance.provider - model_stop = model_instance.stop - - memory = self._memory - - query: str | None = None - if self.node_data.memory: - query = self.node_data.memory.query_prompt_template - if ( - not query - and self._default_query_selector - and (query_variable := variable_pool.get(self._default_query_selector)) - ): - query = query_variable.text - - prompt_messages, stop = LLMNode.fetch_prompt_messages( - sys_query=query, - sys_files=files, - context=context or "", - memory=memory, - model_instance=model_instance, - stop=model_stop, - prompt_template=self.node_data.prompt_template, - memory_config=self.node_data.memory, - vision_enabled=self.node_data.vision.enabled, - vision_detail=self.node_data.vision.configs.detail, - variable_pool=variable_pool, - jinja2_variables=self.node_data.prompt_config.jinja2_variables, - context_files=context_files, - jinja2_template_renderer=self._jinja2_template_renderer, - ) - - # handle invoke result - generator = LLMNode.invoke_llm( - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - structured_output_enabled=self.node_data.structured_output_enabled, - structured_output=self.node_data.structured_output, - file_saver=self._llm_file_saver, - file_outputs=self._file_outputs, - node_id=self._node_id, - node_type=self.node_type, - reasoning_format=self.node_data.reasoning_format, - ) - - structured_output: LLMStructuredOutput | None = None - - for event in generator: - if isinstance(event, StreamChunkEvent): - yield event - elif isinstance(event, ModelInvokeCompletedEvent): - # Raw text - result_text = event.text - usage = event.usage - finish_reason = event.finish_reason - reasoning_content = event.reasoning_content or "" - - # For downstream nodes, determine clean text based on reasoning_format - if self.node_data.reasoning_format == "tagged": - # Keep tags for backward compatibility - clean_text = result_text - else: - # Extract clean text from tags - clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format) - - # Process structured output if available from the event. - structured_output = ( - LLMStructuredOutput(structured_output=event.structured_output) - if event.structured_output - else None - ) - - break - elif isinstance(event, LLMStructuredOutput): - structured_output = event - - process_data = { - "model_mode": self.node_data.model.mode, - "prompts": self._prompt_message_serializer.serialize( - model_mode=self.node_data.model.mode, prompt_messages=prompt_messages - ), - "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, - "model_provider": model_provider, - "model_name": model_name, - } - - outputs = { - "text": clean_text, - "reasoning_content": reasoning_content, - "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, - } - if structured_output: - outputs["structured_output"] = structured_output.structured_output - if self._file_outputs: - outputs["files"] = ArrayFileSegment(value=self._file_outputs) - - # Send final chunk event to indicate streaming is complete - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk="", - is_final=True, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - process_data=process_data, - outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - ) - except ValueError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=node_inputs, - process_data=process_data, - error_type=type(e).__name__, - llm_usage=usage, - ) - ) - except Exception as e: - logger.exception("error while executing llm node") - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=node_inputs, - process_data=process_data, - error_type=type(e).__name__, - llm_usage=usage, - ) - ) - - @staticmethod - def invoke_llm( - *, - model_instance: PreparedLLMProtocol, - prompt_messages: Sequence[PromptMessage], - stop: Sequence[str] | None = None, - structured_output_enabled: bool, - structured_output: Mapping[str, Any] | None = None, - file_saver: LLMFileSaver, - file_outputs: list[File], - node_id: str, - node_type: NodeType, - reasoning_format: Literal["separated", "tagged"] = "tagged", - ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: - model_parameters = model_instance.parameters - invoke_model_parameters = dict(model_parameters) - invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None] - if structured_output_enabled: - output_schema = LLMNode.fetch_structured_output_schema( - structured_output=structured_output or {}, - ) - request_start_time = time.perf_counter() - - invoke_result = cast( - LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], - model_instance.invoke_llm_with_structured_output( - prompt_messages=prompt_messages, - json_schema=output_schema, - model_parameters=invoke_model_parameters, - stop=stop, - stream=True, - ), - ) - else: - request_start_time = time.perf_counter() - - invoke_result = cast( - LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], - model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=invoke_model_parameters, - tools=None, - stop=stop, - stream=True, - ), - ) - - return LLMNode.handle_invoke_result( - invoke_result=invoke_result, - file_saver=file_saver, - file_outputs=file_outputs, - node_id=node_id, - node_type=node_type, - model_instance=model_instance, - reasoning_format=reasoning_format, - request_start_time=request_start_time, - ) - - @staticmethod - def handle_invoke_result( - *, - invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], - file_saver: LLMFileSaver, - file_outputs: list[File], - node_id: str, - node_type: NodeType, - model_instance: PreparedLLMProtocol | object, - reasoning_format: Literal["separated", "tagged"] = "tagged", - request_start_time: float | None = None, - ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: - # For blocking mode - if isinstance(invoke_result, LLMResult): - duration = None - if request_start_time is not None: - duration = time.perf_counter() - request_start_time - invoke_result.usage.latency = round(duration, 3) - event = LLMNode.handle_blocking_result( - invoke_result=invoke_result, - saver=file_saver, - file_outputs=file_outputs, - reasoning_format=reasoning_format, - request_latency=duration, - ) - yield event - return - - # For streaming mode - model = "" - prompt_messages: list[PromptMessage] = [] - - usage = LLMUsage.empty_usage() - finish_reason = None - full_text_buffer = io.StringIO() - - # Initialize streaming metrics tracking - start_time = request_start_time if request_start_time is not None else time.perf_counter() - first_token_time = None - has_content = False - - collected_structured_output = None # Collect structured_output from streaming chunks - # Consume the invoke result and handle generator exception - try: - for result in invoke_result: - if isinstance(result, LLMResultChunkWithStructuredOutput): - # Collect structured_output from the chunk - if result.structured_output is not None: - collected_structured_output = dict(result.structured_output) - yield result - if isinstance(result, LLMResultChunk): - contents = result.delta.message.content - for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( - contents=contents, - file_saver=file_saver, - file_outputs=file_outputs, - ): - # Detect first token for TTFT calculation - if text_part and not has_content: - first_token_time = time.perf_counter() - has_content = True - - full_text_buffer.write(text_part) - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=text_part, - is_final=False, - ) - - # Update the whole metadata - if not model and result.model: - model = result.model - if len(prompt_messages) == 0: - # TODO(QuantumGhost): it seems that this update has no visable effect. - # What's the purpose of the line below? - prompt_messages = list(result.prompt_messages) - if usage.prompt_tokens == 0 and result.delta.usage: - usage = result.delta.usage - if finish_reason is None and result.delta.finish_reason: - finish_reason = result.delta.finish_reason - except Exception as e: - if hasattr(model_instance, "is_structured_output_parse_error") and cast( - PreparedLLMProtocol, model_instance - ).is_structured_output_parse_error(e): - raise LLMNodeError(f"Failed to parse structured output: {e}") from e - if type(e).__name__ == "OutputParserError": - raise LLMNodeError(f"Failed to parse structured output: {e}") from e - raise - - # Extract reasoning content from tags in the main text - full_text = full_text_buffer.getvalue() - - if reasoning_format == "tagged": - # Keep tags in text for backward compatibility - clean_text = full_text - reasoning_content = "" - else: - # Extract clean text and reasoning from tags - clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) - - # Calculate streaming metrics - end_time = time.perf_counter() - total_duration = end_time - start_time - usage.latency = round(total_duration, 3) - if has_content and first_token_time: - gen_ai_server_time_to_first_token = first_token_time - start_time - llm_streaming_time_to_generate = end_time - first_token_time - usage.time_to_first_token = round(gen_ai_server_time_to_first_token, 3) - usage.time_to_generate = round(llm_streaming_time_to_generate, 3) - - yield ModelInvokeCompletedEvent( - # Use clean_text for separated mode, full_text for tagged mode - text=clean_text if reasoning_format == "separated" else full_text, - usage=usage, - finish_reason=finish_reason, - # Reasoning content for workflow variables and downstream nodes - reasoning_content=reasoning_content, - # Pass structured output if collected from streaming chunks - structured_output=collected_structured_output, - ) - - @staticmethod - def _image_file_to_markdown(file: File, /): - text_chunk = f"![]({file.generate_url()})" - return text_chunk - - @classmethod - def _split_reasoning( - cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged" - ) -> tuple[str, str]: - """ - Split reasoning content from text based on reasoning_format strategy. - - Args: - text: Full text that may contain blocks - reasoning_format: Strategy for handling reasoning content - - "separated": Remove tags and return clean text + reasoning_content field - - "tagged": Keep tags in text, return empty reasoning_content - - Returns: - tuple of (clean_text, reasoning_content) - """ - - if reasoning_format == "tagged": - return text, "" - - # Find all ... blocks (case-insensitive) - matches = cls._THINK_PATTERN.findall(text) - - # Extract reasoning content from all blocks - reasoning_content = "\n".join(match.strip() for match in matches) if matches else "" - - # Remove all ... blocks from original text - clean_text = cls._THINK_PATTERN.sub("", text) - - # Clean up extra whitespace - clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip() - - # Separated mode: always return clean text and reasoning_content - return clean_text, reasoning_content or "" - - def _transform_chat_messages( - self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / - ) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: - if isinstance(messages, LLMNodeCompletionModelPromptTemplate): - if messages.edition_type == "jinja2" and messages.jinja2_text: - messages.text = messages.jinja2_text - - return messages - - for message in messages: - if message.edition_type == "jinja2" and message.jinja2_text: - message.text = message.jinja2_text - - return messages - - def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: - variables: dict[str, Any] = {} - - if not node_data.prompt_config: - return variables - - for variable_selector in node_data.prompt_config.jinja2_variables or []: - variable_name = variable_selector.variable - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if variable is None: - raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") - - def parse_dict(input_dict: Mapping[str, Any]) -> str: - """ - Parse dict into string - """ - # check if it's a context structure - if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict: - return str(input_dict["content"]) - - # else, parse the dict - try: - return json.dumps(input_dict, ensure_ascii=False) - except Exception: - return str(input_dict) - - if isinstance(variable, ArraySegment): - result = "" - for item in variable.value: - if isinstance(item, dict): - result += parse_dict(item) - else: - result += str(item) - result += "\n" - value = result.strip() - elif isinstance(variable, ObjectSegment): - value = parse_dict(variable.value) - else: - value = variable.text - - variables[variable_name] = value - - return variables - - def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]: - inputs = {} - prompt_template = node_data.prompt_template - - variable_selectors = [] - if isinstance(prompt_template, list): - for prompt in prompt_template: - variable_template_parser = VariableTemplateParser(template=prompt.text) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - elif isinstance(prompt_template, CompletionModelPromptTemplate): - variable_template_parser = VariableTemplateParser(template=prompt_template.text) - variable_selectors = variable_template_parser.extract_variable_selectors() - - for variable_selector in variable_selectors: - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if variable is None: - raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") - if isinstance(variable, NoneSegment): - inputs[variable_selector.variable] = "" - inputs[variable_selector.variable] = variable.to_object() - - memory = node_data.memory - if memory and memory.query_prompt_template: - query_variable_selectors = VariableTemplateParser( - template=memory.query_prompt_template - ).extract_variable_selectors() - for variable_selector in query_variable_selectors: - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if variable is None: - raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") - if isinstance(variable, NoneSegment): - continue - inputs[variable_selector.variable] = variable.to_object() - - return inputs - - def _fetch_context(self, node_data: LLMNodeData): - if not node_data.context.enabled: - return - - if not node_data.context.variable_selector: - return - - context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector) - if context_value_variable: - if isinstance(context_value_variable, StringSegment): - yield RunRetrieverResourceEvent( - retriever_resources=[], context=context_value_variable.value, context_files=[] - ) - elif isinstance(context_value_variable, ArraySegment): - context_str = "" - original_retriever_resource: list[dict[str, Any]] = [] - context_files: list[File] = [] - for item in context_value_variable.value: - if isinstance(item, str): - context_str += item + "\n" - else: - if "content" not in item: - raise InvalidContextStructureError(f"Invalid context structure: {item}") - - if item.get("summary"): - context_str += item["summary"] + "\n" - context_str += item["content"] + "\n" - - retriever_resource = self._convert_to_original_retriever_resource(item) - if retriever_resource: - original_retriever_resource.append(retriever_resource) - segment_id = retriever_resource.get("segment_id") - if not segment_id: - continue - if self._retriever_attachment_loader is not None: - context_files.extend(self._retriever_attachment_loader.load(segment_id=segment_id)) - yield RunRetrieverResourceEvent( - retriever_resources=original_retriever_resource, - context=context_str.strip(), - context_files=context_files, - ) - - def _convert_to_original_retriever_resource(self, context_dict: dict) -> dict[str, Any] | None: - if ( - "metadata" in context_dict - and "_source" in context_dict["metadata"] - and context_dict["metadata"]["_source"] == "knowledge" - ): - metadata = context_dict.get("metadata", {}) - - return { - "position": metadata.get("position"), - "dataset_id": metadata.get("dataset_id"), - "dataset_name": metadata.get("dataset_name"), - "document_id": metadata.get("document_id"), - "document_name": metadata.get("document_name"), - "data_source_type": metadata.get("data_source_type"), - "segment_id": metadata.get("segment_id"), - "retriever_from": metadata.get("retriever_from"), - "score": metadata.get("score"), - "hit_count": metadata.get("segment_hit_count"), - "word_count": metadata.get("segment_word_count"), - "segment_position": metadata.get("segment_position"), - "index_node_hash": metadata.get("segment_index_node_hash"), - "content": context_dict.get("content"), - "page": metadata.get("page"), - "doc_metadata": metadata.get("doc_metadata"), - "files": context_dict.get("files"), - "summary": context_dict.get("summary"), - } - - return None - - @staticmethod - def fetch_prompt_messages( - *, - sys_query: str | None = None, - sys_files: Sequence[File], - context: str = "", - memory: PromptMessageMemory | None = None, - model_instance: PreparedLLMProtocol, - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, - stop: Sequence[str] | None = None, - memory_config: MemoryConfig | None = None, - vision_enabled: bool = False, - vision_detail: ImagePromptMessageContent.DETAIL, - variable_pool: VariablePool, - jinja2_variables: Sequence[VariableSelector], - context_files: list[File] | None = None, - jinja2_template_renderer: Jinja2TemplateRenderer | None = None, - ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: - prompt_messages: list[PromptMessage] = [] - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - - if isinstance(prompt_template, list): - # For chat model - prompt_messages.extend( - LLMNode.handle_list_messages( - messages=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - vision_detail_config=vision_detail, - jinja2_template_renderer=jinja2_template_renderer, - ) - ) - - # Get memory messages for chat mode - memory_messages = _handle_memory_chat_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - # Extend prompt_messages with memory messages - prompt_messages.extend(memory_messages) - - # Add current query to the prompt messages - if sys_query: - message = LLMNodeChatModelMessage( - text=sys_query, - role=PromptMessageRole.USER, - edition_type="basic", - ) - prompt_messages.extend( - LLMNode.handle_list_messages( - messages=[message], - context="", - jinja2_variables=[], - variable_pool=variable_pool, - vision_detail_config=vision_detail, - jinja2_template_renderer=jinja2_template_renderer, - ) - ) - - elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - # For completion model - prompt_messages.extend( - _handle_completion_template( - template=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - jinja2_template_renderer=jinja2_template_renderer, - ) - ) - - # Get memory text for completion model - memory_text = _handle_memory_completion_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - # Insert histories into the prompt - prompt_content = prompt_messages[0].content - # For issue #11247 - Check if prompt content is a string or a list - if isinstance(prompt_content, str): - prompt_content = str(prompt_content) - if "#histories#" in prompt_content: - prompt_content = prompt_content.replace("#histories#", memory_text) - else: - prompt_content = memory_text + "\n" + prompt_content - prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - if "#histories#" in content_item.data: - content_item.data = content_item.data.replace("#histories#", memory_text) - else: - content_item.data = memory_text + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - - # Add current query to the prompt message - if sys_query: - if isinstance(prompt_content, str): - prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) - prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - content_item.data = sys_query + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - else: - raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) - - # The sys_files will be deprecated later - if vision_enabled and sys_files: - file_prompts = [] - for file in sys_files: - file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) - file_prompts.append(file_prompt) - # If last prompt is a user prompt, add files into its contents, - # otherwise append a new user prompt - if ( - len(prompt_messages) > 0 - and isinstance(prompt_messages[-1], UserPromptMessage) - and isinstance(prompt_messages[-1].content, list) - ): - prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) - else: - prompt_messages.append(UserPromptMessage(content=file_prompts)) - - # The context_files - if vision_enabled and context_files: - file_prompts = [] - for file in context_files: - file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) - file_prompts.append(file_prompt) - # If last prompt is a user prompt, add files into its contents, - # otherwise append a new user prompt - if ( - len(prompt_messages) > 0 - and isinstance(prompt_messages[-1], UserPromptMessage) - and isinstance(prompt_messages[-1].content, list) - ): - prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) - else: - prompt_messages.append(UserPromptMessage(content=file_prompts)) - - # Remove empty messages and filter unsupported content - filtered_prompt_messages = [] - for prompt_message in prompt_messages: - if isinstance(prompt_message.content, list): - prompt_message_content: list[PromptMessageContentUnionTypes] = [] - for content_item in prompt_message.content: - # Skip content if features are not defined - if not model_schema.features: - if content_item.type != PromptMessageContentType.TEXT: - continue - prompt_message_content.append(content_item) - continue - - # Skip content if corresponding feature is not supported - if ( - ( - content_item.type == PromptMessageContentType.IMAGE - and ModelFeature.VISION not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.DOCUMENT - and ModelFeature.DOCUMENT not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.VIDEO - and ModelFeature.VIDEO not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.AUDIO - and ModelFeature.AUDIO not in model_schema.features - ) - ): - continue - prompt_message_content.append(content_item) - if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: - prompt_message.content = prompt_message_content[0].data - else: - prompt_message.content = prompt_message_content - if prompt_message.is_empty(): - continue - filtered_prompt_messages.append(prompt_message) - - if len(filtered_prompt_messages) == 0: - raise NoPromptFoundError( - "No prompt found in the LLM configuration. " - "Please ensure a prompt is properly configured before proceeding." - ) - - return filtered_prompt_messages, stop - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: LLMNodeData, - ) -> Mapping[str, Sequence[str]]: - # graph_config is not used in this node type - _ = graph_config # Explicitly mark as unused - prompt_template = node_data.prompt_template - variable_selectors = [] - if isinstance(prompt_template, list): - for prompt in prompt_template: - if prompt.edition_type != "jinja2": - variable_template_parser = VariableTemplateParser(template=prompt.text) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - if prompt_template.edition_type != "jinja2": - variable_template_parser = VariableTemplateParser(template=prompt_template.text) - variable_selectors = variable_template_parser.extract_variable_selectors() - else: - raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") - - variable_mapping: dict[str, Any] = {} - for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - memory = node_data.memory - if memory and memory.query_prompt_template: - query_variable_selectors = VariableTemplateParser( - template=memory.query_prompt_template - ).extract_variable_selectors() - for variable_selector in query_variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - if node_data.context.enabled: - variable_mapping["#context#"] = node_data.context.variable_selector - - if node_data.vision.enabled: - variable_mapping["#files#"] = node_data.vision.configs.variable_selector - - if node_data.prompt_config: - enable_jinja = False - - if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - if prompt_template.edition_type == "jinja2": - enable_jinja = True - else: - for prompt in prompt_template: - if prompt.edition_type == "jinja2": - enable_jinja = True - break - - if enable_jinja: - for variable_selector in node_data.prompt_config.jinja2_variables or []: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} - - return variable_mapping - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return { - "type": "llm", - "config": { - "prompt_templates": { - "chat_model": { - "prompts": [ - {"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"} - ] - }, - "completion_model": { - "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, - "prompt": { - "text": "Here are the chat histories between human and assistant, inside " - " XML tags.\n\n\n{{" - "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", - "edition_type": "basic", - }, - "stop": ["Human:"], - }, - } - }, - } - - @staticmethod - def handle_list_messages( - *, - messages: Sequence[LLMNodeChatModelMessage], - context: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - vision_detail_config: ImagePromptMessageContent.DETAIL, - jinja2_template_renderer: Jinja2TemplateRenderer | None = None, - ) -> Sequence[PromptMessage]: - prompt_messages: list[PromptMessage] = [] - for message in messages: - if message.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=message.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - jinja2_template_renderer=jinja2_template_renderer, - ) - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], role=message.role - ) - prompt_messages.append(prompt_message) - else: - # Get segment group from basic message - template = message.text.replace(llm_utils.CONTEXT_PLACEHOLDER, context) - segment_group = variable_pool.convert_template(template) - - # Process segments for images - file_contents = [] - for segment in segment_group.value: - if isinstance(segment, ArrayFileSegment): - for file in segment.value: - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=vision_detail_config - ) - file_contents.append(file_content) - elif isinstance(segment, FileSegment): - file = segment.value - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=vision_detail_config - ) - file_contents.append(file_content) - - # Create message with text from all segments - plain_text = segment_group.text - if plain_text: - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=plain_text)], role=message.role - ) - prompt_messages.append(prompt_message) - - if file_contents: - # Create message with image contents - prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) - prompt_messages.append(prompt_message) - - return prompt_messages - - @staticmethod - def handle_blocking_result( - *, - invoke_result: LLMResult | LLMResultWithStructuredOutput, - saver: LLMFileSaver, - file_outputs: list[File], - reasoning_format: Literal["separated", "tagged"] = "tagged", - request_latency: float | None = None, - ) -> ModelInvokeCompletedEvent: - buffer = io.StringIO() - for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( - contents=invoke_result.message.content, - file_saver=saver, - file_outputs=file_outputs, - ): - buffer.write(text_part) - - # Extract reasoning content from tags in the main text - full_text = buffer.getvalue() - - if reasoning_format == "tagged": - # Keep tags in text for backward compatibility - clean_text = full_text - reasoning_content = "" - else: - # Extract clean text and reasoning from tags - clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) - - event = ModelInvokeCompletedEvent( - # Use clean_text for separated mode, full_text for tagged mode - text=clean_text if reasoning_format == "separated" else full_text, - usage=invoke_result.usage, - finish_reason=None, - # Reasoning content for workflow variables and downstream nodes - reasoning_content=reasoning_content, - # Pass structured output if enabled - structured_output=getattr(invoke_result, "structured_output", None), - ) - if request_latency is not None: - event.usage.latency = round(request_latency, 3) - return event - - @staticmethod - def save_multimodal_image_output( - *, - content: ImagePromptMessageContent, - file_saver: LLMFileSaver, - ) -> File: - """_save_multimodal_output saves multi-modal contents generated by LLM plugins. - - There are two kinds of multimodal outputs: - - - Inlined data encoded in base64, which would be saved to storage directly. - - Remote files referenced by an url, which would be downloaded and then saved to storage. - - Currently, only image files are supported. - """ - if content.url != "": - saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE) - else: - saved_file = file_saver.save_binary_string( - data=base64.b64decode(content.base64_data), - mime_type=content.mime_type, - file_type=FileType.IMAGE, - ) - return saved_file - - @staticmethod - def fetch_structured_output_schema( - *, - structured_output: Mapping[str, Any], - ) -> dict[str, Any]: - """ - Fetch the structured output schema from the node data. - - Returns: - dict[str, Any]: The structured output schema - """ - if not structured_output: - raise LLMNodeError("Please provide a valid structured output schema") - structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False) - if not structured_output_schema: - raise LLMNodeError("Please provide a valid structured output schema") - - try: - schema = json.loads(structured_output_schema) - if not isinstance(schema, dict): - raise LLMNodeError("structured_output_schema must be a JSON object") - return schema - except json.JSONDecodeError: - raise LLMNodeError("structured_output_schema is not valid JSON format") - - @staticmethod - def _save_multimodal_output_and_convert_result_to_markdown( - *, - contents: str | list[PromptMessageContentUnionTypes] | None, - file_saver: LLMFileSaver, - file_outputs: list[File], - ) -> Generator[str, None, None]: - """Convert intermediate prompt messages into strings and yield them to the caller. - - If the messages contain non-textual content (e.g., multimedia like images or videos), - it will be saved separately, and the corresponding Markdown representation will - be yielded to the caller. - """ - - # NOTE(QuantumGhost): This function should yield results to the caller immediately - # whenever new content or partial content is available. Avoid any intermediate buffering - # of results. Additionally, do not yield empty strings; instead, yield from an empty list - # if necessary. - if contents is None: - yield from [] - return - if isinstance(contents, str): - yield contents - else: - for item in contents: - if isinstance(item, TextPromptMessageContent): - yield item.data - elif isinstance(item, ImagePromptMessageContent): - file = LLMNode.save_multimodal_image_output( - content=item, - file_saver=file_saver, - ) - file_outputs.append(file) - yield LLMNode._image_file_to_markdown(file) - else: - logger.warning("unknown item type encountered, type=%s", type(item)) - yield str(item) - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled - - @property - def model_instance(self) -> PreparedLLMProtocol: - return self._model_instance - - -def _combine_message_content_with_role( - *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole -): - match role: - case PromptMessageRole.USER: - return UserPromptMessage(content=contents) - case PromptMessageRole.ASSISTANT: - return AssistantPromptMessage(content=contents) - case PromptMessageRole.SYSTEM: - return SystemPromptMessage(content=contents) - case _: - raise NotImplementedError(f"Role {role} is not supported") - - -def _render_jinja2_message( - *, - template: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - jinja2_template_renderer: Jinja2TemplateRenderer | None, -): - if not template: - return "" - - jinja2_inputs = {} - for jinja2_variable in jinja2_variables: - variable = variable_pool.get(jinja2_variable.value_selector) - jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" - if jinja2_template_renderer is None: - raise TemplateRenderError("LLMNode requires an injected jinja2_template_renderer for jinja2 prompts.") - return jinja2_template_renderer.render_template(template, jinja2_inputs) - - -def _calculate_rest_token( - *, - prompt_messages: list[PromptMessage], - model_instance: PreparedLLMProtocol, -) -> int: - rest_tokens = 2000 - runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - runtime_model_parameters = model_instance.parameters - - model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in runtime_model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - runtime_model_parameters.get(parameter_rule.name) - or runtime_model_parameters.get(str(parameter_rule.use_template)) - or 0 - ) - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - -def _handle_memory_chat_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: PreparedLLMProtocol, -) -> Sequence[PromptMessage]: - memory_messages: Sequence[PromptMessage] = [] - # Get messages from memory for chat model - if memory and memory_config: - rest_tokens = _calculate_rest_token( - prompt_messages=[], - model_instance=model_instance, - ) - memory_messages = memory.get_history_prompt_messages( - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - ) - return memory_messages - - -def _handle_memory_completion_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: PreparedLLMProtocol, -) -> str: - memory_text = "" - # Get history text from memory for completion model - if memory and memory_config: - rest_tokens = _calculate_rest_token( - prompt_messages=[], - model_instance=model_instance, - ) - if not memory_config.role_prefix: - raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") - memory_text = llm_utils.fetch_memory_text( - memory=memory, - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - human_prefix=memory_config.role_prefix.user, - ai_prefix=memory_config.role_prefix.assistant, - ) - return memory_text - - -def _handle_completion_template( - *, - template: LLMNodeCompletionModelPromptTemplate, - context: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - jinja2_template_renderer: Jinja2TemplateRenderer | None = None, -) -> Sequence[PromptMessage]: - """Handle completion template processing outside of LLMNode class. - - Args: - template: The completion model prompt template - context: Context string - jinja2_variables: Variables for jinja2 template rendering - variable_pool: Variable pool for template conversion - - Returns: - Sequence of prompt messages - """ - prompt_messages = [] - if template.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=template.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - jinja2_template_renderer=jinja2_template_renderer, - ) - else: - template_text = template.text.replace(llm_utils.CONTEXT_PLACEHOLDER, context) - result_text = variable_pool.convert_template(template_text).text - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER - ) - prompt_messages.append(prompt_message) - return prompt_messages diff --git a/api/graphon/nodes/llm/protocols.py b/api/graphon/nodes/llm/protocols.py deleted file mode 100644 index 65bfd533d11..00000000000 --- a/api/graphon/nodes/llm/protocols.py +++ /dev/null @@ -1,21 +0,0 @@ -from __future__ import annotations - -from typing import Any, Protocol - -from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol - - -class CredentialsProvider(Protocol): - """Port for loading runtime credentials for a provider/model pair.""" - - def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: - """Return credentials for the target provider/model or raise a domain error.""" - ... - - -class ModelFactory(Protocol): - """Port for creating prepared graph-facing LLM runtimes for execution.""" - - def init_model_instance(self, provider_name: str, model_name: str) -> PreparedLLMProtocol: - """Create a prepared LLM runtime that is ready for graph execution.""" - ... diff --git a/api/graphon/nodes/llm/runtime_protocols.py b/api/graphon/nodes/llm/runtime_protocols.py deleted file mode 100644 index dbe415d3632..00000000000 --- a/api/graphon/nodes/llm/runtime_protocols.py +++ /dev/null @@ -1,77 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator, Mapping, Sequence -from typing import Any, Protocol - -from graphon.file import File -from graphon.model_runtime.entities import LLMMode, PromptMessage -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, -) -from graphon.model_runtime.entities.message_entities import PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity - - -class PreparedLLMProtocol(Protocol): - """A graph-facing LLM runtime with provider-specific setup already applied.""" - - @property - def provider(self) -> str: ... - - @property - def model_name(self) -> str: ... - - @property - def parameters(self) -> Mapping[str, Any]: ... - - @parameters.setter - def parameters(self, value: Mapping[str, Any]) -> None: ... - - @property - def stop(self) -> Sequence[str] | None: ... - - def get_model_schema(self) -> AIModelEntity: ... - - def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: ... - - def invoke_llm( - self, - *, - prompt_messages: Sequence[PromptMessage], - model_parameters: Mapping[str, Any], - tools: Sequence[PromptMessageTool] | None, - stop: Sequence[str] | None, - stream: bool, - ) -> LLMResult | Generator[LLMResultChunk, None, None]: ... - - def invoke_llm_with_structured_output( - self, - *, - prompt_messages: Sequence[PromptMessage], - json_schema: Mapping[str, Any], - model_parameters: Mapping[str, Any], - stop: Sequence[str] | None, - stream: bool, - ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... - - def is_structured_output_parse_error(self, error: Exception) -> bool: ... - - -class PromptMessageSerializerProtocol(Protocol): - """Port for converting compiled prompt messages into persisted process data.""" - - def serialize( - self, - *, - model_mode: LLMMode, - prompt_messages: Sequence[PromptMessage], - ) -> Any: ... - - -class RetrieverAttachmentLoaderProtocol(Protocol): - """Port for resolving retriever segment attachments into graph file references.""" - - def load(self, *, segment_id: str) -> Sequence[File]: ... diff --git a/api/graphon/nodes/loop/__init__.py b/api/graphon/nodes/loop/__init__.py deleted file mode 100644 index 9fe695607b9..00000000000 --- a/api/graphon/nodes/loop/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .entities import LoopNodeData -from .loop_end_node import LoopEndNode -from .loop_node import LoopNode -from .loop_start_node import LoopStartNode - -__all__ = ["LoopEndNode", "LoopNode", "LoopNodeData", "LoopStartNode"] diff --git a/api/graphon/nodes/loop/entities.py b/api/graphon/nodes/loop/entities.py deleted file mode 100644 index e7362769e97..00000000000 --- a/api/graphon/nodes/loop/entities.py +++ /dev/null @@ -1,107 +0,0 @@ -from enum import StrEnum -from typing import Annotated, Any, Literal - -from pydantic import AfterValidator, BaseModel, Field, field_validator - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base import BaseLoopNodeData, BaseLoopState -from graphon.utils.condition.entities import Condition -from graphon.variables.types import SegmentType - -_VALID_VAR_TYPE = frozenset( - [ - SegmentType.STRING, - SegmentType.NUMBER, - SegmentType.OBJECT, - SegmentType.BOOLEAN, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_BOOLEAN, - ] -) - - -def _is_valid_var_type(seg_type: SegmentType) -> SegmentType: - if seg_type not in _VALID_VAR_TYPE: - raise ValueError(...) - return seg_type - - -class LoopVariableData(BaseModel): - """ - Loop Variable Data. - """ - - label: str - var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)] - value_type: Literal["variable", "constant"] - value: Any | list[str] | None = None - - -class LoopNodeData(BaseLoopNodeData): - type: NodeType = BuiltinNodeTypes.LOOP - loop_count: int # Maximum number of loops - break_conditions: list[Condition] # Conditions to break the loop - logical_operator: Literal["and", "or"] - loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData]) - outputs: dict[str, Any] = Field(default_factory=dict) - - @field_validator("outputs", mode="before") - @classmethod - def validate_outputs(cls, v): - if v is None: - return {} - return v - - -class LoopStartNodeData(BaseNodeData): - """ - Loop Start Node Data. - """ - - type: NodeType = BuiltinNodeTypes.LOOP_START - - -class LoopEndNodeData(BaseNodeData): - """ - Loop End Node Data. - """ - - type: NodeType = BuiltinNodeTypes.LOOP_END - - -class LoopState(BaseLoopState): - """ - Loop State. - """ - - outputs: list[Any] = Field(default_factory=list) - current_output: Any = None - - class MetaData(BaseLoopState.MetaData): - """ - Data. - """ - - loop_length: int - - def get_last_output(self) -> Any: - """ - Get last output. - """ - if self.outputs: - return self.outputs[-1] - return None - - def get_current_output(self) -> Any: - """ - Get current output. - """ - return self.current_output - - -class LoopCompletedReason(StrEnum): - LOOP_BREAK = "loop_break" - LOOP_COMPLETED = "loop_completed" diff --git a/api/graphon/nodes/loop/loop_end_node.py b/api/graphon/nodes/loop/loop_end_node.py deleted file mode 100644 index c0562b59c4c..00000000000 --- a/api/graphon/nodes/loop/loop_end_node.py +++ /dev/null @@ -1,22 +0,0 @@ -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.loop.entities import LoopEndNodeData - - -class LoopEndNode(Node[LoopEndNodeData]): - """ - Loop End Node. - """ - - node_type = BuiltinNodeTypes.LOOP_END - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run the node. - """ - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) diff --git a/api/graphon/nodes/loop/loop_node.py b/api/graphon/nodes/loop/loop_node.py deleted file mode 100644 index d574e9f7ae7..00000000000 --- a/api/graphon/nodes/loop/loop_node.py +++ /dev/null @@ -1,428 +0,0 @@ -import contextlib -import json -import logging -from collections.abc import Callable, Generator, Mapping, Sequence -from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, Literal, cast - -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.enums import ( - BuiltinNodeTypes, - NodeExecutionType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.graph_events import ( - GraphNodeEventBase, - GraphRunAbortedEvent, - GraphRunFailedEvent, - NodeRunSucceededEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import ( - LoopFailedEvent, - LoopNextEvent, - LoopStartedEvent, - LoopSucceededEvent, - NodeEventBase, - NodeRunResult, - StreamCompletedEvent, -) -from graphon.nodes.base import LLMUsageTrackingMixin -from graphon.nodes.base.node import Node -from graphon.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData -from graphon.utils.condition.processor import ConditionProcessor -from graphon.variables import Segment, SegmentType, TypeMismatchError, build_segment_with_type, segment_to_variable - -if TYPE_CHECKING: - from graphon.graph_engine import GraphEngine - -logger = logging.getLogger(__name__) -_DEFAULT_CHILD_ABORT_REASON = "child graph aborted" - - -class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): - """ - Loop Node. - """ - - node_type = BuiltinNodeTypes.LOOP - execution_type = NodeExecutionType.CONTAINER - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator: - """Run the node.""" - # Get inputs - loop_count = self.node_data.loop_count - break_conditions = self.node_data.break_conditions - logical_operator = self.node_data.logical_operator - - inputs = {"loop_count": loop_count} - - if not self.node_data.start_node_id: - raise ValueError(f"field start_node_id in loop {self._node_id} not found") - - root_node_id = self.node_data.start_node_id - - # Initialize loop variables in the original variable pool - loop_variable_selectors = {} - if self.node_data.loop_variables: - value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = { - "constant": lambda var: self._get_segment_for_constant(var.var_type, var.value), - "variable": lambda var: ( - self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None - ), - } - for loop_variable in self.node_data.loop_variables: - if loop_variable.value_type not in value_processor: - raise ValueError( - f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}" - ) - - processed_segment = value_processor[loop_variable.value_type](loop_variable) - if not processed_segment: - raise ValueError(f"Invalid value for loop variable {loop_variable.label}") - variable_selector = [self._node_id, loop_variable.label] - variable = segment_to_variable(segment=processed_segment, selector=variable_selector) - self.graph_runtime_state.variable_pool.add(variable_selector, variable.value) - loop_variable_selectors[loop_variable.label] = variable_selector - inputs[loop_variable.label] = processed_segment.value - - start_at = datetime.now(UTC).replace(tzinfo=None) - condition_processor = ConditionProcessor() - - loop_duration_map: dict[str, float] = {} - single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output - loop_usage = LLMUsage.empty_usage() - loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id) - - # Start Loop event - yield LoopStartedEvent( - start_at=start_at, - inputs=inputs, - metadata={"loop_length": loop_count}, - ) - - try: - reach_break_condition = False - if break_conditions: - with contextlib.suppress(ValueError): - _, _, reach_break_condition = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=break_conditions, - operator=logical_operator, - ) - - if reach_break_condition: - loop_count = 0 - - for i in range(loop_count): - # Clear stale variables from previous loop iterations to avoid streaming old values - self._clear_loop_subgraph_variables(loop_node_ids) - graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id) - - loop_start_time = datetime.now(UTC).replace(tzinfo=None) - try: - reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i) - finally: - loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) - # Track loop duration - loop_duration_map[str(i)] = (datetime.now(UTC).replace(tzinfo=None) - loop_start_time).total_seconds() - - # Accumulate outputs from the sub-graph's response nodes - for key, value in graph_engine.graph_runtime_state.outputs.items(): - if key == "answer": - # Concatenate answer outputs with newline - existing_answer = self.graph_runtime_state.get_output("answer", "") - if existing_answer: - self.graph_runtime_state.set_output("answer", f"{existing_answer}{value}") - else: - self.graph_runtime_state.set_output("answer", value) - else: - # For other outputs, just update - self.graph_runtime_state.set_output(key, value) - - # Collect loop variable values after iteration - single_loop_variable = {} - for key, selector in loop_variable_selectors.items(): - segment = self.graph_runtime_state.variable_pool.get(selector) - single_loop_variable[key] = segment.value if segment else None - - single_loop_variable_map[str(i)] = single_loop_variable - - if reach_break_node: - break - - if break_conditions: - _, _, reach_break_condition = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=break_conditions, - operator=logical_operator, - ) - if reach_break_condition: - break - - yield LoopNextEvent( - index=i + 1, - pre_loop_output=self.node_data.outputs, - ) - - self._accumulate_usage(loop_usage) - # Loop completed successfully - yield LoopSucceededEvent( - start_at=start_at, - inputs=inputs, - outputs=self.node_data.outputs, - steps=loop_count, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - WorkflowNodeExecutionMetadataKey.COMPLETED_REASON: ( - LoopCompletedReason.LOOP_BREAK - if reach_break_condition - else LoopCompletedReason.LOOP_COMPLETED.value - ), - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - outputs=self.node_data.outputs, - inputs=inputs, - llm_usage=loop_usage, - ) - ) - - except Exception as e: - self._accumulate_usage(loop_usage) - yield LoopFailedEvent( - start_at=start_at, - inputs=inputs, - steps=loop_count, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - "completed_reason": "error", - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - error=str(e), - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - llm_usage=loop_usage, - ) - ) - - def _run_single_loop( - self, - *, - graph_engine: "GraphEngine", - current_index: int, - ) -> Generator[NodeEventBase | GraphNodeEventBase, None, bool]: - reach_break_node = False - for event in graph_engine.run(): - if isinstance(event, GraphNodeEventBase): - self._append_loop_info_to_event(event=event, loop_run_index=current_index) - - if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.LOOP_START: - continue - if isinstance(event, GraphNodeEventBase): - yield event - if isinstance(event, NodeRunSucceededEvent) and event.node_type == BuiltinNodeTypes.LOOP_END: - reach_break_node = True - if isinstance(event, GraphRunAbortedEvent): - raise RuntimeError(event.reason or _DEFAULT_CHILD_ABORT_REASON) - if isinstance(event, GraphRunFailedEvent): - raise Exception(event.error) - - for loop_var in self.node_data.loop_variables or []: - key, sel = loop_var.label, [self._node_id, loop_var.label] - segment = self.graph_runtime_state.variable_pool.get(sel) - self.node_data.outputs[key] = segment.value if segment else None - self.node_data.outputs["loop_round"] = current_index + 1 - - return reach_break_node - - def _append_loop_info_to_event( - self, - event: GraphNodeEventBase, - loop_run_index: int, - ): - event.in_loop_id = self._node_id - loop_metadata = { - WorkflowNodeExecutionMetadataKey.LOOP_ID: self._node_id, - WorkflowNodeExecutionMetadataKey.LOOP_INDEX: loop_run_index, - } - - current_metadata = event.node_run_result.metadata - if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata: - event.node_run_result.metadata = {**current_metadata, **loop_metadata} - - def _clear_loop_subgraph_variables(self, loop_node_ids: set[str]) -> None: - """ - Remove variables produced by loop sub-graph nodes from previous iterations. - - Keeping stale variables causes a freshly created response coordinator in the - next iteration to fall back to outdated values when no stream chunks exist. - """ - variable_pool = self.graph_runtime_state.variable_pool - for node_id in loop_node_ids: - variable_pool.remove([node_id]) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: LoopNodeData, - ) -> Mapping[str, Sequence[str]]: - variable_mapping = {} - - # Extract loop node IDs statically from graph_config - - loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id) - - # Get node configs from graph_config - node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} - for sub_node_id, sub_node_config in node_configs.items(): - if sub_node_config.get("data", {}).get("loop_id") != node_id: - continue - - # variable selector to variable mapping - try: - typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) - node_type = typed_sub_node_config["data"].type - node_mapping = Node.get_node_type_classes_mapping() - if node_type not in node_mapping: - continue - node_version = str(typed_sub_node_config["data"].version) - node_cls = node_mapping[node_type][node_version] - - sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, config=typed_sub_node_config - ) - sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) - except NotImplementedError: - sub_node_variable_mapping = {} - - # remove loop variables - sub_node_variable_mapping = { - sub_node_id + "." + key: value - for key, value in sub_node_variable_mapping.items() - if value[0] != node_id - } - - variable_mapping.update(sub_node_variable_mapping) - - for loop_variable in node_data.loop_variables or []: - if loop_variable.value_type == "variable": - assert loop_variable.value is not None, "Loop variable value must be provided for variable type" - # add loop variable to variable mapping - selector = loop_variable.value - variable_mapping[f"{node_id}.{loop_variable.label}"] = selector - - # remove variable out from loop - variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids} - - return variable_mapping - - @classmethod - def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]: - """ - Extract node IDs that belong to a specific loop from graph configuration. - - This method statically analyzes the graph configuration to find all nodes - that are part of the specified loop, without creating actual node instances. - - :param graph_config: the complete graph configuration - :param loop_node_id: the ID of the loop node - :return: set of node IDs that belong to the loop - """ - loop_node_ids = set() - - # Find all nodes that belong to this loop - nodes = graph_config.get("nodes", []) - for node in nodes: - node_data = node.get("data", {}) - if node_data.get("loop_id") == loop_node_id: - node_id = node.get("id") - if node_id: - loop_node_ids.add(node_id) - - return loop_node_ids - - @staticmethod - def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment: - """Get the appropriate segment type for a constant value.""" - # TODO: Refactor for maintainability: - # 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py) - # 2. Consider moving this method to LoopVariableData class for better encapsulation - if not var_type.is_array_type() or var_type == SegmentType.ARRAY_BOOLEAN: - value = original_value - elif var_type in [ - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_STRING, - ]: - if original_value and isinstance(original_value, str): - value = json.loads(original_value) - else: - logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type) - value = [] - else: - raise AssertionError("this statement should be unreachable.") - try: - return build_segment_with_type(var_type, value=value) - except TypeMismatchError as type_exc: - # Attempt to parse the value as a JSON-encoded string, if applicable. - if not isinstance(original_value, str): - raise - try: - value = json.loads(original_value) - except ValueError: - raise type_exc - return build_segment_with_type(var_type, value) - - def _create_graph_engine(self, start_at: datetime, root_node_id: str): - from graphon.entities import GraphInitParams - - # Create GraphInitParams for child graph execution. - graph_init_params = GraphInitParams( - workflow_id=self.workflow_id, - graph_config=self.graph_config, - run_context=self.run_context, - call_depth=self.workflow_call_depth, - ) - - return self.graph_runtime_state.create_child_engine( - workflow_id=self.workflow_id, - graph_init_params=graph_init_params, - root_node_id=root_node_id, - ) diff --git a/api/graphon/nodes/loop/loop_start_node.py b/api/graphon/nodes/loop/loop_start_node.py deleted file mode 100644 index 2b17054ae22..00000000000 --- a/api/graphon/nodes/loop/loop_start_node.py +++ /dev/null @@ -1,22 +0,0 @@ -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.loop.entities import LoopStartNodeData - - -class LoopStartNode(Node[LoopStartNodeData]): - """ - Loop Start Node. - """ - - node_type = BuiltinNodeTypes.LOOP_START - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run the node. - """ - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) diff --git a/api/graphon/nodes/parameter_extractor/__init__.py b/api/graphon/nodes/parameter_extractor/__init__.py deleted file mode 100644 index bdbf19a7d36..00000000000 --- a/api/graphon/nodes/parameter_extractor/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .parameter_extractor_node import ParameterExtractorNode - -__all__ = ["ParameterExtractorNode"] diff --git a/api/graphon/nodes/parameter_extractor/entities.py b/api/graphon/nodes/parameter_extractor/entities.py deleted file mode 100644 index 8fda1b9e79c..00000000000 --- a/api/graphon/nodes/parameter_extractor/entities.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import Annotated, Any, Literal - -from pydantic import ( - BaseModel, - BeforeValidator, - Field, - field_validator, -) - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.llm.entities import ModelConfig, VisionConfig -from graphon.prompt_entities import MemoryConfig -from graphon.variables.types import SegmentType - -_OLD_BOOL_TYPE_NAME = "bool" -_OLD_SELECT_TYPE_NAME = "select" - -_VALID_PARAMETER_TYPES = frozenset( - [ - SegmentType.STRING, # "string", - SegmentType.NUMBER, # "number", - SegmentType.BOOLEAN, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_BOOLEAN, - _OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node - _OLD_SELECT_TYPE_NAME, # string type with enumeration choices. - ] -) - - -def _validate_type(parameter_type: str) -> SegmentType: - if parameter_type not in _VALID_PARAMETER_TYPES: - raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.") - - if parameter_type == _OLD_BOOL_TYPE_NAME: - return SegmentType.BOOLEAN - elif parameter_type == _OLD_SELECT_TYPE_NAME: - return SegmentType.STRING - return SegmentType(parameter_type) - - -class ParameterConfig(BaseModel): - """ - Parameter Config. - """ - - name: str - type: Annotated[SegmentType, BeforeValidator(_validate_type)] - options: list[str] | None = None - description: str - required: bool - - @field_validator("name", mode="before") - @classmethod - def validate_name(cls, value) -> str: - if not value: - raise ValueError("Parameter name is required") - if value in {"__reason", "__is_success"}: - raise ValueError("Invalid parameter name, __reason and __is_success are reserved") - return str(value) - - def is_array_type(self) -> bool: - return self.type.is_array_type() - - def element_type(self) -> SegmentType: - """Return the element type of the parameter. - - Raises a ValueError if the parameter's type is not an array type. - """ - element_type = self.type.element_type() - # At this point, self.type is guaranteed to be one of `ARRAY_STRING`, - # `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`. - # - # See: _VALID_PARAMETER_TYPES for reference. - assert element_type is not None, f"the element type should not be None, {self.type=}" - return element_type - - -class ParameterExtractorNodeData(BaseNodeData): - """ - Parameter Extractor Node Data. - """ - - type: NodeType = BuiltinNodeTypes.PARAMETER_EXTRACTOR - model: ModelConfig - query: list[str] - parameters: list[ParameterConfig] - instruction: str | None = None - memory: MemoryConfig | None = None - reasoning_mode: Literal["function_call", "prompt"] - vision: VisionConfig = Field(default_factory=VisionConfig) - - @field_validator("reasoning_mode", mode="before") - @classmethod - def set_reasoning_mode(cls, v) -> str: - return v or "function_call" - - def get_parameter_json_schema(self): - """ - Get parameter json schema. - - :return: parameter json schema - """ - parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []} - - for parameter in self.parameters: - parameter_schema: dict[str, Any] = {"description": parameter.description} - - if parameter.type == SegmentType.STRING: - parameter_schema["type"] = "string" - elif parameter.type.is_array_type(): - parameter_schema["type"] = "array" - element_type = parameter.type.element_type() - if element_type is None: - raise AssertionError("element type should not be None.") - parameter_schema["items"] = {"type": element_type.value} - else: - parameter_schema["type"] = parameter.type - - if parameter.options: - parameter_schema["enum"] = parameter.options - - parameters["properties"][parameter.name] = parameter_schema - - if parameter.required: - parameters["required"].append(parameter.name) - - return parameters diff --git a/api/graphon/nodes/parameter_extractor/exc.py b/api/graphon/nodes/parameter_extractor/exc.py deleted file mode 100644 index faa90313c1f..00000000000 --- a/api/graphon/nodes/parameter_extractor/exc.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Any - -from graphon.variables.types import SegmentType - - -class ParameterExtractorNodeError(ValueError): - """Base error for ParameterExtractorNode.""" - - -class InvalidModelTypeError(ParameterExtractorNodeError): - """Raised when the model is not a Large Language Model.""" - - -class ModelSchemaNotFoundError(ParameterExtractorNodeError): - """Raised when the model schema is not found.""" - - -class InvalidInvokeResultError(ParameterExtractorNodeError): - """Raised when the invoke result is invalid.""" - - -class InvalidTextContentTypeError(ParameterExtractorNodeError): - """Raised when the text content type is invalid.""" - - -class InvalidNumberOfParametersError(ParameterExtractorNodeError): - """Raised when the number of parameters is invalid.""" - - -class RequiredParameterMissingError(ParameterExtractorNodeError): - """Raised when a required parameter is missing.""" - - -class InvalidSelectValueError(ParameterExtractorNodeError): - """Raised when a select value is invalid.""" - - -class InvalidNumberValueError(ParameterExtractorNodeError): - """Raised when a number value is invalid.""" - - -class InvalidBoolValueError(ParameterExtractorNodeError): - """Raised when a bool value is invalid.""" - - -class InvalidStringValueError(ParameterExtractorNodeError): - """Raised when a string value is invalid.""" - - -class InvalidArrayValueError(ParameterExtractorNodeError): - """Raised when an array value is invalid.""" - - -class InvalidModelModeError(ParameterExtractorNodeError): - """Raised when the model mode is invalid.""" - - -class InvalidValueTypeError(ParameterExtractorNodeError): - def __init__( - self, - /, - parameter_name: str, - expected_type: SegmentType, - actual_type: SegmentType | None, - value: Any, - ): - message = ( - f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, " - f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}" - ) - super().__init__(message) - self.parameter_name = parameter_name - self.expected_type = expected_type - self.actual_type = actual_type - self.value = value diff --git a/api/graphon/nodes/parameter_extractor/parameter_extractor_node.py b/api/graphon/nodes/parameter_extractor/parameter_extractor_node.py deleted file mode 100644 index 25379e325c2..00000000000 --- a/api/graphon/nodes/parameter_extractor/parameter_extractor_node.py +++ /dev/null @@ -1,846 +0,0 @@ -import contextlib -import json -import logging -import uuid -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, cast - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import File -from graphon.model_runtime.entities import ImagePromptMessageContent, LLMMode -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageRole, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.node_events import NodeRunResult -from graphon.nodes.base import variable_template_parser -from graphon.nodes.base.node import Node -from graphon.nodes.llm import LLMNode, llm_utils -from graphon.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate -from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol -from graphon.runtime import VariablePool -from graphon.variables import build_segment_with_type -from graphon.variables.types import ArrayValidation, SegmentType - -from .entities import ParameterExtractorNodeData -from .exc import ( - InvalidModelModeError, - InvalidModelTypeError, - InvalidNumberOfParametersError, - InvalidSelectValueError, - InvalidTextContentTypeError, - InvalidValueTypeError, - ModelSchemaNotFoundError, - ParameterExtractorNodeError, - RequiredParameterMissingError, -) -from .prompts import ( - CHAT_EXAMPLE, - CHAT_GENERATE_JSON_PROMPT, - CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, - COMPLETION_GENERATE_JSON_PROMPT, - FUNCTION_CALLING_EXTRACTOR_EXAMPLE, - FUNCTION_CALLING_EXTRACTOR_NAME, - FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT, - FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE, -) - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - - -def extract_json(text): - """ - From a given JSON started from '{' or '[' extract the complete JSON object. - """ - stack = [] - for i, c in enumerate(text): - if c in {"{", "["}: - stack.append(c) - elif c in {"}", "]"}: - # check if stack is empty - if not stack: - return text[:i] - # check if the last element in stack is matching - if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["): - stack.pop() - if not stack: - return text[: i + 1] - else: - return text[:i] - return None - - -class ParameterExtractorNode(Node[ParameterExtractorNodeData]): - """ - Parameter Extractor Node. - """ - - node_type = BuiltinNodeTypes.PARAMETER_EXTRACTOR - - _model_instance: PreparedLLMProtocol - _prompt_message_serializer: PromptMessageSerializerProtocol - _memory: PromptMessageMemory | None - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - credentials_provider: object | None = None, - model_factory: object | None = None, - model_instance: PreparedLLMProtocol, - memory: PromptMessageMemory | None = None, - prompt_message_serializer: PromptMessageSerializerProtocol, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - _ = credentials_provider, model_factory - self._model_instance = model_instance - self._prompt_message_serializer = prompt_message_serializer - self._memory = memory - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return { - "model": { - "prompt_templates": { - "completion_model": { - "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, - "stop": ["Human:"], - } - } - } - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self): - """ - Run the node. - """ - node_data = self.node_data - variable = self.graph_runtime_state.variable_pool.get(node_data.query) - query = variable.text if variable else "" - - variable_pool = self.graph_runtime_state.variable_pool - - files = ( - llm_utils.fetch_files( - variable_pool=variable_pool, - selector=node_data.vision.configs.variable_selector, - ) - if node_data.vision.enabled - else [] - ) - - model_instance = self._model_instance - # Resolve variable references in string-typed completion params - model_instance.parameters = llm_utils.resolve_completion_params_variables( - model_instance.parameters, variable_pool - ) - try: - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - except ValueError as exc: - raise ModelSchemaNotFoundError("Model schema not found") from exc - if model_schema.model_type != ModelType.LLM: - raise InvalidModelTypeError("Model is not a Large Language Model") - memory = self._memory - - if ( - set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} - and node_data.reasoning_mode == "function_call" - ): - # use function call - prompt_messages, prompt_message_tools = self._generate_function_call_prompt( - node_data=node_data, - query=query, - variable_pool=self.graph_runtime_state.variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=node_data.vision.configs.detail, - ) - else: - # use prompt engineering - prompt_messages = self._generate_prompt_engineering_prompt( - data=node_data, - query=query, - variable_pool=self.graph_runtime_state.variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=node_data.vision.configs.detail, - ) - - prompt_message_tools = [] - - inputs = { - "query": query, - "files": [f.to_dict() for f in files], - "parameters": jsonable_encoder(node_data.parameters), - "instruction": jsonable_encoder(node_data.instruction), - } - - process_data = { - "model_mode": node_data.model.mode, - "prompts": self._prompt_message_serializer.serialize( - model_mode=node_data.model.mode, - prompt_messages=prompt_messages, - ), - "usage": None, - "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), - "tool_call": None, - "model_provider": model_instance.provider, - "model_name": model_instance.model_name, - } - - try: - text, usage, tool_call = self._invoke( - model_instance=model_instance, - prompt_messages=prompt_messages, - tools=prompt_message_tools, - stop=model_instance.stop, - ) - process_data["usage"] = jsonable_encoder(usage) - process_data["tool_call"] = jsonable_encoder(tool_call) - process_data["llm_text"] = text - except ParameterExtractorNodeError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - outputs={"__is_success": 0, "__reason": str(e)}, - error=str(e), - metadata={}, - ) - except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)}, - error=str(e), - metadata={}, - ) - - error = None - - if tool_call: - result = self._extract_json_from_tool_call(tool_call) - else: - result = self._extract_complete_json_response(text) - if not result: - result = self._generate_default_result(node_data) - error = "Failed to extract result from function call or text response, using empty result." - - try: - result = self._validate_result(data=node_data, result=result or {}) - except ParameterExtractorNodeError as e: - error = str(e) - - # transform result into standard format - result = self._transform_result(data=node_data, result=result or {}) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={ - "__is_success": 1 if not error else 0, - "__reason": error, - "__usage": jsonable_encoder(usage), - **result, - }, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - - def _invoke( - self, - model_instance: PreparedLLMProtocol, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - stop: Sequence[str] | None, - ) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]: - invoke_result = cast( - LLMResult, - model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=dict(model_instance.parameters), - tools=tools or None, - stop=stop, - stream=False, - ), - ) - - # handle invoke result - - text = invoke_result.message.get_text_content() - if not isinstance(text, str): - raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.") - - usage = invoke_result.usage - tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None - - return text, usage, tool_call - - def _generate_function_call_prompt( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: PreparedLLMProtocol, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: - """ - Generate function call prompt. - """ - query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format( - content=query, structure=json.dumps(node_data.get_parameter_json_schema()) - ) - - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - context="", - ) - prompt_template = self._get_function_calling_prompt_template( - node_data, query, variable_pool, memory, rest_token - ) - prompt_messages = self._compile_prompt_messages( - model_instance=model_instance, - prompt_template=prompt_template, - files=files, - vision_enabled=node_data.vision.enabled, - image_detail_config=vision_detail, - ) - - # find last user message - last_user_message_idx = -1 - for i, prompt_message in enumerate(prompt_messages): - if prompt_message.role == PromptMessageRole.USER: - last_user_message_idx = i - - # add function call messages before last user message - example_messages = [] - for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE: - id = uuid.uuid4().hex - example_messages.extend( - [ - UserPromptMessage(content=example["user"]["query"]), - AssistantPromptMessage( - content=example["assistant"]["text"], - tool_calls=[ - AssistantPromptMessage.ToolCall( - id=id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=example["assistant"]["function_call"]["name"], - arguments=json.dumps(example["assistant"]["function_call"]["parameters"]), - ), - ) - ], - ), - ToolPromptMessage( - content="Great! You have called the function with the correct parameters.", tool_call_id=id - ), - AssistantPromptMessage( - content="I have extracted the parameters, let's move on.", - ), - ] - ) - - prompt_messages = ( - prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] - ) - - # generate tool - tool = PromptMessageTool( - name=FUNCTION_CALLING_EXTRACTOR_NAME, - description="Extract parameters from the natural language text", - parameters=node_data.get_parameter_json_schema(), - ) - - return prompt_messages, [tool] - - def _generate_prompt_engineering_prompt( - self, - data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: PreparedLLMProtocol, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - """ - Generate prompt engineering prompt. - """ - if data.model.mode == LLMMode.COMPLETION: - return self._generate_prompt_engineering_completion_prompt( - node_data=data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=vision_detail, - ) - if data.model.mode == LLMMode.CHAT: - return self._generate_prompt_engineering_chat_prompt( - node_data=data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=vision_detail, - ) - raise InvalidModelModeError(f"Invalid model mode: {data.model.mode}") - - def _generate_prompt_engineering_completion_prompt( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: PreparedLLMProtocol, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - """ - Generate completion prompt. - """ - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - context="", - ) - prompt_template = self._get_prompt_engineering_prompt_template( - node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token - ) - return self._compile_prompt_messages( - model_instance=model_instance, - prompt_template=prompt_template, - files=files, - vision_enabled=node_data.vision.enabled, - image_detail_config=vision_detail, - ) - - def _generate_prompt_engineering_chat_prompt( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: PreparedLLMProtocol, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - """ - Generate chat prompt. - """ - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - context="", - ) - prompt_template = self._get_prompt_engineering_prompt_template( - node_data=node_data, - query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(node_data.get_parameter_json_schema()), text=query - ), - variable_pool=variable_pool, - memory=memory, - max_token_limit=rest_token, - ) - - prompt_messages = self._compile_prompt_messages( - model_instance=model_instance, - prompt_template=prompt_template, - files=files, - vision_enabled=node_data.vision.enabled, - image_detail_config=vision_detail, - ) - - # find last user message - last_user_message_idx = -1 - for i, prompt_message in enumerate(prompt_messages): - if prompt_message.role == PromptMessageRole.USER: - last_user_message_idx = i - - # add example messages before last user message - example_messages = [] - for example in CHAT_EXAMPLE: - example_messages.extend( - [ - UserPromptMessage( - content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(example["user"]["json"]), - text=example["user"]["query"], - ) - ), - AssistantPromptMessage( - content=json.dumps(example["assistant"]["json"]), - ), - ] - ) - - prompt_messages = ( - prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] - ) - - return prompt_messages - - def _validate_result(self, data: ParameterExtractorNodeData, result: dict): - if len(data.parameters) != len(result): - raise InvalidNumberOfParametersError("Invalid number of parameters") - - for parameter in data.parameters: - if parameter.required and parameter.name not in result: - raise RequiredParameterMissingError(f"Parameter {parameter.name} is required") - - param_value = result.get(parameter.name) - if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL): - inferred_type = SegmentType.infer_segment_type(param_value) - raise InvalidValueTypeError( - parameter_name=parameter.name, - expected_type=parameter.type, - actual_type=inferred_type, - value=param_value, - ) - if parameter.type == SegmentType.STRING and parameter.options: - if param_value not in parameter.options: - raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}") - return result - - @staticmethod - def _transform_number(value: int | float | str | bool) -> int | float | None: - """ - Attempts to transform the input into an integer or float. - - Returns: - int or float: The transformed number if the conversion is successful. - None: If the transformation fails. - - Note: - Boolean values `True` and `False` are converted to integers `1` and `0`, respectively. - This behavior ensures compatibility with existing workflows that may use boolean types as integers. - """ - if isinstance(value, bool): - return int(value) - elif isinstance(value, (int, float)): - return value - elif isinstance(value, str): - if "." in value: - try: - return float(value) - except ValueError: - return None - else: - try: - return int(value) - except ValueError: - return None - else: - return None - - def _transform_result(self, data: ParameterExtractorNodeData, result: dict): - """ - Transform result into standard format. - """ - transformed_result: dict[str, Any] = {} - for parameter in data.parameters: - if parameter.name in result: - param_value = result[parameter.name] - # transform value - if parameter.type == SegmentType.NUMBER: - transformed = self._transform_number(param_value) - if transformed is not None: - transformed_result[parameter.name] = transformed - elif parameter.type == SegmentType.BOOLEAN: - if isinstance(result[parameter.name], (bool, int)): - transformed_result[parameter.name] = bool(result[parameter.name]) - # elif isinstance(result[parameter.name], str): - # if result[parameter.name].lower() in ["true", "false"]: - # transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true") - elif parameter.type == SegmentType.STRING: - if isinstance(param_value, str): - transformed_result[parameter.name] = param_value - elif parameter.is_array_type(): - if isinstance(param_value, list): - nested_type = parameter.element_type() - assert nested_type is not None - segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[]) - transformed_result[parameter.name] = segment_value - for item in param_value: - if nested_type == SegmentType.NUMBER: - transformed = self._transform_number(item) - if transformed is not None: - segment_value.value.append(transformed) - elif nested_type == SegmentType.STRING: - if isinstance(item, str): - segment_value.value.append(item) - elif nested_type == SegmentType.OBJECT: - if isinstance(item, dict): - segment_value.value.append(item) - elif nested_type == SegmentType.BOOLEAN: - if isinstance(item, bool): - segment_value.value.append(item) - - if parameter.name not in transformed_result: - if parameter.type.is_array_type(): - transformed_result[parameter.name] = build_segment_with_type( - segment_type=SegmentType(parameter.type), value=[] - ) - elif parameter.type in (SegmentType.STRING, SegmentType.SECRET): - transformed_result[parameter.name] = "" - elif parameter.type == SegmentType.NUMBER: - transformed_result[parameter.name] = 0 - elif parameter.type == SegmentType.BOOLEAN: - transformed_result[parameter.name] = False - else: - raise AssertionError("this statement should be unreachable.") - - return transformed_result - - def _extract_complete_json_response(self, result: str) -> dict | None: - """ - Extract complete json response. - """ - - # extract json from the text - for idx in range(len(result)): - if result[idx] == "{" or result[idx] == "[": - json_str = extract_json(result[idx:]) - if json_str: - with contextlib.suppress(Exception): - return cast(dict, json.loads(json_str)) - logger.info("extra error: %s", result) - return None - - def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None: - """ - Extract json from tool call. - """ - if not tool_call or not tool_call.function.arguments: - return None - - result = tool_call.function.arguments - # extract json from the arguments - for idx in range(len(result)): - if result[idx] == "{" or result[idx] == "[": - json_str = extract_json(result[idx:]) - if json_str: - with contextlib.suppress(Exception): - return cast(dict, json.loads(json_str)) - - logger.info("extra error: %s", result) - return None - - def _generate_default_result(self, data: ParameterExtractorNodeData): - """ - Generate default result. - """ - result: dict[str, Any] = {} - for parameter in data.parameters: - if parameter.type == "number": - result[parameter.name] = 0 - elif parameter.type == "boolean": - result[parameter.name] = False - elif parameter.type in {"string", "select"}: - result[parameter.name] = "" - - return result - - def _get_function_calling_prompt_template( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - memory: PromptMessageMemory | None, - max_token_limit: int = 2000, - ) -> list[LLMNodeChatModelMessage]: - input_text = query - memory_str = "" - instruction = variable_pool.convert_template(node_data.instruction or "").text - - if memory and node_data.memory and node_data.memory.window: - memory_str = llm_utils.fetch_memory_text( - memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size - ) - if node_data.model.mode == LLMMode.CHAT: - system_prompt_messages = LLMNodeChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), - ) - user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text) - return [system_prompt_messages, user_prompt_message] - raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.") - - def _get_prompt_engineering_prompt_template( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - memory: PromptMessageMemory | None, - max_token_limit: int = 2000, - ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: - input_text = query - memory_str = "" - instruction = variable_pool.convert_template(node_data.instruction or "").text - - if memory and node_data.memory and node_data.memory.window: - memory_str = llm_utils.fetch_memory_text( - memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size - ) - if node_data.model.mode == LLMMode.CHAT: - system_prompt_messages = LLMNodeChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction), - ) - user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text) - return [system_prompt_messages, user_prompt_message] - if node_data.model.mode == LLMMode.COMPLETION: - return LLMNodeCompletionModelPromptTemplate( - text=COMPLETION_GENERATE_JSON_PROMPT.format( - histories=memory_str, text=input_text, instruction=instruction - ) - .replace("{ฮณฮณฮณ", "") - .replace("}ฮณฮณฮณ", "") - .replace("{ structure }", json.dumps(node_data.get_parameter_json_schema())), - ) - raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.") - - def _calculate_rest_token( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: PreparedLLMProtocol, - context: str | None, - ) -> int: - try: - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - except ValueError as exc: - raise ModelSchemaNotFoundError("Model schema not found") from exc - - prompt_template: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: - prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) - else: - prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) - - prompt_messages = self._compile_prompt_messages( - model_instance=model_instance, - prompt_template=prompt_template, - files=[], - vision_enabled=False, - context=context, - ) - rest_tokens = 2000 - model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + 1000 - - max_tokens = 0 - for parameter_rule in model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_instance.parameters.get(parameter_rule.name) - or model_instance.parameters.get(parameter_rule.use_template or "") - ) or 0 - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - def _compile_prompt_messages( - self, - *, - model_instance: PreparedLLMProtocol, - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, - files: Sequence[File], - vision_enabled: bool, - context: str | None = "", - image_detail_config: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - prompt_messages, _ = LLMNode.fetch_prompt_messages( - sys_query="", - sys_files=files, - context=context or "", - memory=None, - model_instance=model_instance, - prompt_template=prompt_template, - stop=model_instance.stop, - memory_config=None, - vision_enabled=vision_enabled, - vision_detail=image_detail_config or ImagePromptMessageContent.DETAIL.HIGH, - variable_pool=self.graph_runtime_state.variable_pool, - jinja2_variables=[], - ) - return list(prompt_messages) - - @property - def model_instance(self) -> PreparedLLMProtocol: - return self._model_instance - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: ParameterExtractorNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} - - if node_data.instruction: - selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) - for selector in selectors: - variable_mapping[selector.variable] = selector.value_selector - - variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} - - return variable_mapping diff --git a/api/graphon/nodes/parameter_extractor/prompts.py b/api/graphon/nodes/parameter_extractor/prompts.py deleted file mode 100644 index 1b29be4418d..00000000000 --- a/api/graphon/nodes/parameter_extractor/prompts.py +++ /dev/null @@ -1,184 +0,0 @@ -from typing import Any - -FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters" - -FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. -### Task -Always call the `{FUNCTION_CALLING_EXTRACTOR_NAME}` function with the correct parameters. Ensure that the information extraction is contextual and aligns with the provided criteria. -### Memory -Here is the chat history between the human and assistant, provided within tags: - -\x7bhistories\x7d - -### Instructions: -Some additional information is provided below. Always adhere to these instructions as closely as possible: - -\x7binstruction\x7d - -Steps: -1. Review the chat history provided within the tags. -2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text. -3. Generate a well-formatted output using the defined functions and arguments. -4. Use the `extract_parameter` function to create structured outputs with appropriate parameters. -5. Do not include any XML tags in your output. -### Example -To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples. -### Final Output -Produce well-formatted function calls in json without XML tags, as shown in the example. -""" # noqa: E501 - -FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside XML tags. - -\x7bcontent\x7d - - - -\x7bstructure\x7d - -""" # noqa: E501 - -FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [ - { - "user": { - "query": "What is the weather today in SF?", - "function": { - "name": FUNCTION_CALLING_EXTRACTOR_NAME, - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The location to get the weather information", - "required": True, - }, - }, - "required": ["location"], - }, - }, - }, - "assistant": { - "text": "I need always call the function with the correct parameters." - " in this case, I need to call the function with the location parameter.", - "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"location": "San Francisco"}}, - }, - }, - { - "user": { - "query": "I want to eat some apple pie.", - "function": { - "name": FUNCTION_CALLING_EXTRACTOR_NAME, - "parameters": { - "type": "object", - "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, - "required": ["food"], - }, - }, - }, - "assistant": { - "text": "I need always call the function with the correct parameters." - " in this case, I need to call the function with the food parameter.", - "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"food": "apple pie"}}, - }, - }, -] - -COMPLETION_GENERATE_JSON_PROMPT = """### Instructions: -Some extra information are provided below, I should always follow the instructions as possible as I can. - -{instruction} - - -### Extract parameter Workflow -I need to extract the following information from the input text. The tag specifies the 'type', 'description' and 'required' of the information to be extracted. - -{{ structure }} - - -Step 1: Carefully read the input and understand the structure of the expected output. -Step 2: Extract relevant parameters from the provided text based on the name and description of object. -Step 3: Structure the extracted parameters to JSON object as specified in . -Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted. - -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - - -### Structure -Here is the structure of the expected output, I should always follow the output structure. -{{ฮณฮณฮณ - 'properties1': 'relevant text extracted from input', - 'properties2': 'relevant text extracted from input', -}}ฮณฮณฮณ - -### Input Text -Inside XML tags, there is a text that I should extract parameters and convert to a JSON object. - -{text} - - -### Answer -I should always output a valid JSON object. Output nothing other than the JSON object. -```JSON -""" # noqa: E501 - -CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object. -The structure of the JSON object you can found in the instructions. - -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - - -### Instructions: -Some extra information are provided below, you should always follow the instructions as possible as you can. - -{instructions} - -""" - -CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE = """### Structure -Here is the structure of the JSON object, you should always follow the structure. - -{structure} - - -### Text to be converted to JSON -Inside XML tags, there is a text that you should convert to a JSON object. - -{text} - -""" - -CHAT_EXAMPLE = [ - { - "user": { - "query": "What is the weather today in SF?", - "json": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The location to get the weather information", - "required": True, - } - }, - "required": ["location"], - }, - }, - "assistant": {"text": "I need to output a valid JSON object.", "json": {"location": "San Francisco"}}, - }, - { - "user": { - "query": "I want to eat some apple pie.", - "json": { - "type": "object", - "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, - "required": ["food"], - }, - }, - "assistant": {"text": "I need to output a valid JSON object.", "json": {"food": "apple pie"}}, - }, -] diff --git a/api/graphon/nodes/protocols.py b/api/graphon/nodes/protocols.py deleted file mode 100644 index 4b050c113c7..00000000000 --- a/api/graphon/nodes/protocols.py +++ /dev/null @@ -1,46 +0,0 @@ -from collections.abc import Generator, Mapping -from typing import Any, Protocol - -import httpx - -from graphon.file import File - - -class HttpClientProtocol(Protocol): - @property - def max_retries_exceeded_error(self) -> type[Exception]: ... - - @property - def request_error(self) -> type[Exception]: ... - - def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def delete(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def patch(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - -class FileManagerProtocol(Protocol): - def download(self, f: File, /) -> bytes: ... - - -class ToolFileManagerProtocol(Protocol): - def create_file_by_raw( - self, - *, - file_binary: bytes, - mimetype: str, - filename: str | None = None, - ) -> Any: ... - - def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, File | None]: ... - - -class FileReferenceFactoryProtocol(Protocol): - def build_from_mapping(self, *, mapping: Mapping[str, Any]) -> File: ... diff --git a/api/graphon/nodes/question_classifier/__init__.py b/api/graphon/nodes/question_classifier/__init__.py deleted file mode 100644 index 4d06b6bea36..00000000000 --- a/api/graphon/nodes/question_classifier/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .entities import QuestionClassifierNodeData -from .question_classifier_node import QuestionClassifierNode - -__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeData"] diff --git a/api/graphon/nodes/question_classifier/entities.py b/api/graphon/nodes/question_classifier/entities.py deleted file mode 100644 index 8d5f1173157..00000000000 --- a/api/graphon/nodes/question_classifier/entities.py +++ /dev/null @@ -1,30 +0,0 @@ -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.llm import ModelConfig, VisionConfig -from graphon.prompt_entities import MemoryConfig - - -class ClassConfig(BaseModel): - id: str - name: str - - -class QuestionClassifierNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.QUESTION_CLASSIFIER - query_variable_selector: list[str] - model: ModelConfig - classes: list[ClassConfig] - instruction: str | None = None - memory: MemoryConfig | None = None - vision: VisionConfig = Field(default_factory=VisionConfig) - - @property - def structured_output_enabled(self) -> bool: - # NOTE(QuantumGhost): Temporary workaround for issue #20725 - # (https://github.com/langgenius/dify/issues/20725). - # - # The proper fix would be to make `QuestionClassifierNode` inherit - # from `BaseNode` instead of `LLMNode`. - return False diff --git a/api/graphon/nodes/question_classifier/exc.py b/api/graphon/nodes/question_classifier/exc.py deleted file mode 100644 index 2c6354e2a70..00000000000 --- a/api/graphon/nodes/question_classifier/exc.py +++ /dev/null @@ -1,6 +0,0 @@ -class QuestionClassifierNodeError(ValueError): - """Base class for QuestionClassifierNode errors.""" - - -class InvalidModelTypeError(QuestionClassifierNodeError): - """Raised when the model is not a Large Language Model.""" diff --git a/api/graphon/nodes/question_classifier/question_classifier_node.py b/api/graphon/nodes/question_classifier/question_classifier_node.py deleted file mode 100644 index a30ffbb1495..00000000000 --- a/api/graphon/nodes/question_classifier/question_classifier_node.py +++ /dev/null @@ -1,395 +0,0 @@ -import json -import re -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ( - BuiltinNodeTypes, - NodeExecutionType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.model_runtime.entities import LLMMode, LLMUsage, ModelPropertyKey, PromptMessageRole -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.node_events import ModelInvokeCompletedEvent, NodeRunResult -from graphon.nodes.base.entities import VariableSelector -from graphon.nodes.base.node import Node -from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from graphon.nodes.llm import ( - LLMNode, - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - llm_utils, -) -from graphon.nodes.llm.file_saver import LLMFileSaver -from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol -from graphon.nodes.protocols import HttpClientProtocol -from graphon.template_rendering import Jinja2TemplateRenderer -from graphon.utils.json_in_md_parser import parse_and_check_json_markdown - -from .entities import QuestionClassifierNodeData -from .exc import InvalidModelTypeError -from .template_prompts import ( - QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, - QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2, - QUESTION_CLASSIFIER_COMPLETION_PROMPT, - QUESTION_CLASSIFIER_SYSTEM_PROMPT, - QUESTION_CLASSIFIER_USER_PROMPT_1, - QUESTION_CLASSIFIER_USER_PROMPT_2, - QUESTION_CLASSIFIER_USER_PROMPT_3, -) - -if TYPE_CHECKING: - from graphon.file.models import File - from graphon.runtime import GraphRuntimeState - - -class _PassthroughPromptMessageSerializer: - def serialize(self, *, model_mode: Any, prompt_messages: Sequence[Any]) -> Any: - _ = model_mode - return list(prompt_messages) - - -class QuestionClassifierNode(Node[QuestionClassifierNodeData]): - node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER - execution_type = NodeExecutionType.BRANCH - - _file_outputs: list["File"] - _llm_file_saver: LLMFileSaver - _prompt_message_serializer: PromptMessageSerializerProtocol - _model_instance: PreparedLLMProtocol - _memory: PromptMessageMemory | None - _template_renderer: Jinja2TemplateRenderer - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - credentials_provider: object | None = None, - model_factory: object | None = None, - model_instance: PreparedLLMProtocol, - http_client: HttpClientProtocol, - template_renderer: Jinja2TemplateRenderer, - memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver, - prompt_message_serializer: PromptMessageSerializerProtocol | None = None, - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - # LLM file outputs, used for MultiModal outputs. - self._file_outputs = [] - - _ = credentials_provider, model_factory, http_client - self._model_instance = model_instance - self._memory = memory - self._template_renderer = template_renderer - - self._llm_file_saver = llm_file_saver - self._prompt_message_serializer = prompt_message_serializer or _PassthroughPromptMessageSerializer() - - @classmethod - def version(cls): - return "1" - - def _run(self): - node_data = self.node_data - variable_pool = self.graph_runtime_state.variable_pool - - # extract variables - variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None - query = variable.value if variable else None - variables = {"query": query} - # fetch model instance - model_instance = self._model_instance - # Resolve variable references in string-typed completion params - model_instance.parameters = llm_utils.resolve_completion_params_variables( - model_instance.parameters, variable_pool - ) - memory = self._memory - # fetch instruction - node_data.instruction = node_data.instruction or "" - node_data.instruction = variable_pool.convert_template(node_data.instruction).text - - files = ( - llm_utils.fetch_files( - variable_pool=variable_pool, - selector=node_data.vision.configs.variable_selector, - ) - if node_data.vision.enabled - else [] - ) - - # fetch prompt messages - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query or "", - model_instance=model_instance, - context="", - ) - prompt_template = self._get_prompt_template( - node_data=node_data, - query=query or "", - memory=memory, - max_token_limit=rest_token, - ) - # Some models (e.g. Gemma, Mistral) force roles alternation (user/assistant/user/assistant...). - # If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt, - # two consecutive user prompts will be generated, causing model's error. - # To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end. - prompt_messages, stop = llm_utils.fetch_prompt_messages( - prompt_template=prompt_template, - sys_query="", - memory=memory, - model_instance=model_instance, - stop=model_instance.stop, - sys_files=files, - vision_enabled=node_data.vision.enabled, - vision_detail=node_data.vision.configs.detail, - variable_pool=variable_pool, - jinja2_variables=[], - template_renderer=self._template_renderer, - ) - - result_text = "" - usage = LLMUsage.empty_usage() - finish_reason = None - - try: - # handle invoke result - generator = LLMNode.invoke_llm( - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - structured_output_enabled=False, - structured_output=None, - file_saver=self._llm_file_saver, - file_outputs=self._file_outputs, - node_id=self._node_id, - node_type=self.node_type, - ) - - for event in generator: - if isinstance(event, ModelInvokeCompletedEvent): - result_text = event.text - usage = event.usage - finish_reason = event.finish_reason - break - - rendered_classes = [ - c.model_copy(update={"name": variable_pool.convert_template(c.name).text}) for c in node_data.classes - ] - - category_name = rendered_classes[0].name - category_id = rendered_classes[0].id - if "" in result_text: - result_text = re.sub(r"]*>[\s\S]*?", "", result_text, flags=re.IGNORECASE) - result_text_json = parse_and_check_json_markdown(result_text, []) - # result_text_json = json.loads(result_text.strip('```JSON\n')) - if "category_name" in result_text_json and "category_id" in result_text_json: - category_id_result = result_text_json["category_id"] - classes = rendered_classes - classes_map = {class_.id: class_.name for class_ in classes} - category_ids = [_class.id for _class in classes] - if category_id_result in category_ids: - category_name = classes_map[category_id_result] - category_id = category_id_result - process_data = { - "model_mode": node_data.model.mode, - "prompts": self._prompt_message_serializer.serialize( - model_mode=node_data.model.mode, prompt_messages=prompt_messages - ), - "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, - "model_provider": model_instance.provider, - "model_name": model_instance.model_name, - } - outputs = { - "class_name": category_name, - "class_id": category_id, - "usage": jsonable_encoder(usage), - } - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - process_data=process_data, - outputs=outputs, - edge_source_handle=category_id, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - except ValueError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error=str(e), - error_type=type(e).__name__, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - - @property - def model_instance(self) -> PreparedLLMProtocol: - return self._model_instance - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: QuestionClassifierNodeData, - ) -> Mapping[str, Sequence[str]]: - # graph_config is not used in this node type - variable_mapping = {"query": node_data.query_variable_selector} - variable_selectors: list[VariableSelector] = [] - if node_data.instruction: - variable_template_parser = VariableTemplateParser(template=node_data.instruction) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = list(variable_selector.value_selector) - - variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} - - return variable_mapping - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - """ - Get default config of node. - :param filters: filter by node config parameters (not used in this implementation). - :return: - """ - # filters parameter is not used in this node type - return {"type": "question-classifier", "config": {"instructions": ""}} - - def _calculate_rest_token( - self, - node_data: QuestionClassifierNodeData, - query: str, - model_instance: PreparedLLMProtocol, - context: str | None, - ) -> int: - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - - prompt_template = self._get_prompt_template(node_data, query, None, 2000) - prompt_messages, _ = llm_utils.fetch_prompt_messages( - prompt_template=prompt_template, - sys_query="", - sys_files=[], - context=context or "", - memory=None, - model_instance=model_instance, - stop=model_instance.stop, - memory_config=node_data.memory, - vision_enabled=False, - vision_detail=node_data.vision.configs.detail, - variable_pool=self.graph_runtime_state.variable_pool, - jinja2_variables=[], - template_renderer=self._template_renderer, - ) - rest_tokens = 2000 - - model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_instance.parameters.get(parameter_rule.name) - or model_instance.parameters.get(parameter_rule.use_template or "") - ) or 0 - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - def _get_prompt_template( - self, - node_data: QuestionClassifierNodeData, - query: str, - memory: PromptMessageMemory | None, - max_token_limit: int = 2000, - ): - model_mode = LLMMode(node_data.model.mode) - classes = node_data.classes - categories = [] - for class_ in classes: - category = {"category_id": class_.id, "category_name": class_.name} - categories.append(category) - instruction = node_data.instruction or "" - input_text = query - memory_str = "" - if memory: - memory_str = llm_utils.fetch_memory_text( - memory=memory, - max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None, - ) - prompt_messages: list[LLMNodeChatModelMessage] = [] - if model_mode == LLMMode.CHAT: - system_prompt_messages = LLMNodeChatModelMessage( - role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) - ) - prompt_messages.append(system_prompt_messages) - user_prompt_message_1 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1 - ) - prompt_messages.append(user_prompt_message_1) - assistant_prompt_message_1 = LLMNodeChatModelMessage( - role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 - ) - prompt_messages.append(assistant_prompt_message_1) - user_prompt_message_2 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2 - ) - prompt_messages.append(user_prompt_message_2) - assistant_prompt_message_2 = LLMNodeChatModelMessage( - role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 - ) - prompt_messages.append(assistant_prompt_message_2) - user_prompt_message_3 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_3.format( - input_text=input_text, - categories=json.dumps(categories, ensure_ascii=False), - classification_instructions=instruction, - ), - ) - prompt_messages.append(user_prompt_message_3) - return prompt_messages - elif model_mode == LLMMode.COMPLETION: - return LLMNodeCompletionModelPromptTemplate( - text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( - histories=memory_str, - input_text=input_text, - categories=json.dumps(categories, ensure_ascii=False), - classification_instructions=instruction, - ) - ) - - else: - raise InvalidModelTypeError(f"Model mode {model_mode} not support.") diff --git a/api/graphon/nodes/question_classifier/template_prompts.py b/api/graphon/nodes/question_classifier/template_prompts.py deleted file mode 100644 index a615c323836..00000000000 --- a/api/graphon/nodes/question_classifier/template_prompts.py +++ /dev/null @@ -1,76 +0,0 @@ -QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ -### Job Description', -You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. -### Task -Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification. -### Format -The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy. -### Constraint -DO NOT include anything other than the JSON array in your response. -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - -""" # noqa: E501 - -QUESTION_CLASSIFIER_USER_PROMPT_1 = """ - {"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], - "categories": [{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"},{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"},{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"},{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}], - "classification_instructions": ["classify the text based on the feedback provided by customer"]} -""" # noqa: E501 - -QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """ -```json - {"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"], - "category_id": "f5660049-284f-41a7-b301-fd24176a711c", - "category_name": "Customer Service"} -``` -""" - -QUESTION_CLASSIFIER_USER_PROMPT_2 = """ - {"input_text": ["bad service, slow to bring the food"], - "categories": [{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"},{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"},{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}], - "classification_instructions": []} -""" # noqa: E501 - -QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """ -```json - {"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"], - "category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f", - "category_name": "Experience"} -``` -""" - -QUESTION_CLASSIFIER_USER_PROMPT_3 = """ - {{"input_text": ["{input_text}"], - "categories": {categories}, - "classification_instructions": ["{classification_instructions}"]}} -""" - -QUESTION_CLASSIFIER_COMPLETION_PROMPT = """ -### Job Description -You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. -### Task -Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification. -### Format -The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy. -### Constraint -DO NOT include anything other than the JSON array in your response. -### Example -Here is the chat example between human and assistant, inside XML tags. - -User:{{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], "categories": [{{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"}},{{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"}},{{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"}},{{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}}], "classification_instructions": ["classify the text based on the feedback provided by customer"]}} -Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"category_id": "f5660049-284f-41a7-b301-fd24176a711c","category_name": "Customer Service"}} -User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"}},{{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"}},{{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}}], "classification_instructions": []}} -Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}} - -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - -### User Input -{{"input_text" : ["{input_text}"], "categories" : {categories},"classification_instruction" : ["{classification_instructions}"]}} -### Assistant Output -""" # noqa: E501 diff --git a/api/graphon/nodes/runtime.py b/api/graphon/nodes/runtime.py deleted file mode 100644 index 650299898c4..00000000000 --- a/api/graphon/nodes/runtime.py +++ /dev/null @@ -1,106 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator, Mapping, Sequence -from datetime import datetime -from typing import TYPE_CHECKING, Any, Protocol - -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.nodes.tool_runtime_entities import ( - ToolRuntimeHandle, - ToolRuntimeMessage, - ToolRuntimeParameter, -) - -if TYPE_CHECKING: - from graphon.nodes.human_input.entities import HumanInputNodeData - from graphon.nodes.human_input.enums import HumanInputFormStatus - from graphon.nodes.tool.entities import ToolNodeData - from graphon.runtime import VariablePool - - -class ToolNodeRuntimeProtocol(Protocol): - """Workflow-layer adapter owned by `core.workflow` and consumed by `graphon`. - - The graph package depends only on these DTOs and lets the workflow layer - translate between graph-owned abstractions and `core.tools` internals. - """ - - def get_runtime( - self, - *, - node_id: str, - node_data: ToolNodeData, - variable_pool: VariablePool | None, - ) -> ToolRuntimeHandle: ... - - def get_runtime_parameters( - self, - *, - tool_runtime: ToolRuntimeHandle, - ) -> Sequence[ToolRuntimeParameter]: ... - - def invoke( - self, - *, - tool_runtime: ToolRuntimeHandle, - tool_parameters: Mapping[str, Any], - workflow_call_depth: int, - provider_name: str, - ) -> Generator[ToolRuntimeMessage, None, None]: ... - - def get_usage( - self, - *, - tool_runtime: ToolRuntimeHandle, - ) -> LLMUsage: ... - - def build_file_reference(self, *, mapping: Mapping[str, Any]) -> Any: ... - - def resolve_provider_icons( - self, - *, - provider_name: str, - default_icon: str | None = None, - ) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]: ... - - -class HumanInputNodeRuntimeProtocol(Protocol): - """Workflow-layer adapter for human-input runtime persistence and delivery.""" - - def get_form( - self, - *, - node_id: str, - ) -> HumanInputFormStateProtocol | None: ... - - def create_form( - self, - *, - node_id: str, - node_data: HumanInputNodeData, - rendered_content: str, - resolved_default_values: Mapping[str, Any], - ) -> HumanInputFormStateProtocol: ... - - -class HumanInputFormStateProtocol(Protocol): - @property - def id(self) -> str: ... - - @property - def rendered_content(self) -> str: ... - - @property - def selected_action_id(self) -> str | None: ... - - @property - def submitted_data(self) -> Mapping[str, Any] | None: ... - - @property - def submitted(self) -> bool: ... - - @property - def status(self) -> HumanInputFormStatus: ... - - @property - def expiration_time(self) -> datetime: ... diff --git a/api/graphon/nodes/start/__init__.py b/api/graphon/nodes/start/__init__.py deleted file mode 100644 index 54117804231..00000000000 --- a/api/graphon/nodes/start/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .start_node import StartNode - -__all__ = ["StartNode"] diff --git a/api/graphon/nodes/start/entities.py b/api/graphon/nodes/start/entities.py deleted file mode 100644 index 7df62e1b2bb..00000000000 --- a/api/graphon/nodes/start/entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from collections.abc import Sequence - -from pydantic import Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.variables.input_entities import VariableEntity - - -class StartNodeData(BaseNodeData): - """ - Start Node Data - """ - - type: NodeType = BuiltinNodeTypes.START - variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/graphon/nodes/start/start_node.py b/api/graphon/nodes/start/start_node.py deleted file mode 100644 index cb3f4c1e7dd..00000000000 --- a/api/graphon/nodes/start/start_node.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Any - -from jsonschema import Draft7Validator, ValidationError - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.start.entities import StartNodeData -from graphon.variables.input_entities import VariableEntityType - - -class StartNode(Node[StartNodeData]): - node_type = BuiltinNodeTypes.START - execution_type = NodeExecutionType.ROOT - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) - self._validate_and_normalize_json_object_inputs(node_inputs) - outputs = dict(self.graph_runtime_state.variable_pool.flatten(unprefixed_node_id=self.id)) - outputs.update(node_inputs) - - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs) - - def _validate_and_normalize_json_object_inputs(self, node_inputs: dict[str, Any]) -> None: - for variable in self.node_data.variables: - if variable.type != VariableEntityType.JSON_OBJECT: - continue - - key = variable.variable - value = node_inputs.get(key) - - if value is None and variable.required: - raise ValueError(f"{key} is required in input form") - - # If no value provided, skip further processing for this key - if not value: - continue - - if not isinstance(value, dict): - raise ValueError(f"JSON object for '{key}' must be an object") - - # Overwrite with normalized dict to ensure downstream consistency - node_inputs[key] = value - - # If schema exists, then validate against it - schema = variable.json_schema - if not schema: - continue - - try: - Draft7Validator(schema).validate(value) - except ValidationError as e: - raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}") diff --git a/api/graphon/nodes/template_transform/__init__.py b/api/graphon/nodes/template_transform/__init__.py deleted file mode 100644 index 43863b9d59a..00000000000 --- a/api/graphon/nodes/template_transform/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .template_transform_node import TemplateTransformNode - -__all__ = ["TemplateTransformNode"] diff --git a/api/graphon/nodes/template_transform/entities.py b/api/graphon/nodes/template_transform/entities.py deleted file mode 100644 index a27a57f34fe..00000000000 --- a/api/graphon/nodes/template_transform/entities.py +++ /dev/null @@ -1,13 +0,0 @@ -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base.entities import VariableSelector - - -class TemplateTransformNodeData(BaseNodeData): - """ - Template Transform Node Data. - """ - - type: NodeType = BuiltinNodeTypes.TEMPLATE_TRANSFORM - variables: list[VariableSelector] - template: str diff --git a/api/graphon/nodes/template_transform/template_transform_node.py b/api/graphon/nodes/template_transform/template_transform_node.py deleted file mode 100644 index 4206fb0c1a5..00000000000 --- a/api/graphon/nodes/template_transform/template_transform_node.py +++ /dev/null @@ -1,119 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.entities import VariableSelector -from graphon.nodes.base.node import Node -from graphon.nodes.template_transform.entities import TemplateTransformNodeData -from graphon.template_rendering import ( - Jinja2TemplateRenderer, - TemplateRenderError, -) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - -DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000 - - -class TemplateTransformNode(Node[TemplateTransformNodeData]): - node_type = BuiltinNodeTypes.TEMPLATE_TRANSFORM - _jinja2_template_renderer: Jinja2TemplateRenderer - _max_output_length: int - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - jinja2_template_renderer: Jinja2TemplateRenderer, - max_output_length: int | None = None, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._jinja2_template_renderer = jinja2_template_renderer - - if max_output_length is not None and max_output_length <= 0: - raise ValueError("max_output_length must be a positive integer") - self._max_output_length = max_output_length or DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ - return { - "type": "template-transform", - "config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"}, - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - # Get variables - variables: dict[str, Any] = {} - for variable_selector in self.node_data.variables: - variable_name = variable_selector.variable - value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - variables[variable_name] = value.to_object() if value else None - # Run code - try: - rendered = self._jinja2_template_renderer.render_template(self.node_data.template, variables) - except TemplateRenderError as e: - return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) - - if len(rendered) > self._max_output_length: - return NodeRunResult( - inputs=variables, - status=WorkflowNodeExecutionStatus.FAILED, - error=f"Output length exceeds {self._max_output_length} characters", - ) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered} - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: TemplateTransformNodeData | Mapping[str, Any], - ) -> Mapping[str, Sequence[str]]: - _ = graph_config - raw_variables = ( - node_data.variables if isinstance(node_data, TemplateTransformNodeData) else node_data.get("variables", []) - ) - variable_mapping: dict[str, Sequence[str]] = {} - for variable_selector in raw_variables: - if isinstance(variable_selector, VariableSelector): - variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector - continue - - if not isinstance(variable_selector, Mapping): - continue - - variable = variable_selector.get("variable") - value_selector = variable_selector.get("value_selector") - if ( - isinstance(variable, str) - and isinstance(value_selector, Sequence) - and all(isinstance(selector_part, str) for selector_part in value_selector) - ): - variable_mapping[node_id + "." + variable] = list(value_selector) - - return variable_mapping diff --git a/api/graphon/nodes/tool/__init__.py b/api/graphon/nodes/tool/__init__.py deleted file mode 100644 index f4982e655d1..00000000000 --- a/api/graphon/nodes/tool/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .tool_node import ToolNode - -__all__ = ["ToolNode"] diff --git a/api/graphon/nodes/tool/entities.py b/api/graphon/nodes/tool/entities.py deleted file mode 100644 index 54e60480339..00000000000 --- a/api/graphon/nodes/tool/entities.py +++ /dev/null @@ -1,101 +0,0 @@ -from enum import StrEnum, auto -from typing import Any, Literal, Union - -from pydantic import BaseModel, field_validator -from pydantic_core.core_schema import ValidationInfo - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - - -class ToolProviderType(StrEnum): - """ - Graph-owned enum for persisted tool provider kinds. - """ - - PLUGIN = auto() - BUILT_IN = "builtin" - WORKFLOW = auto() - API = auto() - APP = auto() - DATASET_RETRIEVAL = "dataset-retrieval" - MCP = auto() - - -class ToolEntity(BaseModel): - provider_id: str - provider_type: ToolProviderType - provider_name: str # redundancy - tool_name: str - tool_label: str # redundancy - tool_configurations: dict[str, Any] - credential_id: str | None = None - plugin_unique_identifier: str | None = None # redundancy - - @field_validator("tool_configurations", mode="before") - @classmethod - def validate_tool_configurations(cls, value, values: ValidationInfo): - if not isinstance(value, dict): - raise ValueError("tool_configurations must be a dictionary") - - for key in values.data.get("tool_configurations", {}): - value = values.data.get("tool_configurations", {}).get(key) - if not isinstance(value, str | int | float | bool): - raise ValueError(f"{key} must be a string") - - return value - - -class ToolNodeData(BaseNodeData, ToolEntity): - type: NodeType = BuiltinNodeTypes.TOOL - - class ToolInput(BaseModel): - # TODO: check this type - value: Union[Any, list[str]] - type: Literal["mixed", "variable", "constant"] - - @field_validator("type", mode="before") - @classmethod - def check_type(cls, value, validation_info: ValidationInfo): - typ = value - value = validation_info.data.get("value") - - if value is None: - return typ - - if typ == "mixed" and not isinstance(value, str): - raise ValueError("value must be a string") - elif typ == "variable": - if not isinstance(value, list): - raise ValueError("value must be a list") - for val in value: - if not isinstance(val, str): - raise ValueError("value must be a list of strings") - elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))): - raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}") - return typ - - tool_parameters: dict[str, ToolInput] - # The version of the tool parameter. - # If this value is None, it indicates this is a previous version - # and requires using the legacy parameter parsing rules. - tool_node_version: str | None = None - - @field_validator("tool_parameters", mode="before") - @classmethod - def filter_none_tool_inputs(cls, value): - if not isinstance(value, dict): - return value - - return { - key: tool_input - for key, tool_input in value.items() - if tool_input is not None and cls._has_valid_value(tool_input) - } - - @staticmethod - def _has_valid_value(tool_input): - """Check if the value is valid""" - if isinstance(tool_input, dict): - return tool_input.get("value") is not None - return getattr(tool_input, "value", None) is not None diff --git a/api/graphon/nodes/tool/exc.py b/api/graphon/nodes/tool/exc.py deleted file mode 100644 index 1a309e1084b..00000000000 --- a/api/graphon/nodes/tool/exc.py +++ /dev/null @@ -1,28 +0,0 @@ -class ToolNodeError(ValueError): - """Base exception for tool node errors.""" - - pass - - -class ToolRuntimeResolutionError(ToolNodeError): - """Raised when the workflow layer cannot construct a tool runtime.""" - - pass - - -class ToolRuntimeInvocationError(ToolNodeError): - """Raised when the workflow layer fails while invoking a tool runtime.""" - - pass - - -class ToolParameterError(ToolNodeError): - """Exception raised for errors in tool parameters.""" - - pass - - -class ToolFileError(ToolNodeError): - """Exception raised for errors related to tool files.""" - - pass diff --git a/api/graphon/nodes/tool/tool_node.py b/api/graphon/nodes/tool/tool_node.py deleted file mode 100644 index 57ab8ce5d65..00000000000 --- a/api/graphon/nodes/tool/tool_node.py +++ /dev/null @@ -1,432 +0,0 @@ -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from graphon.nodes.base.node import Node -from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from graphon.nodes.protocols import ToolFileManagerProtocol -from graphon.nodes.runtime import ToolNodeRuntimeProtocol -from graphon.nodes.tool_runtime_entities import ( - ToolRuntimeHandle, - ToolRuntimeMessage, - ToolRuntimeParameter, -) -from graphon.variables.segments import ArrayFileSegment - -from .entities import ToolNodeData -from .exc import ( - ToolFileError, - ToolNodeError, - ToolParameterError, -) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - -class ToolNode(Node[ToolNodeData]): - """ - Tool Node - """ - - node_type = BuiltinNodeTypes.TOOL - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - tool_file_manager_factory: ToolFileManagerProtocol, - runtime: ToolNodeRuntimeProtocol | None = None, - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._tool_file_manager_factory = tool_file_manager_factory - if runtime is None: - raise ValueError("runtime is required") - self._runtime = runtime - - @classmethod - def version(cls) -> str: - return "1" - - def populate_start_event(self, event) -> None: - event.provider_id = self.node_data.provider_id - event.provider_type = self.node_data.provider_type - - def _run(self) -> Generator[NodeEventBase, None, None]: - """ - Run the tool node - """ - # fetch tool icon - tool_info = { - "provider_type": self.node_data.provider_type.value, - "provider_id": self.node_data.provider_id, - "plugin_unique_identifier": self.node_data.plugin_unique_identifier, - } - - # get tool runtime - try: - # This is an issue that caused problems before. - # Logically, we shouldn't use the node_data.version field for judgment - # But for backward compatibility with historical data - # this version field judgment is still preserved here. - variable_pool: VariablePool | None = None - if self.node_data.version != "1" or self.node_data.tool_node_version is not None: - variable_pool = self.graph_runtime_state.variable_pool - tool_runtime = self._runtime.get_runtime( - node_id=self._node_id, - node_data=self.node_data, - variable_pool=variable_pool, - ) - except ToolNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs={}, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to get tool runtime: {str(e)}", - error_type=type(e).__name__, - ) - ) - return - - # get parameters - tool_parameters = self._runtime.get_runtime_parameters(tool_runtime=tool_runtime) - parameters = self._generate_parameters( - tool_parameters=tool_parameters, - variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, - ) - parameters_for_log = self._generate_parameters( - tool_parameters=tool_parameters, - variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, - for_log=True, - ) - try: - message_stream = self._runtime.invoke( - tool_runtime=tool_runtime, - tool_parameters=parameters, - workflow_call_depth=self.workflow_call_depth, - provider_name=self.node_data.provider_name, - ) - except ToolNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool: {str(e)}", - error_type=type(e).__name__, - ) - ) - return - - try: - # convert tool messages - _ = yield from self._transform_message( - messages=message_stream, - tool_info=tool_info, - parameters_for_log=parameters_for_log, - node_id=self._node_id, - tool_runtime=tool_runtime, - ) - except ToolNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=str(e), - error_type=type(e).__name__, - ) - ) - - def _generate_parameters( - self, - *, - tool_parameters: Sequence[ToolRuntimeParameter], - variable_pool: "VariablePool", - node_data: ToolNodeData, - for_log: bool = False, - ) -> dict[str, Any]: - """ - Generate parameters based on the given tool parameters, variable pool, and node data. - - Args: - tool_parameters (Sequence[ToolRuntimeParameter]): The list of tool parameters. - variable_pool (VariablePool): The variable pool containing the variables. - node_data (ToolNodeData): The data associated with the tool node. - - Returns: - Mapping[str, Any]: A dictionary containing the generated parameters. - - """ - tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} - - result: dict[str, Any] = {} - for parameter_name in node_data.tool_parameters: - parameter = tool_parameters_dictionary.get(parameter_name) - if not parameter: - result[parameter_name] = None - continue - tool_input = node_data.tool_parameters[parameter_name] - if tool_input.type == "variable": - variable = variable_pool.get(tool_input.value) - if variable is None: - if parameter.required: - raise ToolParameterError(f"Variable {tool_input.value} does not exist") - continue - parameter_value = variable.value - elif tool_input.type in {"mixed", "constant"}: - segment_group = variable_pool.convert_template(str(tool_input.value)) - parameter_value = segment_group.log if for_log else segment_group.text - else: - raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") - result[parameter_name] = parameter_value - - return result - - def _transform_message( - self, - messages: Generator[ToolRuntimeMessage, None, None], - tool_info: Mapping[str, Any], - parameters_for_log: dict[str, Any], - node_id: str, - tool_runtime: ToolRuntimeHandle, - **_: Any, - ) -> Generator[NodeEventBase, None, LLMUsage]: - """ - Convert graph-owned tool runtime messages into node outputs. - """ - text = "" - files: list[File] = [] - json: list[dict | list] = [] - - variables: dict[str, Any] = {} - - for message in messages: - if message.type in { - ToolRuntimeMessage.MessageType.IMAGE_LINK, - ToolRuntimeMessage.MessageType.BINARY_LINK, - ToolRuntimeMessage.MessageType.IMAGE, - }: - assert isinstance(message.message, ToolRuntimeMessage.TextMessage) - - url = message.message.text - if message.meta: - transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) - tool_file_id = message.meta.get("tool_file_id") - else: - transfer_method = FileTransferMethod.TOOL_FILE - tool_file_id = None - if not isinstance(tool_file_id, str) or not tool_file_id: - raise ToolFileError("tool message is missing tool_file_id metadata") - - _stream, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) - if not tool_file: - raise ToolFileError(f"tool file {tool_file_id} not found") - if tool_file.mime_type is None: - raise ToolFileError(f"tool file {tool_file_id} is missing mime type") - - file_mapping: dict[str, Any] = { - "tool_file_id": tool_file_id, - "type": get_file_type_by_mime_type(tool_file.mime_type), - "transfer_method": transfer_method, - "url": url, - } - file = self._runtime.build_file_reference(mapping=file_mapping) - files.append(file) - elif message.type == ToolRuntimeMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, ToolRuntimeMessage.TextMessage) - assert message.meta - - tool_file_id = message.meta.get("tool_file_id") - if not isinstance(tool_file_id, str) or not tool_file_id: - raise ToolFileError("tool blob message is missing tool_file_id metadata") - _stream, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) - if not tool_file: - raise ToolFileError(f"tool file {tool_file_id} not exists") - - blob_file_mapping: dict[str, Any] = { - "tool_file_id": tool_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append(self._runtime.build_file_reference(mapping=blob_file_mapping)) - elif message.type == ToolRuntimeMessage.MessageType.TEXT: - assert isinstance(message.message, ToolRuntimeMessage.TextMessage) - text += message.message.text - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=message.message.text, - is_final=False, - ) - elif message.type == ToolRuntimeMessage.MessageType.JSON: - assert isinstance(message.message, ToolRuntimeMessage.JsonMessage) - # JSON message handling for tool node - if message.message.json_object: - json.append(message.message.json_object) - elif message.type == ToolRuntimeMessage.MessageType.LINK: - assert isinstance(message.message, ToolRuntimeMessage.TextMessage) - - # Check if this LINK message is a file link - file_obj = (message.meta or {}).get("file") - if isinstance(file_obj, File): - files.append(file_obj) - stream_text = f"File: {message.message.text}\n" - else: - stream_text = f"Link: {message.message.text}\n" - - text += stream_text - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=stream_text, - is_final=False, - ) - elif message.type == ToolRuntimeMessage.MessageType.VARIABLE: - assert isinstance(message.message, ToolRuntimeMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.") - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value - - yield StreamChunkEvent( - selector=[node_id, variable_name], - chunk=variable_value, - is_final=False, - ) - else: - variables[variable_name] = variable_value - elif message.type == ToolRuntimeMessage.MessageType.FILE: - assert message.meta is not None - assert isinstance(message.meta, dict) - # Validate that meta contains a 'file' key - if "file" not in message.meta: - raise ToolNodeError("File message is missing 'file' key in meta") - - # Validate that the file is an instance of File - if not isinstance(message.meta["file"], File): - raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") - files.append(message.meta["file"]) - elif message.type == ToolRuntimeMessage.MessageType.LOG: - assert isinstance(message.message, ToolRuntimeMessage.LogMessage) - if message.message.metadata: - icon = tool_info.get("icon", "") - dict_metadata = dict(message.message.metadata) - if dict_metadata.get("provider"): - icon, icon_dark = self._runtime.resolve_provider_icons( - provider_name=dict_metadata["provider"], - default_icon=icon, - ) - dict_metadata["icon"] = icon - dict_metadata["icon_dark"] = icon_dark - message.message.metadata = dict_metadata - - # Add agent_logs to outputs['json'] to ensure frontend can access thinking process - json_output: list[dict[str, Any] | list[Any]] = [] - - # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] - if json: - json_output.extend(json) - else: - json_output.append({"data": []}) - - # Send final chunk events for all streamed outputs - # Final chunk for text stream - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk="", - is_final=True, - ) - - # Final chunks for any streamed variables - for var_name in variables: - yield StreamChunkEvent( - selector=[self._node_id, var_name], - chunk="", - is_final=True, - ) - - usage = self._runtime.get_usage(tool_runtime=tool_runtime) - - metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { - WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, - } - if isinstance(usage.total_tokens, int) and usage.total_tokens > 0: - metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens - metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price - metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, - metadata=metadata, - inputs=parameters_for_log, - llm_usage=usage, - ) - ) - - return usage - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: ToolNodeData, - ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - _ = graph_config # Explicitly mark as unused - typed_node_data = node_data - result = {} - for parameter_name in typed_node_data.tool_parameters: - input = typed_node_data.tool_parameters[parameter_name] - match input.type: - case "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - case "variable": - selector_key = ".".join(input.value) - result[f"#{selector_key}#"] = input.value - case "constant": - pass - - result = {node_id + "." + key: value for key, value in result.items()} - - return result - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled diff --git a/api/graphon/nodes/tool_runtime_entities.py b/api/graphon/nodes/tool_runtime_entities.py deleted file mode 100644 index 5bb0c165738..00000000000 --- a/api/graphon/nodes/tool_runtime_entities.py +++ /dev/null @@ -1,105 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from enum import StrEnum, auto -from typing import Any - -from pydantic import BaseModel, ConfigDict, Field - - -class _ToolRuntimeModel(BaseModel): - model_config = ConfigDict(extra="forbid") - - -@dataclass(frozen=True, slots=True) -class ToolRuntimeHandle: - """Opaque graph-owned handle for a workflow-layer tool runtime. - - Workflow-specific execution context must stay behind `raw` so the graph - contract does not absorb application-owned concepts. - """ - - raw: object - - -@dataclass(frozen=True, slots=True) -class ToolRuntimeParameter: - """Graph-owned parameter shape used by tool nodes.""" - - name: str - required: bool = False - - -class ToolRuntimeMessage(_ToolRuntimeModel): - """Graph-owned tool invocation message DTO.""" - - class TextMessage(_ToolRuntimeModel): - text: str - - class JsonMessage(_ToolRuntimeModel): - json_object: dict[str, Any] | list[Any] - suppress_output: bool = Field(default=False) - - class BlobMessage(_ToolRuntimeModel): - blob: bytes - - class BlobChunkMessage(_ToolRuntimeModel): - id: str - sequence: int - total_length: int - blob: bytes - end: bool - - class FileMessage(_ToolRuntimeModel): - file_marker: str = Field(default="file_marker") - - class VariableMessage(_ToolRuntimeModel): - variable_name: str - variable_value: dict[str, Any] | list[Any] | str | int | float | bool | None - stream: bool = Field(default=False) - - class LogMessage(_ToolRuntimeModel): - class LogStatus(StrEnum): - START = auto() - ERROR = auto() - SUCCESS = auto() - - id: str - label: str - parent_id: str | None = None - error: str | None = None - status: LogStatus - data: dict[str, Any] - metadata: dict[str, Any] = Field(default_factory=dict) - - class RetrieverResourceMessage(_ToolRuntimeModel): - retriever_resources: list[dict[str, Any]] - context: str - - class MessageType(StrEnum): - TEXT = auto() - IMAGE = auto() - LINK = auto() - BLOB = auto() - JSON = auto() - IMAGE_LINK = auto() - BINARY_LINK = auto() - VARIABLE = auto() - FILE = auto() - LOG = auto() - BLOB_CHUNK = auto() - RETRIEVER_RESOURCES = auto() - - type: MessageType = MessageType.TEXT - message: ( - JsonMessage - | TextMessage - | BlobChunkMessage - | BlobMessage - | LogMessage - | FileMessage - | None - | VariableMessage - | RetrieverResourceMessage - ) - meta: dict[str, Any] | None = None diff --git a/api/graphon/nodes/variable_aggregator/__init__.py b/api/graphon/nodes/variable_aggregator/__init__.py deleted file mode 100644 index 0b6bf2a5b62..00000000000 --- a/api/graphon/nodes/variable_aggregator/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .variable_aggregator_node import VariableAggregatorNode - -__all__ = ["VariableAggregatorNode"] diff --git a/api/graphon/nodes/variable_aggregator/entities.py b/api/graphon/nodes/variable_aggregator/entities.py deleted file mode 100644 index 136fd28f8ca..00000000000 --- a/api/graphon/nodes/variable_aggregator/entities.py +++ /dev/null @@ -1,35 +0,0 @@ -from pydantic import BaseModel - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.variables.types import SegmentType - - -class AdvancedSettings(BaseModel): - """ - Advanced setting. - """ - - group_enabled: bool - - class Group(BaseModel): - """ - Group. - """ - - output_type: SegmentType - variables: list[list[str]] - group_name: str - - groups: list[Group] - - -class VariableAggregatorNodeData(BaseNodeData): - """ - Variable Aggregator Node Data. - """ - - type: NodeType = BuiltinNodeTypes.VARIABLE_AGGREGATOR - output_type: str - variables: list[list[str]] - advanced_settings: AdvancedSettings | None = None diff --git a/api/graphon/nodes/variable_aggregator/variable_aggregator_node.py b/api/graphon/nodes/variable_aggregator/variable_aggregator_node.py deleted file mode 100644 index 71b221e1961..00000000000 --- a/api/graphon/nodes/variable_aggregator/variable_aggregator_node.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping - -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.variable_aggregator.entities import VariableAggregatorNodeData -from graphon.variables.segments import Segment - - -class VariableAggregatorNode(Node[VariableAggregatorNodeData]): - node_type = BuiltinNodeTypes.VARIABLE_AGGREGATOR - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - # Get variables - outputs: dict[str, Segment | Mapping[str, Segment]] = {} - inputs = {} - - if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: - for selector in self.node_data.variables: - variable = self.graph_runtime_state.variable_pool.get(selector) - if variable is not None: - outputs = {"output": variable} - - inputs = {".".join(selector[1:]): variable.to_object()} - break - else: - for group in self.node_data.advanced_settings.groups: - for selector in group.variables: - variable = self.graph_runtime_state.variable_pool.get(selector) - - if variable is not None: - outputs[group.group_name] = {"output": variable} - inputs[".".join(selector[1:])] = variable.to_object() - break - - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs) diff --git a/api/graphon/nodes/variable_assigner/__init__.py b/api/graphon/nodes/variable_assigner/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/graphon/nodes/variable_assigner/common/__init__.py b/api/graphon/nodes/variable_assigner/common/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/graphon/nodes/variable_assigner/common/exc.py b/api/graphon/nodes/variable_assigner/common/exc.py deleted file mode 100644 index f8dbedc2901..00000000000 --- a/api/graphon/nodes/variable_assigner/common/exc.py +++ /dev/null @@ -1,4 +0,0 @@ -class VariableOperatorNodeError(ValueError): - """Base error type, don't use directly.""" - - pass diff --git a/api/graphon/nodes/variable_assigner/common/helpers.py b/api/graphon/nodes/variable_assigner/common/helpers.py deleted file mode 100644 index 4c30e009f28..00000000000 --- a/api/graphon/nodes/variable_assigner/common/helpers.py +++ /dev/null @@ -1,55 +0,0 @@ -from collections.abc import Mapping, MutableMapping, Sequence -from typing import Any, TypeVar - -from pydantic import BaseModel - -from graphon.variables import Segment -from graphon.variables.consts import SELECTORS_LENGTH -from graphon.variables.types import SegmentType - -# Use double underscore (`__`) prefix for internal variables -# to minimize risk of collision with user-defined variable names. -_UPDATED_VARIABLES_KEY = "__updated_variables" - - -class UpdatedVariable(BaseModel): - name: str - selector: Sequence[str] - value_type: SegmentType - new_value: Any = None - - -_T = TypeVar("_T", bound=MutableMapping[str, Any]) - - -def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable: - if len(selector) < SELECTORS_LENGTH: - raise Exception("selector too short") - _, var_name = selector[:2] - return UpdatedVariable( - name=var_name, - selector=list(selector[:2]), - value_type=seg.value_type, - new_value=seg.value, - ) - - -def set_updated_variables(m: _T, updates: Sequence[UpdatedVariable]) -> _T: - m[_UPDATED_VARIABLES_KEY] = updates - return m - - -def get_updated_variables(m: Mapping[str, Any]) -> Sequence[UpdatedVariable] | None: - updated_values = m.get(_UPDATED_VARIABLES_KEY, None) - if updated_values is None: - return None - result = [] - for items in updated_values: - if isinstance(items, UpdatedVariable): - result.append(items) - elif isinstance(items, dict): - items = UpdatedVariable.model_validate(items) - result.append(items) - else: - raise TypeError(f"Invalid updated variable: {items}, type={type(items)}") - return result diff --git a/api/graphon/nodes/variable_assigner/v1/__init__.py b/api/graphon/nodes/variable_assigner/v1/__init__.py deleted file mode 100644 index 7eb1428e503..00000000000 --- a/api/graphon/nodes/variable_assigner/v1/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .node import VariableAssignerNode - -__all__ = ["VariableAssignerNode"] diff --git a/api/graphon/nodes/variable_assigner/v1/node.py b/api/graphon/nodes/variable_assigner/v1/node.py deleted file mode 100644 index 19ded5f1232..00000000000 --- a/api/graphon/nodes/variable_assigner/v1/node.py +++ /dev/null @@ -1,106 +0,0 @@ -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, cast - -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent, VariableUpdatedEvent -from graphon.nodes.base.node import Node -from graphon.nodes.variable_assigner.common import helpers as common_helpers -from graphon.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from graphon.variables import SegmentType, Variable, VariableBase - -from .node_data import VariableAssignerData, WriteMode - -if TYPE_CHECKING: - from graphon.runtime import GraphRuntimeState - - -class VariableAssignerNode(Node[VariableAssignerData]): - node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: - """ - Check if this Variable Assigner node blocks the output of specific variables. - - Returns True if this node updates any of the requested conversation variables. - """ - assigned_selector = tuple(self.node_data.assigned_variable_selector) - return assigned_selector in variable_selectors - - @classmethod - def version(cls) -> str: - return "1" - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: VariableAssignerData, - ) -> Mapping[str, Sequence[str]]: - mapping = {} - selector_key = ".".join(node_data.assigned_variable_selector) - key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.assigned_variable_selector - - selector_key = ".".join(node_data.input_variable_selector) - key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.input_variable_selector - return mapping - - def _run(self) -> Generator[NodeEventBase, None, None]: - assigned_variable_selector = self.node_data.assigned_variable_selector - # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) - if not isinstance(original_variable, VariableBase): - raise VariableOperatorNodeError("assigned variable not found") - - match self.node_data.write_mode: - case WriteMode.OVER_WRITE: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) - if not income_value: - raise VariableOperatorNodeError("input value not found") - updated_variable = original_variable.model_copy(update={"value": income_value.value}) - - case WriteMode.APPEND: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) - if not income_value: - raise VariableOperatorNodeError("input value not found") - updated_value = original_variable.value + [income_value.value] - updated_variable = original_variable.model_copy(update={"value": updated_value}) - - case WriteMode.CLEAR: - income_value = SegmentType.get_zero_value(original_variable.value_type) - updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) - - updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] - yield VariableUpdatedEvent(variable=cast(Variable, updated_variable)) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={ - "value": income_value.to_object(), - }, - # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, - # we still set `output_variables` as a list to ensure the schema of output is - # compatible with `v2.VariableAssignerNode`. - process_data=common_helpers.set_updated_variables({}, updated_variables), - outputs={}, - ) - ) diff --git a/api/graphon/nodes/variable_assigner/v1/node_data.py b/api/graphon/nodes/variable_assigner/v1/node_data.py deleted file mode 100644 index 4f630bc76c2..00000000000 --- a/api/graphon/nodes/variable_assigner/v1/node_data.py +++ /dev/null @@ -1,18 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - - -class WriteMode(StrEnum): - OVER_WRITE = "over-write" - APPEND = "append" - CLEAR = "clear" - - -class VariableAssignerData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.VARIABLE_ASSIGNER - assigned_variable_selector: Sequence[str] - write_mode: WriteMode - input_variable_selector: Sequence[str] diff --git a/api/graphon/nodes/variable_assigner/v2/__init__.py b/api/graphon/nodes/variable_assigner/v2/__init__.py deleted file mode 100644 index 7eb1428e503..00000000000 --- a/api/graphon/nodes/variable_assigner/v2/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .node import VariableAssignerNode - -__all__ = ["VariableAssignerNode"] diff --git a/api/graphon/nodes/variable_assigner/v2/entities.py b/api/graphon/nodes/variable_assigner/v2/entities.py deleted file mode 100644 index d1c68c8e8c5..00000000000 --- a/api/graphon/nodes/variable_assigner/v2/entities.py +++ /dev/null @@ -1,28 +0,0 @@ -from collections.abc import Sequence -from typing import Any - -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - -from .enums import InputType, Operation - - -class VariableOperationItem(BaseModel): - variable_selector: Sequence[str] - input_type: InputType - operation: Operation - # NOTE(QuantumGhost): The `value` field serves multiple purposes depending on context: - # - # 1. For CONSTANT input_type: Contains the literal value to be used in the operation. - # 2. For VARIABLE input_type: Initially contains the selector of the source variable. - # 3. During the variable updating procedure: The `value` field is reassigned to hold - # the resolved actual value that will be applied to the target variable. - value: Any = None - - -class VariableAssignerNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.VARIABLE_ASSIGNER - version: str = "2" - items: Sequence[VariableOperationItem] = Field(default_factory=list) diff --git a/api/graphon/nodes/variable_assigner/v2/enums.py b/api/graphon/nodes/variable_assigner/v2/enums.py deleted file mode 100644 index 291b1208d46..00000000000 --- a/api/graphon/nodes/variable_assigner/v2/enums.py +++ /dev/null @@ -1,20 +0,0 @@ -from enum import StrEnum - - -class Operation(StrEnum): - OVER_WRITE = "over-write" - CLEAR = "clear" - APPEND = "append" - EXTEND = "extend" - SET = "set" - ADD = "+=" - SUBTRACT = "-=" - MULTIPLY = "*=" - DIVIDE = "/=" - REMOVE_FIRST = "remove-first" - REMOVE_LAST = "remove-last" - - -class InputType(StrEnum): - VARIABLE = "variable" - CONSTANT = "constant" diff --git a/api/graphon/nodes/variable_assigner/v2/exc.py b/api/graphon/nodes/variable_assigner/v2/exc.py deleted file mode 100644 index 90d76485740..00000000000 --- a/api/graphon/nodes/variable_assigner/v2/exc.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Sequence -from typing import Any - -from graphon.nodes.variable_assigner.common.exc import VariableOperatorNodeError - -from .enums import InputType, Operation - - -class OperationNotSupportedError(VariableOperatorNodeError): - def __init__(self, *, operation: Operation, variable_type: str): - super().__init__(f"Operation {operation} is not supported for type {variable_type}") - - -class InputTypeNotSupportedError(VariableOperatorNodeError): - def __init__(self, *, input_type: InputType, operation: Operation): - super().__init__(f"Input type {input_type} is not supported for operation {operation}") - - -class VariableNotFoundError(VariableOperatorNodeError): - def __init__(self, *, variable_selector: Sequence[str]): - super().__init__(f"Variable {variable_selector} not found") - - -class InvalidInputValueError(VariableOperatorNodeError): - def __init__(self, *, value: Any): - super().__init__(f"Invalid input value {value}") - - -class ConversationIDNotFoundError(VariableOperatorNodeError): - def __init__(self): - super().__init__("conversation_id not found") - - -class InvalidDataError(VariableOperatorNodeError): - def __init__(self, message: str): - super().__init__(message) diff --git a/api/graphon/nodes/variable_assigner/v2/helpers.py b/api/graphon/nodes/variable_assigner/v2/helpers.py deleted file mode 100644 index ebc6c794762..00000000000 --- a/api/graphon/nodes/variable_assigner/v2/helpers.py +++ /dev/null @@ -1,98 +0,0 @@ -from typing import Any - -from graphon.variables import SegmentType - -from .enums import Operation - - -def is_operation_supported(*, variable_type: SegmentType, operation: Operation): - match operation: - case Operation.OVER_WRITE | Operation.CLEAR: - return True - case Operation.SET: - return variable_type in { - SegmentType.OBJECT, - SegmentType.STRING, - SegmentType.NUMBER, - SegmentType.INTEGER, - SegmentType.FLOAT, - SegmentType.BOOLEAN, - } - case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE: - # Only number variable can be added, subtracted, multiplied or divided - return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT} - case Operation.APPEND | Operation.EXTEND | Operation.REMOVE_FIRST | Operation.REMOVE_LAST: - # Only array variable can be appended or extended - # Only array variable can have elements removed - return variable_type.is_array_type() - - -def is_variable_input_supported(*, operation: Operation): - if operation in {Operation.SET, Operation.ADD, Operation.SUBTRACT, Operation.MULTIPLY, Operation.DIVIDE}: - return False - return True - - -def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation): - match variable_type: - case SegmentType.STRING | SegmentType.OBJECT | SegmentType.BOOLEAN: - return operation in {Operation.OVER_WRITE, Operation.SET} - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return operation in { - Operation.OVER_WRITE, - Operation.SET, - Operation.ADD, - Operation.SUBTRACT, - Operation.MULTIPLY, - Operation.DIVIDE, - } - case _: - return False - - -def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any): - if operation in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST}: - return True - match variable_type: - case SegmentType.STRING: - return isinstance(value, str) - - case SegmentType.BOOLEAN: - return isinstance(value, bool) - - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - if not isinstance(value, int | float): - return False - if operation == Operation.DIVIDE and value == 0: - return False - return True - - case SegmentType.OBJECT: - return isinstance(value, dict) - - # Array & Append - case SegmentType.ARRAY_ANY if operation == Operation.APPEND: - return isinstance(value, str | float | int | dict) - case SegmentType.ARRAY_STRING if operation == Operation.APPEND: - return isinstance(value, str) - case SegmentType.ARRAY_NUMBER if operation == Operation.APPEND: - return isinstance(value, int | float) - case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND: - return isinstance(value, dict) - case SegmentType.ARRAY_BOOLEAN if operation == Operation.APPEND: - return isinstance(value, bool) - - # Array & Extend / Overwrite - case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, str | float | int | dict) for item in value) - case SegmentType.ARRAY_STRING if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, str) for item in value) - case SegmentType.ARRAY_NUMBER if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, int | float) for item in value) - case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, dict) for item in value) - case SegmentType.ARRAY_BOOLEAN if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, bool) for item in value) - - case _: - return False diff --git a/api/graphon/nodes/variable_assigner/v2/node.py b/api/graphon/nodes/variable_assigner/v2/node.py deleted file mode 100644 index 887bd1b604e..00000000000 --- a/api/graphon/nodes/variable_assigner/v2/node.py +++ /dev/null @@ -1,257 +0,0 @@ -import json -from collections.abc import Generator, Mapping, MutableMapping, Sequence -from typing import TYPE_CHECKING, Any, cast - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent, VariableUpdatedEvent -from graphon.nodes.base.node import Node -from graphon.nodes.variable_assigner.common import helpers as common_helpers -from graphon.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from graphon.variables import SegmentType, Variable, VariableBase -from graphon.variables.consts import SELECTORS_LENGTH - -from . import helpers -from .entities import VariableAssignerNodeData, VariableOperationItem -from .enums import InputType, Operation -from .exc import ( - InputTypeNotSupportedError, - InvalidDataError, - InvalidInputValueError, - OperationNotSupportedError, - VariableNotFoundError, -) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - - -def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): - selector_str = ".".join(item.variable_selector) - key = f"{node_id}.#{selector_str}#" - mapping[key] = item.variable_selector - - -def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): - # Keep this in sync with the logic in _run methods... - if item.input_type != InputType.VARIABLE: - return - selector = item.value - if not isinstance(selector, list): - raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}") - if len(selector) < SELECTORS_LENGTH: - raise InvalidDataError(f"selector too short, {node_id=}, {item=}") - selector_str = ".".join(selector) - key = f"{node_id}.#{selector_str}#" - mapping[key] = selector - - -class VariableAssignerNode(Node[VariableAssignerNodeData]): - node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: - """ - Check if this Variable Assigner node blocks the output of specific variables. - - Returns True if this node updates any of the requested conversation variables. - """ - # Check each item in this Variable Assigner node - for item in self.node_data.items: - # Convert the item's variable_selector to tuple for comparison - item_selector_tuple = tuple(item.variable_selector) - - # Check if this item updates any of the requested variables - if item_selector_tuple in variable_selectors: - return True - - return False - - @classmethod - def version(cls) -> str: - return "2" - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: VariableAssignerNodeData, - ) -> Mapping[str, Sequence[str]]: - var_mapping: dict[str, Sequence[str]] = {} - for item in node_data.items: - _target_mapping_from_item(var_mapping, node_id, item) - _source_mapping_from_item(var_mapping, node_id, item) - return var_mapping - - def _run(self) -> Generator[NodeEventBase, None, None]: - inputs = self.node_data.model_dump() - process_data: dict[str, Any] = {} - # NOTE: This node has no outputs - updated_variable_selectors: list[Sequence[str]] = [] - # Preserve intra-node read-after-write behavior without mutating the shared pool - # until the engine processes the emitted VariableUpdatedEvent instances. - working_variable_pool = self.graph_runtime_state.variable_pool.model_copy(deep=True) - - try: - for item in self.node_data.items: - variable = working_variable_pool.get(item.variable_selector) - - # ==================== Validation Part - - # Check if variable exists - if not isinstance(variable, VariableBase): - raise VariableNotFoundError(variable_selector=item.variable_selector) - - # Check if operation is supported - if not helpers.is_operation_supported(variable_type=variable.value_type, operation=item.operation): - raise OperationNotSupportedError(operation=item.operation, variable_type=variable.value_type) - - # Check if variable input is supported - if item.input_type == InputType.VARIABLE and not helpers.is_variable_input_supported( - operation=item.operation - ): - raise InputTypeNotSupportedError(input_type=InputType.VARIABLE, operation=item.operation) - - # Check if constant input is supported - if item.input_type == InputType.CONSTANT and not helpers.is_constant_input_supported( - variable_type=variable.value_type, operation=item.operation - ): - raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation) - - # Get value from variable pool - input_value = item.value - if ( - item.input_type == InputType.VARIABLE - and item.operation not in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST} - and item.value is not None - ): - value = working_variable_pool.get(item.value) - if value is None: - raise VariableNotFoundError(variable_selector=item.value) - # Skip if value is NoneSegment - if value.value_type == SegmentType.NONE: - continue - input_value = value.value - - # If set string / bytes / bytearray to object, try convert string to object. - if ( - item.operation == Operation.SET - and variable.value_type == SegmentType.OBJECT - and isinstance(input_value, str | bytes | bytearray) - ): - try: - input_value = json.loads(input_value) - except json.JSONDecodeError: - raise InvalidInputValueError(value=input_value) - - # Check if input value is valid - if not helpers.is_input_value_valid( - variable_type=variable.value_type, operation=item.operation, value=input_value - ): - raise InvalidInputValueError(value=input_value) - - # ==================== Execution Part - - updated_value = self._handle_item( - variable=variable, - operation=item.operation, - value=input_value, - ) - updated_variable = variable.model_copy(update={"value": updated_value}) - working_variable_pool.add(updated_variable.selector, updated_variable) - updated_variable_selectors.append(updated_variable.selector) - except VariableOperatorNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - error=str(e), - ) - ) - return - - # The `updated_variable_selectors` is a list contains list[str] which not hashable, - # remove duplicated items while preserving the first update order. - updated_variable_selectors = list(dict.fromkeys(map(tuple, updated_variable_selectors))) - - for selector in updated_variable_selectors: - variable = working_variable_pool.get(selector) - if not isinstance(variable, VariableBase): - raise VariableNotFoundError(variable_selector=selector) - process_data[variable.name] = variable.value - - updated_variables = [ - common_helpers.variable_to_processed_data(selector, seg) - for selector in updated_variable_selectors - if (seg := working_variable_pool.get(selector)) is not None - ] - - process_data = common_helpers.set_updated_variables(process_data, updated_variables) - for selector in updated_variable_selectors: - variable = working_variable_pool.get(selector) - if not isinstance(variable, VariableBase): - raise VariableNotFoundError(variable_selector=selector) - yield VariableUpdatedEvent(variable=cast(Variable, variable)) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={}, - ) - ) - - def _handle_item( - self, - *, - variable: VariableBase, - operation: Operation, - value: Any, - ): - match operation: - case Operation.OVER_WRITE: - return value - case Operation.CLEAR: - return SegmentType.get_zero_value(variable.value_type).to_object() - case Operation.APPEND: - return variable.value + [value] - case Operation.EXTEND: - return variable.value + value - case Operation.SET: - return value - case Operation.ADD: - return variable.value + value - case Operation.SUBTRACT: - return variable.value - value - case Operation.MULTIPLY: - return variable.value * value - case Operation.DIVIDE: - return variable.value / value - case Operation.REMOVE_FIRST: - # If array is empty, do nothing - if not variable.value: - return variable.value - return variable.value[1:] - case Operation.REMOVE_LAST: - # If array is empty, do nothing - if not variable.value: - return variable.value - return variable.value[:-1] diff --git a/api/graphon/prompt_entities.py b/api/graphon/prompt_entities.py deleted file mode 100644 index 2b8b106c6c7..00000000000 --- a/api/graphon/prompt_entities.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel - -from graphon.model_runtime.entities.message_entities import PromptMessageRole - - -class ChatModelMessage(BaseModel): - """Graph-owned chat prompt template message.""" - - text: str - role: PromptMessageRole - edition_type: Literal["basic", "jinja2"] | None = None - - -class CompletionModelPromptTemplate(BaseModel): - """Graph-owned completion prompt template.""" - - text: str - edition_type: Literal["basic", "jinja2"] | None = None - - -class MemoryConfig(BaseModel): - """Graph-owned memory configuration for prompt assembly.""" - - class RolePrefix(BaseModel): - """Role labels used when serializing completion-model histories.""" - - user: str - assistant: str - - class WindowConfig(BaseModel): - """History windowing controls.""" - - enabled: bool - size: int | None = None - - role_prefix: RolePrefix | None = None - window: WindowConfig - query_prompt_template: str | None = None - - -__all__ = [ - "ChatModelMessage", - "CompletionModelPromptTemplate", - "MemoryConfig", -] diff --git a/api/graphon/runtime/__init__.py b/api/graphon/runtime/__init__.py deleted file mode 100644 index adca07e59a7..00000000000 --- a/api/graphon/runtime/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .graph_runtime_state import ( - ChildEngineBuilderNotConfiguredError, - ChildEngineError, - ChildGraphNotFoundError, - GraphRuntimeState, -) -from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool -from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper -from .variable_pool import VariablePool, VariableValue - -__all__ = [ - "ChildEngineBuilderNotConfiguredError", - "ChildEngineError", - "ChildGraphNotFoundError", - "GraphRuntimeState", - "ReadOnlyGraphRuntimeState", - "ReadOnlyGraphRuntimeStateWrapper", - "ReadOnlyVariablePool", - "ReadOnlyVariablePoolWrapper", - "VariablePool", - "VariableValue", -] diff --git a/api/graphon/runtime/graph_runtime_state.py b/api/graphon/runtime/graph_runtime_state.py deleted file mode 100644 index 8453830f284..00000000000 --- a/api/graphon/runtime/graph_runtime_state.py +++ /dev/null @@ -1,704 +0,0 @@ -from __future__ import annotations - -import importlib -import json -from collections.abc import Mapping, Sequence -from contextlib import AbstractContextManager, nullcontext -from copy import deepcopy -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Protocol - -from pydantic import BaseModel, Field -from pydantic.json import pydantic_encoder - -from graphon.enums import NodeExecutionType, NodeState, NodeType -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime.variable_pool import VariablePool - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.entities.pause_reason import PauseReason - - -class ReadyQueueProtocol(Protocol): - """Structural interface required from ready queue implementations.""" - - def put(self, item: str) -> None: - """Enqueue the identifier of a node that is ready to run.""" - ... - - def get(self, timeout: float | None = None) -> str: - """Return the next node identifier, blocking until available or timeout expires.""" - ... - - def task_done(self) -> None: - """Signal that the most recently dequeued node has completed processing.""" - ... - - def empty(self) -> bool: - """Return True when the queue contains no pending nodes.""" - ... - - def qsize(self) -> int: - """Approximate the number of pending nodes awaiting execution.""" - ... - - def dumps(self) -> str: - """Serialize the queue contents for persistence.""" - ... - - def loads(self, data: str) -> None: - """Restore the queue contents from a serialized payload.""" - ... - - -class NodeExecutionProtocol(Protocol): - """Structural interface for persisted per-node execution state.""" - - execution_id: str | None - - -class GraphExecutionProtocol(Protocol): - """Structural interface for graph execution aggregate. - - Defines the minimal set of attributes and methods required from a GraphExecution entity - for runtime orchestration and state management. - """ - - workflow_id: str - started: bool - completed: bool - aborted: bool - error: Exception | None - exceptions_count: int - pause_reasons: list[PauseReason] - - @property - def node_executions(self) -> Mapping[str, NodeExecutionProtocol]: - """Return the persisted node execution state keyed by node id.""" - ... - - def start(self) -> None: - """Transition execution into the running state.""" - ... - - def complete(self) -> None: - """Mark execution as successfully completed.""" - ... - - def abort(self, reason: str) -> None: - """Abort execution in response to an external stop request.""" - ... - - def fail(self, error: Exception) -> None: - """Record an unrecoverable error and end execution.""" - ... - - def dumps(self) -> str: - """Serialize execution state into a JSON payload.""" - ... - - def loads(self, data: str) -> None: - """Restore execution state from a previously serialized payload.""" - ... - - -class ResponseStreamCoordinatorProtocol(Protocol): - """Structural interface for response stream coordinator.""" - - def register(self, response_node_id: str) -> None: - """Register a response node so its outputs can be streamed.""" - ... - - def loads(self, data: str) -> None: - """Restore coordinator state from a serialized payload.""" - ... - - def dumps(self) -> str: - """Serialize coordinator state for persistence.""" - ... - - -class NodeProtocol(Protocol): - """Structural interface for graph nodes.""" - - id: str - state: NodeState - execution_type: NodeExecutionType - node_type: ClassVar[NodeType] - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: ... - - -class EdgeProtocol(Protocol): - id: str - state: NodeState - tail: str - head: str - source_handle: str - - -class GraphProtocol(Protocol): - """Structural interface required from graph instances attached to the runtime state.""" - - nodes: Mapping[str, NodeProtocol] - edges: Mapping[str, EdgeProtocol] - root_node: NodeProtocol - - def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... - - -class ChildGraphEngineBuilderProtocol(Protocol): - def build_child_engine( - self, - *, - workflow_id: str, - graph_init_params: GraphInitParams, - parent_graph_runtime_state: GraphRuntimeState, - root_node_id: str, - variable_pool: VariablePool | None = None, - ) -> Any: ... - - -class ChildEngineError(ValueError): - """Base error type for child-engine creation failures.""" - - -class ChildEngineBuilderNotConfiguredError(ChildEngineError): - """Raised when child-engine creation is requested without a bound builder.""" - - -class ChildGraphNotFoundError(ChildEngineError): - """Raised when the requested child graph entry point cannot be resolved.""" - - -class _GraphStateSnapshot(BaseModel): - """Serializable graph state snapshot for node/edge states.""" - - nodes: dict[str, NodeState] = Field(default_factory=dict) - edges: dict[str, NodeState] = Field(default_factory=dict) - - -@dataclass(slots=True) -class _GraphRuntimeStateSnapshot: - """Immutable view of a serialized runtime state snapshot.""" - - start_at: float - total_tokens: int - node_run_steps: int - llm_usage: LLMUsage - outputs: dict[str, Any] - variable_pool: VariablePool - has_variable_pool: bool - ready_queue_dump: str | None - graph_execution_dump: str | None - response_coordinator_dump: str | None - paused_nodes: tuple[str, ...] - deferred_nodes: tuple[str, ...] - graph_node_states: dict[str, NodeState] - graph_edge_states: dict[str, NodeState] - - -class GraphRuntimeState: - """Mutable runtime state shared across graph execution components. - - `GraphRuntimeState` encapsulates the runtime state of workflow execution, - including scheduling details, variable values, and timing information. - - Values that are initialized prior to workflow execution and remain constant - throughout the execution should be part of `GraphInitParams` instead. - """ - - def __init__( - self, - *, - variable_pool: VariablePool, - start_at: float, - total_tokens: int = 0, - llm_usage: LLMUsage | None = None, - outputs: dict[str, object] | None = None, - node_run_steps: int = 0, - ready_queue: ReadyQueueProtocol | None = None, - graph_execution: GraphExecutionProtocol | None = None, - response_coordinator: ResponseStreamCoordinatorProtocol | None = None, - graph: GraphProtocol | None = None, - execution_context: AbstractContextManager[object] | None = None, - ) -> None: - self._variable_pool = variable_pool - self._start_at = start_at - - if total_tokens < 0: - raise ValueError("total_tokens must be non-negative") - self._total_tokens = total_tokens - - self._llm_usage = (llm_usage or LLMUsage.empty_usage()).model_copy() - self._outputs = deepcopy(outputs) if outputs is not None else {} - - if node_run_steps < 0: - raise ValueError("node_run_steps must be non-negative") - self._node_run_steps = node_run_steps - - self._graph: GraphProtocol | None = None - - self._ready_queue = ready_queue - self._graph_execution = graph_execution - self._response_coordinator = response_coordinator - # Application code injects this when worker threads must restore request - # or framework-local state. It is intentionally excluded from snapshots. - self._execution_context = execution_context if execution_context is not None else nullcontext(None) - self._pending_response_coordinator_dump: str | None = None - self._pending_graph_execution_workflow_id: str | None = None - self._paused_nodes: set[str] = set() - self._deferred_nodes: set[str] = set() - self._child_engine_builder: ChildGraphEngineBuilderProtocol | None = None - - # Node and edges states needed to be restored into - # graph object. - # - # These two fields are non-None only when resuming from a snapshot. - # Once the graph is attached, these two fields will be set to None. - self._pending_graph_node_states: dict[str, NodeState] | None = None - self._pending_graph_edge_states: dict[str, NodeState] | None = None - - if graph is not None: - self.attach_graph(graph) - - # ------------------------------------------------------------------ - # Context binding helpers - # ------------------------------------------------------------------ - def attach_graph(self, graph: GraphProtocol) -> None: - """Attach the materialized graph to the runtime state.""" - if self._graph is not None and self._graph is not graph: - raise ValueError("GraphRuntimeState already attached to a different graph instance") - - self._graph = graph - - if self._response_coordinator is None: - self._response_coordinator = self._build_response_coordinator(graph) - - if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None: - self._response_coordinator.loads(self._pending_response_coordinator_dump) - self._pending_response_coordinator_dump = None - self._apply_pending_graph_state() - - def configure(self, *, graph: GraphProtocol | None = None) -> None: - """Ensure core collaborators are initialized with the provided context.""" - if graph is not None: - self.attach_graph(graph) - - # Ensure collaborators are instantiated - _ = self.ready_queue - _ = self.graph_execution - if self._graph is not None: - _ = self.response_coordinator - - def bind_child_engine_builder(self, builder: ChildGraphEngineBuilderProtocol) -> None: - self._child_engine_builder = builder - - def create_child_engine( - self, - *, - workflow_id: str, - graph_init_params: GraphInitParams, - root_node_id: str, - variable_pool: VariablePool | None = None, - ) -> Any: - """Create a child graph engine that derives its runtime state from the parent.""" - if self._child_engine_builder is None: - raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.") - - return self._child_engine_builder.build_child_engine( - workflow_id=workflow_id, - graph_init_params=graph_init_params, - parent_graph_runtime_state=self, - root_node_id=root_node_id, - variable_pool=variable_pool, - ) - - # ------------------------------------------------------------------ - # Primary collaborators - # ------------------------------------------------------------------ - @property - def variable_pool(self) -> VariablePool: - return self._variable_pool - - @property - def ready_queue(self) -> ReadyQueueProtocol: - if self._ready_queue is None: - self._ready_queue = self._build_ready_queue() - return self._ready_queue - - @property - def graph_execution(self) -> GraphExecutionProtocol: - if self._graph_execution is None: - self._graph_execution = self._build_graph_execution() - return self._graph_execution - - @property - def response_coordinator(self) -> ResponseStreamCoordinatorProtocol: - if self._response_coordinator is None: - if self._graph is None: - raise ValueError("Graph must be attached before accessing response coordinator") - self._response_coordinator = self._build_response_coordinator(self._graph) - return self._response_coordinator - - @property - def execution_context(self) -> AbstractContextManager[object]: - return self._execution_context - - @execution_context.setter - def execution_context(self, value: AbstractContextManager[object] | None) -> None: - self._execution_context = value if value is not None else nullcontext(None) - - # ------------------------------------------------------------------ - # Scalar state - # ------------------------------------------------------------------ - @property - def start_at(self) -> float: - return self._start_at - - @start_at.setter - def start_at(self, value: float) -> None: - self._start_at = value - - @property - def total_tokens(self) -> int: - return self._total_tokens - - @total_tokens.setter - def total_tokens(self, value: int) -> None: - if value < 0: - raise ValueError("total_tokens must be non-negative") - self._total_tokens = value - - @property - def llm_usage(self) -> LLMUsage: - return self._llm_usage.model_copy() - - @llm_usage.setter - def llm_usage(self, value: LLMUsage) -> None: - self._llm_usage = value.model_copy() - - @property - def outputs(self) -> dict[str, Any]: - return deepcopy(self._outputs) - - @outputs.setter - def outputs(self, value: dict[str, Any]) -> None: - self._outputs = deepcopy(value) - - def set_output(self, key: str, value: object) -> None: - self._outputs[key] = deepcopy(value) - - def get_output(self, key: str, default: object = None) -> object: - return deepcopy(self._outputs.get(key, default)) - - def update_outputs(self, updates: dict[str, object]) -> None: - for key, value in updates.items(): - self._outputs[key] = deepcopy(value) - - @property - def node_run_steps(self) -> int: - return self._node_run_steps - - @node_run_steps.setter - def node_run_steps(self, value: int) -> None: - if value < 0: - raise ValueError("node_run_steps must be non-negative") - self._node_run_steps = value - - def increment_node_run_steps(self) -> None: - self._node_run_steps += 1 - - def add_tokens(self, tokens: int) -> None: - if tokens < 0: - raise ValueError("tokens must be non-negative") - self._total_tokens += tokens - - # ------------------------------------------------------------------ - # Serialization - # ------------------------------------------------------------------ - def dumps(self) -> str: - """Serialize runtime state into a JSON string.""" - - snapshot: dict[str, Any] = { - "version": "1.0", - "start_at": self._start_at, - "total_tokens": self._total_tokens, - "node_run_steps": self._node_run_steps, - "llm_usage": self._llm_usage.model_dump(mode="json"), - "outputs": self.outputs, - "variable_pool": self.variable_pool.model_dump(mode="json"), - "ready_queue": self.ready_queue.dumps(), - "graph_execution": self.graph_execution.dumps(), - "paused_nodes": list(self._paused_nodes), - "deferred_nodes": list(self._deferred_nodes), - } - - graph_state = self._snapshot_graph_state() - if graph_state is not None: - snapshot["graph_state"] = graph_state - - if self._response_coordinator is not None and self._graph is not None: - snapshot["response_coordinator"] = self._response_coordinator.dumps() - - return json.dumps(snapshot, default=pydantic_encoder) - - @classmethod - def from_snapshot(cls, data: str | Mapping[str, Any]) -> GraphRuntimeState: - """Restore runtime state from a serialized snapshot.""" - - snapshot = cls._parse_snapshot_payload(data) - - state = cls( - variable_pool=snapshot.variable_pool, - start_at=snapshot.start_at, - total_tokens=snapshot.total_tokens, - llm_usage=snapshot.llm_usage, - outputs=snapshot.outputs, - node_run_steps=snapshot.node_run_steps, - ) - state._apply_snapshot(snapshot) - return state - - def loads(self, data: str | Mapping[str, Any]) -> None: - """Restore runtime state from a serialized snapshot (legacy API).""" - - snapshot = self._parse_snapshot_payload(data) - self._apply_snapshot(snapshot) - - def register_paused_node(self, node_id: str) -> None: - """Record a node that should resume when execution is continued.""" - - self._paused_nodes.add(node_id) - - def get_paused_nodes(self) -> list[str]: - """Retrieve the list of paused nodes without mutating internal state.""" - - return list(self._paused_nodes) - - def consume_paused_nodes(self) -> list[str]: - """Retrieve and clear the list of paused nodes awaiting resume.""" - - nodes = list(self._paused_nodes) - self._paused_nodes.clear() - return nodes - - def register_deferred_node(self, node_id: str) -> None: - """Record a node that became ready during pause and should resume later.""" - - self._deferred_nodes.add(node_id) - - def get_deferred_nodes(self) -> list[str]: - """Retrieve deferred nodes without mutating internal state.""" - - return list(self._deferred_nodes) - - def consume_deferred_nodes(self) -> list[str]: - """Retrieve and clear deferred nodes awaiting resume.""" - - nodes = list(self._deferred_nodes) - self._deferred_nodes.clear() - return nodes - - # ------------------------------------------------------------------ - # Builders - # ------------------------------------------------------------------ - def _build_ready_queue(self) -> ReadyQueueProtocol: - # Import lazily to avoid breaching architecture boundaries enforced by import-linter. - module = importlib.import_module("graphon.graph_engine.ready_queue") - in_memory_cls = module.InMemoryReadyQueue - return in_memory_cls() - - def _build_graph_execution(self) -> GraphExecutionProtocol: - # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("graphon.graph_engine.domain.graph_execution") - graph_execution_cls = module.GraphExecution - workflow_id = self._pending_graph_execution_workflow_id or "" - self._pending_graph_execution_workflow_id = None - return graph_execution_cls(workflow_id=workflow_id) # type: ignore[invalid-return-type] - - def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol: - # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("graphon.graph_engine.response_coordinator") - coordinator_cls = module.ResponseStreamCoordinator - return coordinator_cls(variable_pool=self.variable_pool, graph=graph) - - # ------------------------------------------------------------------ - # Snapshot helpers - # ------------------------------------------------------------------ - @classmethod - def _parse_snapshot_payload(cls, data: str | Mapping[str, Any]) -> _GraphRuntimeStateSnapshot: - payload: dict[str, Any] - if isinstance(data, str): - payload = json.loads(data) - else: - payload = dict(data) - - version = payload.get("version") - if version != "1.0": - raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}") - - start_at = float(payload.get("start_at", 0.0)) - - total_tokens = int(payload.get("total_tokens", 0)) - if total_tokens < 0: - raise ValueError("total_tokens must be non-negative") - - node_run_steps = int(payload.get("node_run_steps", 0)) - if node_run_steps < 0: - raise ValueError("node_run_steps must be non-negative") - - llm_usage_payload = payload.get("llm_usage", {}) - llm_usage = LLMUsage.model_validate(llm_usage_payload) - - outputs_payload = deepcopy(payload.get("outputs", {})) - - variable_pool_payload = payload.get("variable_pool") - has_variable_pool = variable_pool_payload is not None - variable_pool = VariablePool.model_validate(variable_pool_payload) if has_variable_pool else VariablePool() - - ready_queue_payload = payload.get("ready_queue") - graph_execution_payload = payload.get("graph_execution") - response_payload = payload.get("response_coordinator") - paused_nodes_payload = payload.get("paused_nodes", []) - deferred_nodes_payload = payload.get("deferred_nodes", []) - graph_state_payload = payload.get("graph_state", {}) or {} - graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes") - graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges") - - return _GraphRuntimeStateSnapshot( - start_at=start_at, - total_tokens=total_tokens, - node_run_steps=node_run_steps, - llm_usage=llm_usage, - outputs=outputs_payload, - variable_pool=variable_pool, - has_variable_pool=has_variable_pool, - ready_queue_dump=ready_queue_payload, - graph_execution_dump=graph_execution_payload, - response_coordinator_dump=response_payload, - paused_nodes=tuple(map(str, paused_nodes_payload)), - deferred_nodes=tuple(map(str, deferred_nodes_payload)), - graph_node_states=graph_node_states, - graph_edge_states=graph_edge_states, - ) - - def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None: - self._start_at = snapshot.start_at - self._total_tokens = snapshot.total_tokens - self._node_run_steps = snapshot.node_run_steps - self._llm_usage = snapshot.llm_usage.model_copy() - self._outputs = deepcopy(snapshot.outputs) - if snapshot.has_variable_pool or self._variable_pool is None: - self._variable_pool = snapshot.variable_pool - - self._restore_ready_queue(snapshot.ready_queue_dump) - self._restore_graph_execution(snapshot.graph_execution_dump) - self._restore_response_coordinator(snapshot.response_coordinator_dump) - self._paused_nodes = set(snapshot.paused_nodes) - self._deferred_nodes = set(snapshot.deferred_nodes) - self._pending_graph_node_states = snapshot.graph_node_states or None - self._pending_graph_edge_states = snapshot.graph_edge_states or None - self._apply_pending_graph_state() - - def _restore_ready_queue(self, payload: str | None) -> None: - if payload is not None: - self._ready_queue = self._build_ready_queue() - self._ready_queue.loads(payload) - else: - self._ready_queue = None - - def _restore_graph_execution(self, payload: str | None) -> None: - self._graph_execution = None - self._pending_graph_execution_workflow_id = None - - if payload is None: - return - - try: - execution_payload = json.loads(payload) - self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id") - except (json.JSONDecodeError, TypeError, AttributeError): - self._pending_graph_execution_workflow_id = None - - self.graph_execution.loads(payload) - - def _restore_response_coordinator(self, payload: str | None) -> None: - if payload is None: - self._pending_response_coordinator_dump = None - self._response_coordinator = None - return - - if self._graph is not None: - self.response_coordinator.loads(payload) - self._pending_response_coordinator_dump = None - return - - self._pending_response_coordinator_dump = payload - self._response_coordinator = None - - def _snapshot_graph_state(self) -> _GraphStateSnapshot: - graph = self._graph - if graph is None: - if self._pending_graph_node_states is None and self._pending_graph_edge_states is None: - return _GraphStateSnapshot() - return _GraphStateSnapshot( - nodes=self._pending_graph_node_states or {}, - edges=self._pending_graph_edge_states or {}, - ) - - nodes = graph.nodes - edges = graph.edges - if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping): - return _GraphStateSnapshot() - - node_states = {} - for node_id, node in nodes.items(): - if not isinstance(node_id, str): - continue - node_states[node_id] = node.state - - edge_states = {} - for edge_id, edge in edges.items(): - if not isinstance(edge_id, str): - continue - edge_states[edge_id] = edge.state - - return _GraphStateSnapshot(nodes=node_states, edges=edge_states) - - def _apply_pending_graph_state(self) -> None: - if self._graph is None: - return - if self._pending_graph_node_states: - for node_id, state in self._pending_graph_node_states.items(): - node = self._graph.nodes.get(node_id) - if node is None: - continue - node.state = state - if self._pending_graph_edge_states: - for edge_id, state in self._pending_graph_edge_states.items(): - edge = self._graph.edges.get(edge_id) - if edge is None: - continue - edge.state = state - - self._pending_graph_node_states = None - self._pending_graph_edge_states = None - - -def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]: - if not isinstance(payload, Mapping): - return {} - raw_map = payload.get(key, {}) - if not isinstance(raw_map, Mapping): - return {} - result: dict[str, NodeState] = {} - for node_id, raw_state in raw_map.items(): - if not isinstance(node_id, str): - continue - try: - result[node_id] = NodeState(str(raw_state)) - except ValueError: - continue - return result diff --git a/api/graphon/runtime/graph_runtime_state_protocol.py b/api/graphon/runtime/graph_runtime_state_protocol.py deleted file mode 100644 index 856625a5d30..00000000000 --- a/api/graphon/runtime/graph_runtime_state_protocol.py +++ /dev/null @@ -1,79 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Protocol - -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.variables.segments import Segment - - -class ReadOnlyVariablePool(Protocol): - """Read-only interface for VariablePool.""" - - def get(self, selector: Sequence[str], /) -> Segment | None: - """Get a variable value (read-only).""" - ... - - def get_all_by_node(self, node_id: str) -> Mapping[str, object]: - """Get all variables for a node (read-only).""" - ... - - def get_by_prefix(self, prefix: str) -> Mapping[str, object]: - """Get all variables stored under a given node prefix (read-only).""" - ... - - -class ReadOnlyGraphRuntimeState(Protocol): - """ - Read-only view of GraphRuntimeState for layers. - - This protocol defines a read-only interface that prevents layers from - modifying the graph runtime state while still allowing observation. - All methods return defensive copies to ensure immutability. - """ - - @property - def variable_pool(self) -> ReadOnlyVariablePool: - """Get read-only access to the variable pool.""" - ... - - @property - def start_at(self) -> float: - """Get the start time (read-only).""" - ... - - @property - def total_tokens(self) -> int: - """Get the total tokens count (read-only).""" - ... - - @property - def llm_usage(self) -> LLMUsage: - """Get a copy of LLM usage info (read-only).""" - ... - - @property - def outputs(self) -> dict[str, Any]: - """Get a defensive copy of outputs (read-only).""" - ... - - @property - def node_run_steps(self) -> int: - """Get the node run steps count (read-only).""" - ... - - @property - def ready_queue_size(self) -> int: - """Get the number of nodes currently in the ready queue.""" - ... - - @property - def exceptions_count(self) -> int: - """Get the number of node execution exceptions recorded.""" - ... - - def get_output(self, key: str, default: Any = None) -> Any: - """Get a single output value (returns a copy).""" - ... - - def dumps(self) -> str: - """Serialize the runtime state into a JSON snapshot (read-only).""" - ... diff --git a/api/graphon/runtime/read_only_wrappers.py b/api/graphon/runtime/read_only_wrappers.py deleted file mode 100644 index aaef2552041..00000000000 --- a/api/graphon/runtime/read_only_wrappers.py +++ /dev/null @@ -1,82 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from copy import deepcopy -from typing import Any - -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.variables.segments import Segment - -from .graph_runtime_state import GraphRuntimeState -from .variable_pool import VariablePool - - -class ReadOnlyVariablePoolWrapper: - """Provide defensive, read-only access to ``VariablePool``.""" - - def __init__(self, variable_pool: VariablePool) -> None: - self._variable_pool = variable_pool - - def get(self, selector: Sequence[str], /) -> Segment | None: - """Return a copy of a variable value if present.""" - value = self._variable_pool.get(selector) - return deepcopy(value) if value is not None else None - - def get_all_by_node(self, node_id: str) -> Mapping[str, object]: - """Return a copy of all variables for the specified node.""" - variables: dict[str, object] = {} - if node_id in self._variable_pool.variable_dictionary: - for key, variable in self._variable_pool.variable_dictionary[node_id].items(): - variables[key] = deepcopy(variable.value) - return variables - - def get_by_prefix(self, prefix: str) -> Mapping[str, object]: - """Return a copy of all variables stored under the given prefix.""" - return self._variable_pool.get_by_prefix(prefix) - - -class ReadOnlyGraphRuntimeStateWrapper: - """Expose a defensive, read-only view of ``GraphRuntimeState``.""" - - def __init__(self, state: GraphRuntimeState) -> None: - self._state = state - self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool) - - @property - def variable_pool(self) -> ReadOnlyVariablePoolWrapper: - return self._variable_pool_wrapper - - @property - def start_at(self) -> float: - return self._state.start_at - - @property - def total_tokens(self) -> int: - return self._state.total_tokens - - @property - def llm_usage(self) -> LLMUsage: - return self._state.llm_usage.model_copy() - - @property - def outputs(self) -> dict[str, Any]: - return deepcopy(self._state.outputs) - - @property - def node_run_steps(self) -> int: - return self._state.node_run_steps - - @property - def ready_queue_size(self) -> int: - return self._state.ready_queue.qsize() - - @property - def exceptions_count(self) -> int: - return self._state.graph_execution.exceptions_count - - def get_output(self, key: str, default: Any = None) -> Any: - return self._state.get_output(key, default) - - def dumps(self) -> str: - """Serialize the underlying runtime state for external persistence.""" - return self._state.dumps() diff --git a/api/graphon/runtime/variable_pool.py b/api/graphon/runtime/variable_pool.py deleted file mode 100644 index b44d1a8abeb..00000000000 --- a/api/graphon/runtime/variable_pool.py +++ /dev/null @@ -1,279 +0,0 @@ -from __future__ import annotations - -import re -from collections import defaultdict -from collections.abc import Mapping, Sequence -from copy import deepcopy -from typing import Annotated, Any, Union, cast - -from pydantic import BaseModel, Field, model_validator - -from graphon.file import File, FileAttribute, file_manager -from graphon.variables import Segment, SegmentGroup, VariableBase, build_segment, segment_to_variable -from graphon.variables.consts import SELECTORS_LENGTH -from graphon.variables.segments import FileSegment, ObjectSegment -from graphon.variables.variables import RAGPipelineVariableInput, Variable - -VariableValue = Union[str, int, float, dict[str, object], list[object], File] - -VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") - - -def _default_variable_dictionary() -> defaultdict[str, dict[str, Variable]]: - return defaultdict(dict) - - -class VariablePool(BaseModel): - _SYSTEM_VARIABLE_NODE_ID = "sys" - _ENVIRONMENT_VARIABLE_NODE_ID = "env" - _CONVERSATION_VARIABLE_NODE_ID = "conversation" - _RAG_PIPELINE_VARIABLE_NODE_ID = "rag" - - # Variable dictionary is a dictionary for looking up variables by their selector. - # The first element of the selector is the node id, it's the first-level key in the dictionary. - # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the - # elements of the selector except the first one. - variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field( - description="Variables mapping", - default_factory=_default_variable_dictionary, - ) - system_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) - environment_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) - conversation_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) - rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = Field(default_factory=tuple, exclude=True) - user_inputs: Mapping[str, Any] = Field(default_factory=dict, exclude=True) - - @model_validator(mode="after") - def _load_legacy_bootstrap_inputs(self) -> VariablePool: - """ - Accept legacy constructor kwargs that still appear throughout the workflow - layer while keeping serialized state focused on `variable_dictionary`. - """ - - self._ingest_legacy_variables(self.system_variables, node_id=self._SYSTEM_VARIABLE_NODE_ID) - self._ingest_legacy_variables(self.environment_variables, node_id=self._ENVIRONMENT_VARIABLE_NODE_ID) - self._ingest_legacy_variables(self.conversation_variables, node_id=self._CONVERSATION_VARIABLE_NODE_ID) - self._ingest_legacy_rag_variables(self.rag_pipeline_variables) - - # These kwargs are accepted for compatibility but should not affect the - # stable serialized form or model equality. - self.system_variables = () - self.environment_variables = () - self.conversation_variables = () - self.rag_pipeline_variables = () - self.user_inputs = {} - return self - - def _ingest_legacy_variables(self, variables: Sequence[Variable], *, node_id: str) -> None: - for variable in variables: - selector = [node_id, variable.name] - normalized_variable = variable - if list(variable.selector) != selector: - normalized_variable = variable.model_copy(update={"selector": selector}) - self.add(normalized_variable.selector, normalized_variable) - - def _ingest_legacy_rag_variables(self, rag_pipeline_variables: Sequence[RAGPipelineVariableInput]) -> None: - if not rag_pipeline_variables: - return - - values_by_node_id: defaultdict[str, dict[str, Any]] = defaultdict(dict) - for rag_variable_input in rag_pipeline_variables: - values_by_node_id[rag_variable_input.variable.belong_to_node_id][rag_variable_input.variable.variable] = ( - rag_variable_input.value - ) - - for node_id, value in values_by_node_id.items(): - self.add((self._RAG_PIPELINE_VARIABLE_NODE_ID, node_id), value) - - def add(self, selector: Sequence[str], value: Any, /): - """ - Add a variable to the variable pool. - - This method accepts a selector path and a value, converting the value - to a Variable object if necessary before storing it in the pool. - - Args: - selector: A two-element sequence containing [node_id, variable_name]. - The selector must have exactly 2 elements to be valid. - value: The value to store. Can be a Variable, Segment, or any value - that can be converted to a Segment (str, int, float, dict, list, File). - - Raises: - ValueError: If selector length is not exactly 2 elements. - - Note: - While non-Segment values are currently accepted and automatically - converted, it's recommended to pass Segment or Variable objects directly. - """ - if len(selector) != SELECTORS_LENGTH: - raise ValueError( - f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), " - f"got {len(selector)} elements" - ) - - if isinstance(value, VariableBase): - variable = value - elif isinstance(value, Segment): - variable = segment_to_variable(segment=value, selector=selector) - else: - segment = build_segment(value) - variable = segment_to_variable(segment=segment, selector=selector) - - node_id, name = self._selector_to_keys(selector) - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - self.variable_dictionary[node_id][name] = cast(Variable, variable) - - @classmethod - def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]: - return selector[0], selector[1] - - def _has(self, selector: Sequence[str]) -> bool: - node_id, name = self._selector_to_keys(selector) - if node_id not in self.variable_dictionary: - return False - if name not in self.variable_dictionary[node_id]: - return False - return True - - def get(self, selector: Sequence[str], /) -> Segment | None: - """ - Retrieve a variable's value from the pool as a Segment. - - This method supports both simple selectors [node_id, variable_name] and - extended selectors that include attribute access for FileSegment and - ObjectSegment types. - - Args: - selector: A sequence with at least 2 elements: - - [node_id, variable_name]: Returns the full segment - - [node_id, variable_name, attr, ...]: Returns a nested value - from FileSegment (e.g., 'url', 'name') or ObjectSegment - - Returns: - The Segment associated with the selector, or None if not found. - Returns None if selector has fewer than 2 elements. - - Raises: - ValueError: If attempting to access an invalid FileAttribute. - """ - if len(selector) < SELECTORS_LENGTH: - return None - - node_id, name = self._selector_to_keys(selector) - node_map = self.variable_dictionary.get(node_id) - if node_map is None: - return None - - segment: Segment | None = node_map.get(name) - - if segment is None: - return None - - if len(selector) == 2: - return segment - - if isinstance(segment, FileSegment): - attr = selector[2] - # Python support `attr in FileAttribute` after 3.12 - if attr not in {item.value for item in FileAttribute}: - return None - attr = FileAttribute(attr) - attr_value = file_manager.get_attr(file=segment.value, attr=attr) - return build_segment(attr_value) - - # Navigate through nested attributes - result: Any = segment - for attr in selector[2:]: - result = self._extract_value(result) - result = self._get_nested_attribute(result, attr) - if result is None: - return None - - # Return result as Segment - return result if isinstance(result, Segment) else build_segment(result) - - def _extract_value(self, obj: Any): - """Extract the actual value from an ObjectSegment.""" - return obj.value if isinstance(obj, ObjectSegment) else obj - - def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None: - """ - Get a nested attribute from a dictionary-like object. - - Args: - obj: The dictionary-like object to search. - attr: The key to look up. - - Returns: - Segment | None: - The corresponding Segment built from the attribute value if the key exists, - otherwise None. - """ - if not isinstance(obj, dict) or attr not in obj: - return None - return build_segment(obj.get(attr)) - - def remove(self, selector: Sequence[str], /): - """ - Remove variables from the variable pool based on the given selector. - - Args: - selector (Sequence[str]): A sequence of strings representing the selector. - - Returns: - None - """ - if not selector: - return - if len(selector) == 1: - self.variable_dictionary[selector[0]] = {} - return - key, hash_key = self._selector_to_keys(selector) - self.variable_dictionary[key].pop(hash_key, None) - - def convert_template(self, template: str, /): - parts = VARIABLE_PATTERN.split(template) - segments: list[Segment] = [] - for part in filter(lambda x: x, parts): - if "." in part and (variable := self.get(part.split("."))): - segments.append(variable) - else: - segments.append(build_segment(part)) - return SegmentGroup(value=segments) - - def get_file(self, selector: Sequence[str], /) -> FileSegment | None: - segment = self.get(selector) - if isinstance(segment, FileSegment): - return segment - return None - - def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: - """Return a copy of all variables stored under the given node prefix.""" - - nodes = self.variable_dictionary.get(prefix) - if not nodes: - return {} - - result: dict[str, object] = {} - for key, variable in nodes.items(): - value = variable.value - result[key] = deepcopy(value) - - return result - - def flatten(self, *, unprefixed_node_id: str | None = None) -> Mapping[str, object]: - """Return a selector-style snapshot of the entire variable pool.""" - - result: dict[str, object] = {} - for node_id, variables in self.variable_dictionary.items(): - for name, variable in variables.items(): - output_name = name if node_id == unprefixed_node_id else f"{node_id}.{name}" - result[output_name] = deepcopy(variable.value) - - return result - - @classmethod - def empty(cls) -> VariablePool: - """Create an empty variable pool.""" - return cls() diff --git a/api/graphon/template_rendering.py b/api/graphon/template_rendering.py deleted file mode 100644 index 0527e58f6df..00000000000 --- a/api/graphon/template_rendering.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from collections.abc import Mapping -from typing import Any - - -class TemplateRenderError(ValueError): - """Raised when rendering a template fails.""" - - -class Jinja2TemplateRenderer(ABC): - """Nominal renderer contract for Jinja2 template rendering in graph nodes.""" - - @abstractmethod - def render_template(self, template: str, variables: Mapping[str, Any]) -> str: - """Render the template into plain text.""" - raise NotImplementedError diff --git a/api/graphon/utils/__init__.py b/api/graphon/utils/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/graphon/utils/condition/__init__.py b/api/graphon/utils/condition/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/graphon/utils/condition/entities.py b/api/graphon/utils/condition/entities.py deleted file mode 100644 index 77a214571a1..00000000000 --- a/api/graphon/utils/condition/entities.py +++ /dev/null @@ -1,49 +0,0 @@ -from collections.abc import Sequence -from typing import Literal - -from pydantic import BaseModel, Field - -SupportedComparisonOperator = Literal[ - # for string or array - "contains", - "not contains", - "start with", - "end with", - "is", - "is not", - "empty", - "not empty", - "in", - "not in", - "all of", - # for number - "=", - "โ‰ ", - ">", - "<", - "โ‰ฅ", - "โ‰ค", - "null", - "not null", - # for file - "exists", - "not exists", -] - - -class SubCondition(BaseModel): - key: str - comparison_operator: SupportedComparisonOperator - value: str | Sequence[str] | None = None - - -class SubVariableCondition(BaseModel): - logical_operator: Literal["and", "or"] - conditions: list[SubCondition] = Field(default_factory=list) - - -class Condition(BaseModel): - variable_selector: list[str] - comparison_operator: SupportedComparisonOperator - value: str | Sequence[str] | bool | None = None - sub_variable_condition: SubVariableCondition | None = None diff --git a/api/graphon/utils/condition/processor.py b/api/graphon/utils/condition/processor.py deleted file mode 100644 index 03535927cb5..00000000000 --- a/api/graphon/utils/condition/processor.py +++ /dev/null @@ -1,504 +0,0 @@ -import json -from collections.abc import Mapping, Sequence -from typing import Literal, NamedTuple - -from graphon.file import FileAttribute, file_manager -from graphon.runtime import VariablePool -from graphon.variables import ArrayFileSegment -from graphon.variables.segments import ArrayBooleanSegment, BooleanSegment - -from .entities import Condition, SubCondition, SupportedComparisonOperator - - -def _convert_to_bool(value: object) -> bool: - if isinstance(value, int): - return bool(value) - - if isinstance(value, str): - loaded = json.loads(value) - if isinstance(loaded, (int, bool)): - return bool(loaded) - - raise TypeError(f"unexpected value: type={type(value)}, value={value}") - - -class ConditionCheckResult(NamedTuple): - inputs: Sequence[Mapping[str, object]] - group_results: Sequence[bool] - final_result: bool - - -class ConditionProcessor: - def process_conditions( - self, - *, - variable_pool: VariablePool, - conditions: Sequence[Condition], - operator: Literal["and", "or"], - ) -> ConditionCheckResult: - input_conditions: list[Mapping[str, object]] = [] - group_results: list[bool] = [] - - for condition in conditions: - variable = variable_pool.get(condition.variable_selector) - if variable is None: - raise ValueError(f"Variable {condition.variable_selector} not found") - - if isinstance(variable, ArrayFileSegment) and condition.comparison_operator in { - "contains", - "not contains", - "all of", - }: - # check sub conditions - if not condition.sub_variable_condition: - raise ValueError("Sub variable is required") - result = _process_sub_conditions( - variable=variable, - sub_conditions=condition.sub_variable_condition.conditions, - operator=condition.sub_variable_condition.logical_operator, - ) - elif condition.comparison_operator in { - "exists", - "not exists", - }: - result = _evaluate_condition( - value=variable.value, - operator=condition.comparison_operator, - expected=None, - ) - else: - actual_value = variable.value if variable else None - expected_value: str | Sequence[str] | bool | list[bool] | None = condition.value - if isinstance(expected_value, str): - expected_value = variable_pool.convert_template(expected_value).text - # Here we need to explicit convet the input string to boolean. - if isinstance(variable, (BooleanSegment, ArrayBooleanSegment)) and expected_value is not None: - # The following two lines is for compatibility with existing workflows. - if isinstance(expected_value, list): - expected_value = [_convert_to_bool(i) for i in expected_value] - else: - expected_value = _convert_to_bool(expected_value) - input_conditions.append( - { - "actual_value": actual_value, - "expected_value": expected_value, - "comparison_operator": condition.comparison_operator, - } - ) - result = _evaluate_condition( - value=actual_value, - operator=condition.comparison_operator, - expected=expected_value, - ) - group_results.append(result) - # Implemented short-circuit evaluation for logical conditions - if (operator == "and" and not result) or (operator == "or" and result): - final_result = result - return ConditionCheckResult(input_conditions, group_results, final_result) - - final_result = all(group_results) if operator == "and" else any(group_results) - return ConditionCheckResult(input_conditions, group_results, final_result) - - -def _evaluate_condition( - *, - operator: SupportedComparisonOperator, - value: object, - expected: str | Sequence[str] | bool | Sequence[bool] | None, -) -> bool: - match operator: - case "contains": - return _assert_contains(value=value, expected=expected) - case "not contains": - return _assert_not_contains(value=value, expected=expected) - case "start with": - return _assert_start_with(value=value, expected=expected) - case "end with": - return _assert_end_with(value=value, expected=expected) - case "is": - return _assert_is(value=value, expected=expected) - case "is not": - return _assert_is_not(value=value, expected=expected) - case "empty": - return _assert_empty(value=value) - case "not empty": - return _assert_not_empty(value=value) - case "=": - return _assert_equal(value=value, expected=expected) - case "โ‰ ": - return _assert_not_equal(value=value, expected=expected) - case ">": - return _assert_greater_than(value=value, expected=expected) - case "<": - return _assert_less_than(value=value, expected=expected) - case "โ‰ฅ": - return _assert_greater_than_or_equal(value=value, expected=expected) - case "โ‰ค": - return _assert_less_than_or_equal(value=value, expected=expected) - case "null": - return _assert_null(value=value) - case "not null": - return _assert_not_null(value=value) - case "in": - return _assert_in(value=value, expected=expected) - case "not in": - return _assert_not_in(value=value, expected=expected) - case "all of" if isinstance(expected, list): - # Type narrowing: at this point expected is a list, could be list[str] or list[bool] - if all(isinstance(item, str) for item in expected): - # Create a new typed list to satisfy type checker - str_list: list[str] = [item for item in expected if isinstance(item, str)] - return _assert_all_of(value=value, expected=str_list) - elif all(isinstance(item, bool) for item in expected): - # Create a new typed list to satisfy type checker - bool_list: list[bool] = [item for item in expected if isinstance(item, bool)] - return _assert_all_of_bool(value=value, expected=bool_list) - else: - raise ValueError("all of operator expects homogeneous list of strings or booleans") - case "exists": - return _assert_exists(value=value) - case "not exists": - return _assert_not_exists(value=value) - case _: - raise ValueError(f"Unsupported operator: {operator}") - - -def _assert_contains(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(value, (str, list)): - raise ValueError("Invalid actual value type: string or array") - - # Type checking ensures value is str or list at this point - if isinstance(value, str): - if not isinstance(expected, str): - expected = str(expected) - if expected not in value: - return False - else: # value is list - if expected not in value: - return False - return True - - -def _assert_not_contains(*, value: object, expected: object) -> bool: - if not value: - return True - - if not isinstance(value, (str, list)): - raise ValueError("Invalid actual value type: string or array") - - # Type checking ensures value is str or list at this point - if isinstance(value, str): - if not isinstance(expected, str): - expected = str(expected) - if expected in value: - return False - else: # value is list - if expected in value: - return False - return True - - -def _assert_start_with(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(value, str): - raise ValueError("Invalid actual value type: string") - - if not isinstance(expected, str): - raise ValueError("Expected value must be a string for startswith") - if not value.startswith(expected): - return False - return True - - -def _assert_end_with(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(value, str): - raise ValueError("Invalid actual value type: string") - - if not isinstance(expected, str): - raise ValueError("Expected value must be a string for endswith") - if not value.endswith(expected): - return False - return True - - -def _assert_is(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (str, bool)): - raise ValueError("Invalid actual value type: string or boolean") - - if value != expected: - return False - return True - - -def _assert_is_not(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (str, bool)): - raise ValueError("Invalid actual value type: string or boolean") - - if value == expected: - return False - return True - - -def _assert_empty(*, value: object) -> bool: - if not value: - return True - return False - - -def _assert_not_empty(*, value: object) -> bool: - if value: - return True - return False - - -def _normalize_numeric_values(value: int | float, expected: object) -> tuple[int | float, int | float]: - """ - Normalize value and expected to compatible numeric types for comparison. - - Args: - value: The actual numeric value (int or float) - expected: The expected value (int, float, or str) - - Returns: - A tuple of (normalized_value, normalized_expected) with compatible types - - Raises: - ValueError: If expected cannot be converted to a number - """ - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to number") - - # Convert expected to appropriate numeric type - if isinstance(expected, str): - # Try to convert to float first to handle decimal strings - try: - expected_float = float(expected) - except ValueError as e: - raise ValueError(f"Cannot convert '{expected}' to number") from e - - # If value is int and expected is a whole number, keep as int comparison - if isinstance(value, int) and expected_float.is_integer(): - return value, int(expected_float) - else: - # Otherwise convert value to float for comparison - return float(value) if isinstance(value, int) else value, expected_float - elif isinstance(expected, float): - # If expected is already float, convert int value to float - return float(value) if isinstance(value, int) else value, expected - else: - # expected is int - return value, expected - - -def _assert_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float, bool)): - raise ValueError("Invalid actual value type: number or boolean") - - # Handle boolean comparison - if isinstance(value, bool): - if not isinstance(expected, (bool, int, str)): - raise ValueError(f"Cannot convert {type(expected)} to bool") - expected = bool(expected) - elif isinstance(value, int): - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to int") - expected = int(expected) - else: - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to float") - expected = float(expected) - - if value != expected: - return False - return True - - -def _assert_not_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float, bool)): - raise ValueError("Invalid actual value type: number or boolean") - - # Handle boolean comparison - if isinstance(value, bool): - if not isinstance(expected, (bool, int, str)): - raise ValueError(f"Cannot convert {type(expected)} to bool") - expected = bool(expected) - elif isinstance(value, int): - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to int") - expected = int(expected) - else: - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to float") - expected = float(expected) - - if value == expected: - return False - return True - - -def _assert_greater_than(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value > expected - - -def _assert_less_than(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value < expected - - -def _assert_greater_than_or_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value >= expected - - -def _assert_less_than_or_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value <= expected - - -def _assert_null(*, value: object) -> bool: - if value is None: - return True - return False - - -def _assert_not_null(*, value: object) -> bool: - if value is not None: - return True - return False - - -def _assert_in(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(expected, list): - raise ValueError("Invalid expected value type: array") - - if value not in expected: - return False - return True - - -def _assert_not_in(*, value: object, expected: object) -> bool: - if not value: - return True - - if not isinstance(expected, list): - raise ValueError("Invalid expected value type: array") - - if value in expected: - return False - return True - - -def _assert_all_of(*, value: object, expected: Sequence[str]) -> bool: - if not value: - return False - - # Ensure value is a container that supports 'in' operator - if not isinstance(value, (list, tuple, set, str)): - return False - - return all(item in value for item in expected) - - -def _assert_all_of_bool(*, value: object, expected: Sequence[bool]) -> bool: - if not value: - return False - - # Ensure value is a container that supports 'in' operator - if not isinstance(value, (list, tuple, set)): - return False - - return all(item in value for item in expected) - - -def _assert_exists(*, value: object) -> bool: - return value is not None - - -def _assert_not_exists(*, value: object) -> bool: - return value is None - - -def _process_sub_conditions( - variable: ArrayFileSegment, - sub_conditions: Sequence[SubCondition], - operator: Literal["and", "or"], -) -> bool: - files = variable.value - group_results: list[bool] = [] - for condition in sub_conditions: - key = FileAttribute(condition.key) - values = [file_manager.get_attr(file=file, attr=key) for file in files] - expected_value = condition.value - if key == FileAttribute.EXTENSION: - if not isinstance(expected_value, str): - raise TypeError("Expected value must be a string when key is FileAttribute.EXTENSION") - if expected_value and not expected_value.startswith("."): - expected_value = "." + expected_value - - normalized_values: list[object] = [] - for value in values: - if value and isinstance(value, str): - if not value.startswith("."): - value = "." + value - normalized_values.append(value) - values = normalized_values - sub_group_results: list[bool] = [ - _evaluate_condition( - value=value, - operator=condition.comparison_operator, - expected=expected_value, - ) - for value in values - ] - # Determine the result based on the presence of "not" in the comparison operator - result = all(sub_group_results) if "not" in condition.comparison_operator else any(sub_group_results) - group_results.append(result) - return all(group_results) if operator == "and" else any(group_results) diff --git a/api/graphon/utils/json_in_md_parser.py b/api/graphon/utils/json_in_md_parser.py deleted file mode 100644 index 4416b4774bb..00000000000 --- a/api/graphon/utils/json_in_md_parser.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -import json - - -class OutputParserError(ValueError): - """Raised when a markdown-wrapped JSON payload cannot be parsed or validated.""" - - -def parse_json_markdown(json_string: str) -> dict | list: - """Extract and parse the first JSON object or array embedded in markdown text.""" - json_string = json_string.strip() - starts = ["```json", "```", "``", "`", "{", "["] - ends = ["```", "``", "`", "}", "]"] - end_index = -1 - start_index = 0 - - for start_marker in starts: - start_index = json_string.find(start_marker) - if start_index != -1: - if json_string[start_index] not in ("{", "["): - start_index += len(start_marker) - break - - if start_index != -1: - for end_marker in ends: - end_index = json_string.rfind(end_marker, start_index) - if end_index != -1: - if json_string[end_index] in ("}", "]"): - end_index += 1 - break - - if start_index == -1 or end_index == -1 or start_index >= end_index: - raise ValueError("could not find json block in the output.") - - extracted_content = json_string[start_index:end_index].strip() - return json.loads(extracted_content) - - -def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: - try: - json_obj = parse_json_markdown(text) - except json.JSONDecodeError as exc: - raise OutputParserError(f"got invalid json object. error: {exc}") from exc - - if isinstance(json_obj, list): - if len(json_obj) == 1 and isinstance(json_obj[0], dict): - json_obj = json_obj[0] - else: - raise OutputParserError(f"got invalid return object. obj:{json_obj}") - - for key in expected_keys: - if key not in json_obj: - raise OutputParserError( - f"got invalid return object. expected key `{key}` to be present, but got {json_obj}" - ) - - return json_obj diff --git a/api/graphon/variable_loader.py b/api/graphon/variable_loader.py deleted file mode 100644 index 03db920d3d0..00000000000 --- a/api/graphon/variable_loader.py +++ /dev/null @@ -1,75 +0,0 @@ -import abc -from collections.abc import Mapping, Sequence -from typing import Any, Protocol - -from graphon.runtime import VariablePool -from graphon.variables import VariableBase -from graphon.variables.consts import SELECTORS_LENGTH - - -class VariableLoader(Protocol): - """Interface for loading variables based on selectors. - - A `VariableLoader` is responsible for retrieving additional variables required during the execution - of a single node, which are not provided as user inputs. - - TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into - `WorkflowService.single_step_run`, we may get rid of this interface. - """ - - @abc.abstractmethod - def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: - """Load variables based on the provided selectors. If the selectors are empty, - this method should return an empty list. - - The order of the returned variables is not guaranteed. If the caller wants to ensure - a specific order, they should sort the returned list themselves. - - :param: selectors: a list of string list, each inner list should have at least two elements: - - the first element is the node ID, - - the second element is the variable name. - :return: a list of VariableBase objects that match the provided selectors. - """ - pass - - -class _DummyVariableLoader(VariableLoader): - """A dummy implementation of VariableLoader that does not load any variables. - Serves as a placeholder when no variable loading is needed. - """ - - def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: - return [] - - -DUMMY_VARIABLE_LOADER = _DummyVariableLoader() - - -def load_into_variable_pool( - variable_loader: VariableLoader, - variable_pool: VariablePool, - variable_mapping: Mapping[str, Sequence[str]], - user_inputs: Mapping[str, Any], -): - # Loading missing variable from draft var here, and set it into - # variable_pool. - variables_to_load: list[list[str]] = [] - for key, selector in variable_mapping.items(): - # NOTE(QuantumGhost): this logic needs to be in sync with - # `WorkflowEntry.mapping_user_inputs_to_variable_pool`. - node_variable_list = key.split(".") - if len(node_variable_list) < 2: - raise ValueError(f"Invalid variable key: {key}. It should have at least two elements.") - if key in user_inputs: - continue - node_variable_key = ".".join(node_variable_list[1:]) - if node_variable_key in user_inputs: - continue - if variable_pool.get(selector) is None: - variables_to_load.append(list(selector)) - loaded = variable_loader.load_variables(variables_to_load) - for var in loaded: - assert len(var.selector) >= SELECTORS_LENGTH, f"Invalid variable {var}" - # Add variable directly to the pool - # The variable pool expects 2-element selectors [node_id, variable_name] - variable_pool.add([var.selector[0], var.selector[1]], var) diff --git a/api/graphon/variables/__init__.py b/api/graphon/variables/__init__.py deleted file mode 100644 index e9beb6cb951..00000000000 --- a/api/graphon/variables/__init__.py +++ /dev/null @@ -1,82 +0,0 @@ -from .factory import ( - TypeMismatchError, - UnsupportedSegmentTypeError, - build_segment, - build_segment_with_type, - segment_to_variable, -) -from .input_entities import VariableEntity, VariableEntityType -from .segment_group import SegmentGroup -from .segments import ( - ArrayAnySegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, -) -from .types import SegmentType -from .variables import ( - ArrayAnyVariable, - ArrayFileVariable, - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - ArrayVariable, - FileVariable, - FloatVariable, - IntegerVariable, - NoneVariable, - ObjectVariable, - SecretVariable, - StringVariable, - Variable, - VariableBase, -) - -__all__ = [ - "ArrayAnySegment", - "ArrayAnyVariable", - "ArrayFileSegment", - "ArrayFileVariable", - "ArrayNumberSegment", - "ArrayNumberVariable", - "ArrayObjectSegment", - "ArrayObjectVariable", - "ArraySegment", - "ArrayStringSegment", - "ArrayStringVariable", - "ArrayVariable", - "FileSegment", - "FileVariable", - "FloatSegment", - "FloatVariable", - "IntegerSegment", - "IntegerVariable", - "NoneSegment", - "NoneVariable", - "ObjectSegment", - "ObjectVariable", - "SecretVariable", - "Segment", - "SegmentGroup", - "SegmentType", - "StringSegment", - "StringVariable", - "TypeMismatchError", - "UnsupportedSegmentTypeError", - "Variable", - "VariableBase", - "VariableEntity", - "VariableEntityType", - "build_segment", - "build_segment_with_type", - "segment_to_variable", -] diff --git a/api/graphon/variables/consts.py b/api/graphon/variables/consts.py deleted file mode 100644 index 8f3f78f740f..00000000000 --- a/api/graphon/variables/consts.py +++ /dev/null @@ -1,7 +0,0 @@ -# The minimal selector length for valid variables. -# -# The first element of the selector is the node id, and the second element is the variable name. -# -# If the selector length is more than 2, the remaining parts are the keys / indexes paths used -# to extract part of the variable value. -SELECTORS_LENGTH = 2 diff --git a/api/graphon/variables/exc.py b/api/graphon/variables/exc.py deleted file mode 100644 index 5cf67c3bacc..00000000000 --- a/api/graphon/variables/exc.py +++ /dev/null @@ -1,2 +0,0 @@ -class VariableError(ValueError): - pass diff --git a/api/graphon/variables/factory.py b/api/graphon/variables/factory.py deleted file mode 100644 index ac693914a70..00000000000 --- a/api/graphon/variables/factory.py +++ /dev/null @@ -1,202 +0,0 @@ -"""Graph-owned helpers for converting runtime values, segments, and variables. - -These conversions are part of the `graphon` runtime model and must stay -independent from top-level API factory modules so graph nodes and state -containers can operate without importing application-layer packages. -""" - -from collections.abc import Mapping, Sequence -from typing import Any, cast -from uuid import uuid4 - -from graphon.file import File - -from .segments import ( - ArrayAnySegment, - ArrayBooleanSegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - BooleanSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, -) -from .types import SegmentType -from .variables import ( - ArrayAnyVariable, - ArrayBooleanVariable, - ArrayFileVariable, - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - BooleanVariable, - FileVariable, - FloatVariable, - IntegerVariable, - NoneVariable, - ObjectVariable, - StringVariable, - VariableBase, -) - - -class UnsupportedSegmentTypeError(Exception): - pass - - -class TypeMismatchError(Exception): - pass - - -SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[Any]] = { - ArrayAnySegment: ArrayAnyVariable, - ArrayBooleanSegment: ArrayBooleanVariable, - ArrayFileSegment: ArrayFileVariable, - ArrayNumberSegment: ArrayNumberVariable, - ArrayObjectSegment: ArrayObjectVariable, - ArrayStringSegment: ArrayStringVariable, - BooleanSegment: BooleanVariable, - FileSegment: FileVariable, - FloatSegment: FloatVariable, - IntegerSegment: IntegerVariable, - NoneSegment: NoneVariable, - ObjectSegment: ObjectVariable, - StringSegment: StringVariable, -} - - -def build_segment(value: Any, /) -> Segment: - """Build a runtime segment from a Python value.""" - if value is None: - return NoneSegment() - if isinstance(value, Segment): - return value - if isinstance(value, str): - return StringSegment(value=value) - if isinstance(value, bool): - return BooleanSegment(value=value) - if isinstance(value, int): - return IntegerSegment(value=value) - if isinstance(value, float): - return FloatSegment(value=value) - if isinstance(value, dict): - return ObjectSegment(value=value) - if isinstance(value, File): - return FileSegment(value=value) - if isinstance(value, list): - items = [build_segment(item) for item in value] - types = {item.value_type for item in items} - if all(isinstance(item, ArraySegment) for item in items): - return ArrayAnySegment(value=value) - if len(types) != 1: - if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): - return ArrayNumberSegment(value=value) - return ArrayAnySegment(value=value) - - match types.pop(): - case SegmentType.STRING: - return ArrayStringSegment(value=value) - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return ArrayNumberSegment(value=value) - case SegmentType.BOOLEAN: - return ArrayBooleanSegment(value=value) - case SegmentType.OBJECT: - return ArrayObjectSegment(value=value) - case SegmentType.FILE: - return ArrayFileSegment(value=value) - case SegmentType.NONE: - return ArrayAnySegment(value=value) - case _: - raise ValueError(f"not supported value {value}") - raise ValueError(f"not supported value {value}") - - -_SEGMENT_FACTORY: Mapping[SegmentType, type[Segment]] = { - SegmentType.NONE: NoneSegment, - SegmentType.STRING: StringSegment, - SegmentType.INTEGER: IntegerSegment, - SegmentType.FLOAT: FloatSegment, - SegmentType.FILE: FileSegment, - SegmentType.BOOLEAN: BooleanSegment, - SegmentType.OBJECT: ObjectSegment, - SegmentType.ARRAY_ANY: ArrayAnySegment, - SegmentType.ARRAY_STRING: ArrayStringSegment, - SegmentType.ARRAY_NUMBER: ArrayNumberSegment, - SegmentType.ARRAY_OBJECT: ArrayObjectSegment, - SegmentType.ARRAY_FILE: ArrayFileSegment, - SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment, -} - - -def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: - """Build a segment while enforcing compatibility with the expected runtime type.""" - if value is None: - if segment_type == SegmentType.NONE: - return NoneSegment() - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") - - if isinstance(value, list) and len(value) == 0: - if segment_type == SegmentType.ARRAY_ANY: - return ArrayAnySegment(value=value) - if segment_type == SegmentType.ARRAY_STRING: - return ArrayStringSegment(value=value) - if segment_type == SegmentType.ARRAY_BOOLEAN: - return ArrayBooleanSegment(value=value) - if segment_type == SegmentType.ARRAY_NUMBER: - return ArrayNumberSegment(value=value) - if segment_type == SegmentType.ARRAY_OBJECT: - return ArrayObjectSegment(value=value) - if segment_type == SegmentType.ARRAY_FILE: - return ArrayFileSegment(value=value) - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") - - inferred_type = SegmentType.infer_segment_type(value) - if inferred_type is None: - raise TypeMismatchError( - f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" - ) - if inferred_type == segment_type: - segment_class = _SEGMENT_FACTORY[segment_type] - return segment_class(value_type=segment_type, value=value) - if segment_type == SegmentType.NUMBER and inferred_type in (SegmentType.INTEGER, SegmentType.FLOAT): - segment_class = _SEGMENT_FACTORY[inferred_type] - return segment_class(value_type=inferred_type, value=value) - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") - - -def segment_to_variable( - *, - segment: Segment, - selector: Sequence[str], - id: str | None = None, - name: str | None = None, - description: str = "", -) -> VariableBase: - """Convert a runtime segment into a runtime variable for storage in the pool.""" - if isinstance(segment, VariableBase): - return segment - name = name or selector[-1] - id = id or str(uuid4()) - - segment_type = type(segment) - if segment_type not in SEGMENT_TO_VARIABLE_MAP: - raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") - - variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] - return cast( - VariableBase, - variable_class( - id=id, - name=name, - description=description, - value=segment.value, - selector=list(selector), - ), - ) diff --git a/api/graphon/variables/input_entities.py b/api/graphon/variables/input_entities.py deleted file mode 100644 index c46ee47714c..00000000000 --- a/api/graphon/variables/input_entities.py +++ /dev/null @@ -1,62 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum -from typing import Any - -from jsonschema import Draft7Validator, SchemaError -from pydantic import BaseModel, Field, field_validator - -from graphon.file import FileTransferMethod, FileType - - -class VariableEntityType(StrEnum): - TEXT_INPUT = "text-input" - SELECT = "select" - PARAGRAPH = "paragraph" - NUMBER = "number" - EXTERNAL_DATA_TOOL = "external_data_tool" - FILE = "file" - FILE_LIST = "file-list" - CHECKBOX = "checkbox" - JSON_OBJECT = "json_object" - - -class VariableEntity(BaseModel): - """ - Shared variable entity used by workflow runtime and app configuration. - """ - - # `variable` records the name of the variable in user inputs. - variable: str - label: str - description: str = "" - type: VariableEntityType - required: bool = False - hide: bool = False - default: Any = None - max_length: int | None = None - options: Sequence[str] = Field(default_factory=list) - allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) - allowed_file_extensions: Sequence[str] | None = Field(default_factory=list) - allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list) - json_schema: dict[str, Any] | None = Field(default=None) - - @field_validator("description", mode="before") - @classmethod - def convert_none_description(cls, value: Any) -> str: - return value or "" - - @field_validator("options", mode="before") - @classmethod - def convert_none_options(cls, value: Any) -> Sequence[str]: - return value or [] - - @field_validator("json_schema") - @classmethod - def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None: - if schema is None: - return None - try: - Draft7Validator.check_schema(schema) - except SchemaError as error: - raise ValueError(f"Invalid JSON schema: {error.message}") - return schema diff --git a/api/graphon/variables/segment_group.py b/api/graphon/variables/segment_group.py deleted file mode 100644 index b363255b2ca..00000000000 --- a/api/graphon/variables/segment_group.py +++ /dev/null @@ -1,22 +0,0 @@ -from .segments import Segment -from .types import SegmentType - - -class SegmentGroup(Segment): - value_type: SegmentType = SegmentType.GROUP - value: list[Segment] - - @property - def text(self): - return "".join([segment.text for segment in self.value]) - - @property - def log(self): - return "".join([segment.log for segment in self.value]) - - @property - def markdown(self): - return "".join([segment.markdown for segment in self.value]) - - def to_object(self): - return [segment.to_object() for segment in self.value] diff --git a/api/graphon/variables/segments.py b/api/graphon/variables/segments.py deleted file mode 100644 index 8902ddc7e9c..00000000000 --- a/api/graphon/variables/segments.py +++ /dev/null @@ -1,253 +0,0 @@ -import json -import sys -from collections.abc import Mapping, Sequence -from typing import Annotated, Any, TypeAlias - -from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator - -from graphon.file import File - -from .types import SegmentType - - -class Segment(BaseModel): - """Segment is runtime type used during the execution of workflow. - - Note: this class is abstract, you should use subclasses of this class instead. - """ - - model_config = ConfigDict(frozen=True) - - value_type: SegmentType - value: Any - - @field_validator("value_type") - @classmethod - def validate_value_type(cls, value): - """ - This validator checks if the provided value is equal to the default value of the 'value_type' field. - If the value is different, a ValueError is raised. - """ - if value != cls.model_fields["value_type"].default: - raise ValueError("Cannot modify 'value_type'") - return value - - @property - def text(self) -> str: - return str(self.value) - - @property - def log(self) -> str: - return str(self.value) - - @property - def markdown(self) -> str: - return str(self.value) - - @property - def size(self) -> int: - """ - Return the size of the value in bytes. - """ - return sys.getsizeof(self.value) - - def to_object(self): - return self.value - - -class NoneSegment(Segment): - value_type: SegmentType = SegmentType.NONE - value: None = None - - @property - def text(self) -> str: - return "" - - @property - def log(self) -> str: - return "" - - @property - def markdown(self) -> str: - return "" - - -class StringSegment(Segment): - value_type: SegmentType = SegmentType.STRING - value: str - - -class FloatSegment(Segment): - value_type: SegmentType = SegmentType.FLOAT - value: float - # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. - # The following tests cannot pass. - # - # def test_float_segment_and_nan(): - # nan = float("nan") - # assert nan != nan - # - # f1 = FloatSegment(value=float("nan")) - # f2 = FloatSegment(value=float("nan")) - # assert f1 != f2 - # - # f3 = FloatSegment(value=nan) - # f4 = FloatSegment(value=nan) - # assert f3 != f4 - - -class IntegerSegment(Segment): - value_type: SegmentType = SegmentType.INTEGER - value: int - - -class ObjectSegment(Segment): - value_type: SegmentType = SegmentType.OBJECT - value: Mapping[str, Any] - - @property - def text(self) -> str: - return json.dumps(self.model_dump()["value"], ensure_ascii=False) - - @property - def log(self) -> str: - return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) - - @property - def markdown(self) -> str: - return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) - - -class ArraySegment(Segment): - @property - def text(self) -> str: - # Return empty string for empty arrays instead of "[]" - if not self.value: - return "" - return super().text - - @property - def markdown(self) -> str: - items = [] - for item in self.value: - items.append(f"- {item}") - return "\n".join(items) - - -class FileSegment(Segment): - value_type: SegmentType = SegmentType.FILE - value: File - - @property - def markdown(self) -> str: - return self.value.markdown - - @property - def log(self) -> str: - return "" - - @property - def text(self) -> str: - return "" - - -class BooleanSegment(Segment): - value_type: SegmentType = SegmentType.BOOLEAN - value: bool - - -class ArrayAnySegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_ANY - value: Sequence[Any] - - -class ArrayStringSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_STRING - value: Sequence[str] - - @property - def text(self) -> str: - # Return empty string for empty arrays instead of "[]" - if not self.value: - return "" - return json.dumps(self.value, ensure_ascii=False) - - -class ArrayNumberSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_NUMBER - value: Sequence[float | int] - - -class ArrayObjectSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_OBJECT - value: Sequence[Mapping[str, Any]] - - -class ArrayFileSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_FILE - value: Sequence[File] - - @property - def markdown(self) -> str: - items = [] - for item in self.value: - items.append(item.markdown) - return "\n".join(items) - - @property - def log(self) -> str: - return "" - - @property - def text(self) -> str: - return "" - - -class ArrayBooleanSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_BOOLEAN - value: Sequence[bool] - - -def get_segment_discriminator(v: Any) -> SegmentType | None: - if isinstance(v, Segment): - return v.value_type - elif isinstance(v, dict): - value_type = v.get("value_type") - if value_type is None: - return None - try: - seg_type = SegmentType(value_type) - except ValueError: - return None - return seg_type - else: - # return None if the discriminator value isn't found - return None - - -# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic. -# Use `Segment` for type hinting when serialization is not required. -# -# Note: -# - All variants in `SegmentUnion` must inherit from the `Segment` class. -# - The union must include all non-abstract subclasses of `Segment`, except: -# - `SegmentGroup`, which is not added to the variable pool. -# - `VariableBase` and its subclasses, which are handled by `Variable`. -SegmentUnion: TypeAlias = Annotated[ - ( - Annotated[NoneSegment, Tag(SegmentType.NONE)] - | Annotated[StringSegment, Tag(SegmentType.STRING)] - | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] - | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] - | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] - | Annotated[FileSegment, Tag(SegmentType.FILE)] - | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)] - | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] - | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] - | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] - | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] - | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] - | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)] - ), - Discriminator(get_segment_discriminator), -] diff --git a/api/graphon/variables/types.py b/api/graphon/variables/types.py deleted file mode 100644 index 949a693ad2b..00000000000 --- a/api/graphon/variables/types.py +++ /dev/null @@ -1,273 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from enum import StrEnum -from typing import TYPE_CHECKING, Any - -from graphon.file.models import File - -if TYPE_CHECKING: - from graphon.variables.segments import Segment - - -class ArrayValidation(StrEnum): - """Strategy for validating array elements. - - Note: - The `NONE` and `FIRST` strategies are primarily for compatibility purposes. - Avoid using them in new code whenever possible. - """ - - # Skip element validation (only check array container) - NONE = "none" - - # Validate the first element (if array is non-empty) - FIRST = "first" - - # Validate all elements in the array. - ALL = "all" - - -class SegmentType(StrEnum): - NUMBER = "number" - INTEGER = "integer" - FLOAT = "float" - STRING = "string" - OBJECT = "object" - SECRET = "secret" - - FILE = "file" - BOOLEAN = "boolean" - - ARRAY_ANY = "array[any]" - ARRAY_STRING = "array[string]" - ARRAY_NUMBER = "array[number]" - ARRAY_OBJECT = "array[object]" - ARRAY_FILE = "array[file]" - ARRAY_BOOLEAN = "array[boolean]" - - NONE = "none" - - GROUP = "group" - - def is_array_type(self) -> bool: - return self in _ARRAY_TYPES - - @classmethod - def infer_segment_type(cls, value: Any) -> SegmentType | None: - """ - Attempt to infer the `SegmentType` based on the Python type of the `value` parameter. - - Returns `None` if no appropriate `SegmentType` can be determined for the given `value`. - For example, this may occur if the input is a generic Python object of type `object`. - """ - - if isinstance(value, list): - elem_types: set[SegmentType] = set() - for i in value: - segment_type = cls.infer_segment_type(i) - if segment_type is None: - return None - - elem_types.add(segment_type) - - if len(elem_types) != 1: - if elem_types.issubset(_NUMERICAL_TYPES): - return SegmentType.ARRAY_NUMBER - return SegmentType.ARRAY_ANY - elif all(i.is_array_type() for i in elem_types): - return SegmentType.ARRAY_ANY - match elem_types.pop(): - case SegmentType.STRING: - return SegmentType.ARRAY_STRING - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return SegmentType.ARRAY_NUMBER - case SegmentType.OBJECT: - return SegmentType.ARRAY_OBJECT - case SegmentType.FILE: - return SegmentType.ARRAY_FILE - case SegmentType.NONE: - return SegmentType.ARRAY_ANY - case SegmentType.BOOLEAN: - return SegmentType.ARRAY_BOOLEAN - case _: - # This should be unreachable. - raise ValueError(f"not supported value {value}") - if value is None: - return SegmentType.NONE - # Important: The check for `bool` must precede the check for `int`, - # as `bool` is a subclass of `int` in Python's type hierarchy. - elif isinstance(value, bool): - return SegmentType.BOOLEAN - elif isinstance(value, int): - return SegmentType.INTEGER - elif isinstance(value, float): - return SegmentType.FLOAT - elif isinstance(value, str): - return SegmentType.STRING - elif isinstance(value, dict): - return SegmentType.OBJECT - elif isinstance(value, File): - return SegmentType.FILE - else: - return None - - def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool: - if not isinstance(value, list): - return False - # Skip element validation if array is empty - if len(value) == 0: - return True - if self == SegmentType.ARRAY_ANY: - return True - element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self] - - if array_validation == ArrayValidation.NONE: - return True - elif array_validation == ArrayValidation.FIRST: - return element_type.is_valid(value[0]) - else: - return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value) - - def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.ALL) -> bool: - """ - Check if a value matches the segment type. - Users of `SegmentType` should call this method, instead of using - `isinstance` manually. - - Args: - value: The value to validate - array_validation: Validation strategy for array types (ignored for non-array types) - - Returns: - True if the value matches the type under the given validation strategy - """ - if self.is_array_type(): - return self._validate_array(value, array_validation) - # Important: The check for `bool` must precede the check for `int`, - # as `bool` is a subclass of `int` in Python's type hierarchy. - elif self == SegmentType.BOOLEAN: - return isinstance(value, bool) - elif self in [SegmentType.INTEGER, SegmentType.FLOAT, SegmentType.NUMBER]: - return isinstance(value, (int, float)) - elif self == SegmentType.STRING: - return isinstance(value, str) - elif self == SegmentType.OBJECT: - return isinstance(value, dict) - elif self == SegmentType.SECRET: - return isinstance(value, str) - elif self == SegmentType.FILE: - return isinstance(value, File) - elif self == SegmentType.NONE: - return value is None - elif self == SegmentType.GROUP: - from .segment_group import SegmentGroup - from .segments import Segment - - if isinstance(value, SegmentGroup): - return all(isinstance(item, Segment) for item in value.value) - - if isinstance(value, list): - return all(isinstance(item, Segment) for item in value) - - return False - else: - raise AssertionError("this statement should be unreachable.") - - @staticmethod - def cast_value(value: Any, type_: SegmentType): - # Cast Python's `bool` type to `int` when the runtime type requires - # an integer or number. - # - # This ensures compatibility with existing workflows that may use `bool` as - # `int`, since in Python's type system, `bool` is a subtype of `int`. - # - # This function exists solely to maintain compatibility with existing workflows. - # It should not be used to compromise the integrity of the runtime type system. - # No additional casting rules should be introduced to this function. - - if type_ in ( - SegmentType.INTEGER, - SegmentType.NUMBER, - ) and isinstance(value, bool): - return int(value) - if type_ == SegmentType.ARRAY_NUMBER and all(isinstance(i, bool) for i in value): - return [int(i) for i in value] - return value - - def exposed_type(self) -> SegmentType: - """Returns the type exposed to the frontend. - - The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here. - """ - if self in (SegmentType.INTEGER, SegmentType.FLOAT): - return SegmentType.NUMBER - return self - - def element_type(self) -> SegmentType | None: - """Return the element type of the current segment type, or `None` if the element type is undefined. - - Raises: - ValueError: If the current segment type is not an array type. - - Note: - For certain array types, such as `SegmentType.ARRAY_ANY`, their element types are not defined - by the runtime system. In such cases, this method will return `None`. - """ - if not self.is_array_type(): - raise ValueError(f"element_type is only supported by array type, got {self}") - return _ARRAY_ELEMENT_TYPES_MAPPING.get(self) - - @staticmethod - def get_zero_value(t: SegmentType) -> Segment: - # Lazy import to avoid circular dependency between segment types and factory helpers. - from graphon.variables.factory import build_segment, build_segment_with_type - - match t: - case ( - SegmentType.ARRAY_OBJECT - | SegmentType.ARRAY_ANY - | SegmentType.ARRAY_STRING - | SegmentType.ARRAY_NUMBER - | SegmentType.ARRAY_BOOLEAN - ): - return build_segment_with_type(t, []) - case SegmentType.OBJECT: - return build_segment({}) - case SegmentType.STRING: - return build_segment("") - case SegmentType.INTEGER: - return build_segment(0) - case SegmentType.FLOAT: - return build_segment(0.0) - case SegmentType.NUMBER: - return build_segment(0) - case SegmentType.BOOLEAN: - return build_segment(False) - case _: - raise ValueError(f"unsupported variable type: {t}") - - -_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = { - # ARRAY_ANY does not have corresponding element type. - SegmentType.ARRAY_STRING: SegmentType.STRING, - SegmentType.ARRAY_NUMBER: SegmentType.NUMBER, - SegmentType.ARRAY_OBJECT: SegmentType.OBJECT, - SegmentType.ARRAY_FILE: SegmentType.FILE, - SegmentType.ARRAY_BOOLEAN: SegmentType.BOOLEAN, -} - -_ARRAY_TYPES = frozenset( - list(_ARRAY_ELEMENT_TYPES_MAPPING.keys()) - + [ - SegmentType.ARRAY_ANY, - ] -) - -_NUMERICAL_TYPES = frozenset( - [ - SegmentType.NUMBER, - SegmentType.INTEGER, - SegmentType.FLOAT, - ] -) diff --git a/api/graphon/variables/utils.py b/api/graphon/variables/utils.py deleted file mode 100644 index 8e738f8fd5f..00000000000 --- a/api/graphon/variables/utils.py +++ /dev/null @@ -1,33 +0,0 @@ -from collections.abc import Iterable, Sequence -from typing import Any - -import orjson - -from .segment_group import SegmentGroup -from .segments import ArrayFileSegment, FileSegment, Segment - - -def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]: - selectors = [node_id, name] - if paths: - selectors.extend(paths) - return selectors - - -def segment_orjson_default(o: Any): - """Default function for orjson serialization of Segment types""" - if isinstance(o, ArrayFileSegment): - return [v.model_dump() for v in o.value] - elif isinstance(o, FileSegment): - return o.value.model_dump() - elif isinstance(o, SegmentGroup): - return [segment_orjson_default(seg) for seg in o.value] - elif isinstance(o, Segment): - return o.value - raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable") - - -def dumps_with_segments(obj: Any, ensure_ascii: bool = False) -> str: - """JSON dumps with segment support using orjson""" - option = orjson.OPT_NON_STR_KEYS - return orjson.dumps(obj, default=segment_orjson_default, option=option).decode("utf-8") diff --git a/api/graphon/variables/variables.py b/api/graphon/variables/variables.py deleted file mode 100644 index af866283dac..00000000000 --- a/api/graphon/variables/variables.py +++ /dev/null @@ -1,172 +0,0 @@ -from collections.abc import Sequence -from typing import Annotated, Any, TypeAlias -from uuid import uuid4 - -from pydantic import BaseModel, Discriminator, Field, Tag - -from .segments import ( - ArrayAnySegment, - ArrayBooleanSegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - BooleanSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, - get_segment_discriminator, -) -from .types import SegmentType - - -def _obfuscated_token(token: str) -> str: - if not token: - return token - if len(token) <= 8: - return "*" * 20 - return token[:6] + "*" * 12 + token[-2:] - - -class VariableBase(Segment): - """ - A variable is a segment that has a name. - - It is mainly used to store segments and their selector in VariablePool. - - Note: this class is abstract, you should use subclasses of this class instead. - """ - - id: str = Field( - default_factory=lambda: str(uuid4()), - description="Unique identity for variable.", - ) - name: str - description: str = Field(default="", description="Description of the variable.") - selector: Sequence[str] = Field(default_factory=list) - - -class StringVariable(StringSegment, VariableBase): - pass - - -class FloatVariable(FloatSegment, VariableBase): - pass - - -class IntegerVariable(IntegerSegment, VariableBase): - pass - - -class ObjectVariable(ObjectSegment, VariableBase): - pass - - -class ArrayVariable(ArraySegment, VariableBase): - pass - - -class ArrayAnyVariable(ArrayAnySegment, ArrayVariable): - pass - - -class ArrayStringVariable(ArrayStringSegment, ArrayVariable): - pass - - -class ArrayNumberVariable(ArrayNumberSegment, ArrayVariable): - pass - - -class ArrayObjectVariable(ArrayObjectSegment, ArrayVariable): - pass - - -class SecretVariable(StringVariable): - value_type: SegmentType = SegmentType.SECRET - - @property - def log(self) -> str: - return _obfuscated_token(self.value) - - -class NoneVariable(NoneSegment, VariableBase): - value_type: SegmentType = SegmentType.NONE - value: None = None - - -class FileVariable(FileSegment, VariableBase): - pass - - -class BooleanVariable(BooleanSegment, VariableBase): - pass - - -class ArrayFileVariable(ArrayFileSegment, ArrayVariable): - pass - - -class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable): - pass - - -class RAGPipelineVariable(BaseModel): - belong_to_node_id: str = Field(description="belong to which node id, shared means public") - type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list") - label: str = Field(description="label") - description: str | None = Field(description="description", default="") - variable: str = Field(description="variable key", default="") - max_length: int | None = Field( - description="max length, applicable to text-input, paragraph, and file-list", default=0 - ) - default_value: Any = Field(description="default value", default="") - placeholder: str | None = Field(description="placeholder", default="") - unit: str | None = Field(description="unit, applicable to Number", default="") - tooltips: str | None = Field(description="helpful text", default="") - allowed_file_types: list[str] | None = Field( - description="image, document, audio, video, custom.", default_factory=list - ) - allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list) - allowed_file_upload_methods: list[str] | None = Field( - description="remote_url, local_file, tool_file.", default_factory=list - ) - required: bool = Field(description="optional, default false", default=False) - options: list[str] | None = Field(default_factory=list) - - -class RAGPipelineVariableInput(BaseModel): - variable: RAGPipelineVariable - value: Any - - -# The `Variable` type is used to enable serialization and deserialization with Pydantic. -# Use `VariableBase` for type hinting when serialization is not required. -# -# Note: -# - All variants in `Variable` must inherit from the `VariableBase` class. -# - The union must include all non-abstract subclasses of `VariableBase`. -Variable: TypeAlias = Annotated[ - ( - Annotated[NoneVariable, Tag(SegmentType.NONE)] - | Annotated[StringVariable, Tag(SegmentType.STRING)] - | Annotated[FloatVariable, Tag(SegmentType.FLOAT)] - | Annotated[IntegerVariable, Tag(SegmentType.INTEGER)] - | Annotated[ObjectVariable, Tag(SegmentType.OBJECT)] - | Annotated[FileVariable, Tag(SegmentType.FILE)] - | Annotated[BooleanVariable, Tag(SegmentType.BOOLEAN)] - | Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)] - | Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)] - | Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)] - | Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)] - | Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)] - | Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)] - | Annotated[SecretVariable, Tag(SegmentType.SECRET)] - ), - Discriminator(get_segment_discriminator), -] diff --git a/api/graphon/workflow_type_encoder.py b/api/graphon/workflow_type_encoder.py deleted file mode 100644 index 7cdc83ebdbb..00000000000 --- a/api/graphon/workflow_type_encoder.py +++ /dev/null @@ -1,49 +0,0 @@ -from collections.abc import Mapping -from decimal import Decimal -from typing import Any, overload - -from pydantic import BaseModel - -from graphon.file.models import File -from graphon.variables import Segment - - -class WorkflowRuntimeTypeConverter: - @overload - def to_json_encodable(self, value: Mapping[str, Any]) -> Mapping[str, Any]: ... - @overload - def to_json_encodable(self, value: None) -> None: ... - - def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: - """Convert runtime values to JSON-serializable structures.""" - - result = self.value_to_json_encodable_recursive(value) - if isinstance(result, Mapping) or result is None: - return result - return {} - - def value_to_json_encodable_recursive(self, value: Any): - if value is None: - return value - if isinstance(value, (bool, int, str, float)): - return value - if isinstance(value, Decimal): - # Convert Decimal to float for JSON serialization - return float(value) - if isinstance(value, Segment): - return self.value_to_json_encodable_recursive(value.value) - if isinstance(value, File): - return value.to_dict() - if isinstance(value, BaseModel): - return value.model_dump(mode="json") - if isinstance(value, dict): - res = {} - for k, v in value.items(): - res[k] = self.value_to_json_encodable_recursive(v) - return res - if isinstance(value, list): - res_list = [] - for item in value: - res_list.append(self.value_to_json_encodable_recursive(item)) - return res_list - return value diff --git a/api/libs/helper.py b/api/libs/helper.py index b1815859a54..a7b3da77ff8 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -16,14 +16,14 @@ from zoneinfo import available_timezones from flask import Response, stream_with_context from flask_restx import fields +from graphon.file import helpers as file_helpers +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel from pydantic.functional_validators import AfterValidator from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from extensions.ext_redis import redis_client -from graphon.file import helpers as file_helpers -from graphon.model_runtime.utils.encoders import jsonable_encoder if TYPE_CHECKING: from models import Account diff --git a/api/models/human_input.py b/api/models/human_input.py index b4c7a634b65..79c5d62f6a8 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -3,11 +3,11 @@ from enum import StrEnum from typing import Annotated, Literal, Self, final import sqlalchemy as sa +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship from core.workflow.human_input_compat import DeliveryMethodType -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.helper import generate_string from .base import Base, DefaultFieldsMixin diff --git a/api/models/model.py b/api/models/model.py index bcb142db56d..b03cb7711fd 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -14,6 +14,9 @@ from uuid import uuid4 import sqlalchemy as sa from flask import request from flask_login import UserMixin # type: ignore[import-untyped] +from graphon.enums import WorkflowExecutionStatus +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from graphon.file import helpers as file_helpers from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column from typing_extensions import TypedDict @@ -22,9 +25,6 @@ from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file from extensions.storage.storage_type import StorageType -from graphon.enums import WorkflowExecutionStatus -from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType -from graphon.file import helpers as file_helpers from libs.helper import generate_string # type: ignore[import-not-found] from libs.uuid_utils import uuidv7 from models.utils.file_input_compat import build_file_from_input_mapping diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py index dee1cc507a7..f71583c1cde 100644 --- a/api/models/utils/file_input_compat.py +++ b/api/models/utils/file_input_compat.py @@ -4,9 +4,10 @@ from collections.abc import Callable, Mapping from functools import lru_cache from typing import Any -from core.workflow.file_reference import parse_file_reference from graphon.file import File, FileTransferMethod +from core.workflow.file_reference import parse_file_reference + @lru_cache(maxsize=1) def _get_file_access_controller(): diff --git a/api/models/workflow.py b/api/models/workflow.py index 0557e2e890d..f8868cb73cc 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -8,6 +8,19 @@ from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast from uuid import uuid4 import sqlalchemy as sa +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import ( + BuiltinNodeTypes, + NodeType, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import File +from graphon.file.constants import maybe_file_object +from graphon.variables import utils as variable_utils +from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from sqlalchemy import ( DateTime, Index, @@ -31,19 +44,6 @@ from core.workflow.variable_prefixes import ( ) from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from graphon.enums import ( - BuiltinNodeTypes, - NodeType, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file.constants import maybe_file_object -from graphon.file.models import File -from graphon.variables import utils as variable_utils -from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 @@ -53,10 +53,11 @@ if TYPE_CHECKING: from .model import AppMode, UploadFile +from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase + from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter from factories import variable_factory -from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase from libs import helper from .account import Account @@ -1466,8 +1467,6 @@ class WorkflowDraftVariable(Base): # From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than # 80 chars. - # - # ref: api/graphon/entities/variable_pool.py:18 name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column( sa.String(255), diff --git a/api/pyproject.toml b/api/pyproject.toml index b1f1f4bb2e8..9c94474cdf8 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -28,11 +28,11 @@ dependencies = [ "google-auth-httplib2==0.3.0", "google-cloud-aiplatform>=1.123.0", "googleapis-common-protos>=1.65.0", + "graphon>=0.1.2", "gunicorn~=25.1.0", "httpx[socks]~=0.28.0", "jieba==0.42.1", "json-repair>=0.55.1", - "jsonschema>=4.25.1", "langfuse~=2.51.3", "langsmith~=0.7.16", "markdown~=3.10.2", @@ -63,7 +63,6 @@ dependencies = [ "psycopg2-binary~=2.9.6", "pycryptodome==3.23.0", "pydantic~=2.12.5", - "pydantic-extra-types~=2.11.0", "pydantic-settings~=2.13.1", "pyjwt~=2.12.0", "pypdfium2==5.6.0", @@ -81,7 +80,6 @@ dependencies = [ "unstructured[docx,epub,md,ppt,pptx]~=0.21.5", "pypandoc~=1.13", "yarl~=1.23.0", - "webvtt-py~=0.5.1", "sseclient-py~=1.9.0", "httpx-sse~=0.4.0", "sendgrid~=6.12.3", @@ -130,7 +128,6 @@ dev = [ "types-defusedxml~=0.7.0", "types-deprecated~=1.3.1", "types-docutils~=0.22.3", - "types-jsonschema~=4.26.0", "types-flask-cors~=6.0.0", "types-flask-migrate~=4.1.0", "types-gevent~=25.9.0", diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index cf002df2a93..43f604c2de9 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -118,34 +118,7 @@ enterprise/telemetry/exporter.py enterprise/telemetry/id_generator.py enterprise/telemetry/metric_handler.py enterprise/telemetry/telemetry_log.py -graphon/entities/workflow_execution.py -graphon/file/file_manager.py -graphon/graph_engine/error_handler.py -graphon/graph_engine/layers/execution_limits.py -graphon/nodes/agent/agent_node.py -graphon/nodes/base/node.py -graphon/nodes/code/code_node.py -graphon/nodes/datasource/datasource_node.py -graphon/nodes/document_extractor/node.py -graphon/nodes/human_input/human_input_node.py -graphon/nodes/if_else/if_else_node.py -graphon/nodes/iteration/iteration_node.py -graphon/nodes/knowledge_index/knowledge_index_node.py core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py -graphon/nodes/list_operator/node.py -graphon/nodes/llm/node.py -graphon/nodes/loop/loop_node.py -graphon/nodes/parameter_extractor/parameter_extractor_node.py -graphon/nodes/question_classifier/question_classifier_node.py -graphon/nodes/start/start_node.py -graphon/nodes/template_transform/template_transform_node.py -graphon/nodes/tool/tool_node.py -graphon/nodes/trigger_plugin/trigger_event_node.py -graphon/nodes/trigger_schedule/trigger_schedule_node.py -graphon/nodes/trigger_webhook/node.py -graphon/nodes/variable_aggregator/variable_aggregator_node.py -graphon/nodes/variable_assigner/v1/node.py -graphon/nodes/variable_assigner/v2/node.py extensions/logstore/repositories/logstore_api_workflow_run_repository.py extensions/otel/instrumentation.py extensions/otel/runtime.py diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index ffc17e92cfb..1a2a539c802 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -38,11 +38,11 @@ from collections.abc import Callable, Sequence from datetime import datetime from typing import Protocol +from graphon.entities.pause_reason import PauseReason +from graphon.enums import WorkflowType from sqlalchemy.orm import Session from core.repositories.factory import WorkflowExecutionRepository -from graphon.entities.pause_reason import PauseReason -from graphon.enums import WorkflowType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 44735eb7699..d5c6a203b14 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -10,11 +10,11 @@ from collections.abc import Sequence from datetime import datetime from typing import Protocol, cast +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload from repositories.api_workflow_node_execution_repository import ( DifyAPIWorkflowNodeExecutionRepository, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 5bb0c74ada7..413936b542b 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -28,15 +28,15 @@ from decimal import Decimal from typing import Any, cast import sqlalchemy as sa +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.nodes.human_input.entities import FormDefinition from pydantic import ValidationError from sqlalchemy import and_, delete, func, null, or_, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage -from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from graphon.enums import WorkflowExecutionStatus, WorkflowType -from graphon.nodes.human_input.entities import FormDefinition from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date from libs.infinite_scroll_pagination import InfiniteScrollPagination diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py index 67f8795d3fd..feba5f7eb65 100644 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ b/api/repositories/sqlalchemy_execution_extra_content_repository.py @@ -7,6 +7,9 @@ from collections import defaultdict from collections.abc import Sequence from typing import Any +from graphon.nodes.human_input.entities import FormDefinition +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode from sqlalchemy import select from sqlalchemy.orm import Session, selectinload, sessionmaker @@ -18,9 +21,6 @@ from core.entities.execution_extra_content import ( from core.entities.execution_extra_content import ( HumanInputContent as HumanInputContentDomainModel, ) -from graphon.nodes.human_input.entities import FormDefinition -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode from models.execution_extra_content import ( ExecutionExtraContent as ExecutionExtraContentModel, ) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 643a2a2a84d..dd73e103746 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -11,6 +11,12 @@ from uuid import uuid4 import yaml from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from packaging import version from packaging.version import parse as parse_version from pydantic import BaseModel, Field @@ -30,12 +36,6 @@ from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerSc from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_redis import redis_client from factories import variable_factory -from graphon.enums import BuiltinNodeTypes -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData -from graphon.nodes.tool.entities import ToolNodeData from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode from models.model import AppModelConfig, AppModelConfigDict, IconType diff --git a/api/services/app_service.py b/api/services/app_service.py index 9413a93fc48..e9aeb6c43d0 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -4,6 +4,8 @@ from typing import Any, TypedDict, cast import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from configs import dify_config from constants.model_template import default_app_templates @@ -14,8 +16,6 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created, app_was_deleted, app_was_updated from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account diff --git a/api/services/app_task_service.py b/api/services/app_task_service.py index 6e9d6b1c737..0842e9d3e7f 100644 --- a/api/services/app_task_service.py +++ b/api/services/app_task_service.py @@ -5,10 +5,11 @@ like stopping tasks, handling both legacy Redis flag mechanism and new GraphEngine command channel mechanism. """ +from graphon.graph_engine.manager import GraphEngineManager + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_redis import redis_client -from graphon.graph_engine.manager import GraphEngineManager from models.model import AppMode diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 9e743bf7b13..90e72d5f34f 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,12 +5,12 @@ from collections.abc import Generator from typing import cast from flask import Response, stream_with_context +from graphon.model_runtime.entities.model_entities import ModelType from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models.enums import MessageStatus from models.model import App, AppMode, Message from services.errors.audio import ( diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index c6b32b373e2..1c128524ad4 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app +from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -13,7 +14,6 @@ from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_storage import storage -from graphon.model_runtime.utils.encoders import jsonable_encoder from models.account import Tenant from models.model import ( App, diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 545c5048d54..ba1e7bb8266 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -3,6 +3,7 @@ import logging from collections.abc import Callable, Sequence from typing import Any, Union +from graphon.variables.types import SegmentType from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session @@ -12,7 +13,6 @@ from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db from factories import variable_factory -from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index 287d513f480..95a8951951c 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ +from graphon.variables.variables import VariableBase from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker -from graphon.variables.variables import VariableBase from models import ConversationVariable diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 3e2342b1a7b..83363125c38 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -10,6 +10,9 @@ from collections.abc import Sequence from typing import Any, Literal, cast import sqlalchemy as sa +from graphon.file import helpers as file_helpers +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from redis.exceptions import LockNotOwnedError from sqlalchemy import exists, func, select from sqlalchemy.orm import Session @@ -28,9 +31,6 @@ from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.file import helpers as file_helpers -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 2b7bebb01e6..06f83a18f7a 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -3,6 +3,7 @@ import time from collections.abc import Mapping from typing import Any +from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy.orm import Session from configs import dify_config @@ -16,7 +17,6 @@ from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.provider_entities import FormType from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID from services.plugin.plugin_service import PluginService diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 6679c08ebd7..a944ef6acdd 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,6 +1,15 @@ from collections.abc import Sequence from enum import StrEnum +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderHelpEntity, + SimpleProviderEntity, +) from pydantic import BaseModel, ConfigDict, model_validator from configs import dify_config @@ -15,15 +24,6 @@ from core.entities.provider_entities import ( QuotaConfiguration, UnaddedModelConfiguration, ) -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderHelpEntity, - SimpleProviderEntity, -) from models.provider import ProviderType diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index d2fa98f5e26..64852c222f3 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -4,13 +4,13 @@ from typing import Any, Union, cast from urllib.parse import urlparse import httpx +from graphon.nodes.http_request.exc import InvalidHttpMethodError from sqlalchemy import select from constants import HIDDEN_VALUE from core.helper import ssrf_proxy from core.rag.entities.metadata_entities import MetadataCondition from extensions.ext_database import db -from graphon.nodes.http_request.exc import InvalidHttpMethodError from libs.datetime_utils import naive_utc_now from models.dataset import ( Dataset, diff --git a/api/services/file_service.py b/api/services/file_service.py index c11f018f527..50a326d8138 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -8,6 +8,7 @@ from tempfile import NamedTemporaryFile from typing import Literal, Union from zipfile import ZIP_DEFLATED, ZipFile +from graphon.file import helpers as file_helpers from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import NotFound @@ -23,7 +24,6 @@ from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType -from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from models import Account diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index d490ad15619..82e0b0f8b1f 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -3,6 +3,8 @@ import logging import time from typing import Any +from graphon.model_runtime.entities import LLMMode + from core.app.app_config.entities import ModelConfig from core.rag.datasource.retrieval_service import RetrievalService from core.rag.index_processor.constant.query_type import QueryType @@ -10,7 +12,6 @@ from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db -from graphon.model_runtime.entities import LLMMode from models import Account from models.dataset import Dataset, DatasetQuery from models.enums import CreatorUserRole, DatasetQuerySource diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 68ef67dec1c..77576fa4c0d 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field from enum import StrEnum from typing import Protocol +from graphon.runtime import VariablePool from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker @@ -17,7 +18,6 @@ from core.workflow.human_input_compat import ( ) from extensions.ext_database import db from extensions.ext_mail import mail -from graphon.runtime import VariablePool from libs.email_template_renderer import render_email_template from models import Account, TenantAccountJoin from services.feature_service import FeatureService diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 76598d31ace..02a6620fc74 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -3,6 +3,12 @@ from collections.abc import Mapping from datetime import datetime, timedelta from typing import Any +from graphon.nodes.human_input.entities import ( + FormDefinition, + HumanInputSubmissionValidationError, + validate_human_input_submission, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -11,12 +17,6 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) -from graphon.nodes.human_input.entities import ( - FormDefinition, - HumanInputSubmissionValidationError, - validate_human_input_submission, -) -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.exception import BaseHTTPException from models.human_input import RecipientType diff --git a/api/services/message_service.py b/api/services/message_service.py index 0c4a334b47b..e5389ef659a 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -2,6 +2,7 @@ import json from collections.abc import Sequence from typing import Union +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import sessionmaker from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager @@ -13,7 +14,6 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.enums import FeedbackFromSource, FeedbackRating diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 469357d6e0c..91cca5cb6d0 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -3,6 +3,12 @@ import logging from json import JSONDecodeError from typing import Union +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ModelCredentialSchema, + ProviderCredentialSchema, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from sqlalchemy import or_, select from constants import HIDDEN_VALUE @@ -13,12 +19,6 @@ from core.model_manager import LBModelManager from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.provider_manager import ProviderManager from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ModelCredentialSchema, - ProviderCredentialSchema, -) -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.enums import CredentialSourceType from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index e634f906036..3f37c9b176d 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,9 +1,10 @@ import logging +from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule + from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager from core.provider_manager import ProviderManager -from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule from models.provider import ProviderType from services.entities.model_provider_entities import ( CustomConfigurationResponse, diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 46a6221fcce..bcf5973d7b2 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -9,6 +9,15 @@ from typing import Any, Union, cast from uuid import uuid4 from flask_login import current_user +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from graphon.errors import WorkflowNodeRunFailedError +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable, VariableBase from sqlalchemy import func, select from sqlalchemy.orm import Session, sessionmaker @@ -48,19 +57,6 @@ from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace from extensions.ext_database import db -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType -from graphon.errors import WorkflowNodeRunFailedError -from graphon.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent -from graphon.graph_events.base import GraphNodeEventBase -from graphon.node_events.base import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from graphon.runtime import VariablePool -from graphon.variables.variables import Variable, VariableBase from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.dataset import ( # type: ignore diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 1b8207cc310..04156713f4f 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -14,6 +14,12 @@ import yaml # type: ignore from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from flask_login import current_user +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from packaging import version from pydantic import BaseModel, Field from sqlalchemy import select @@ -28,12 +34,6 @@ from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_redis import redis_client from factories import variable_factory -from graphon.enums import BuiltinNodeTypes -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData -from graphon.nodes.tool.entities import ToolNodeData from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.enums import CollectionBindingType, DatasetRuntimeMode diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py index c91f621ffb7..2c1f99a3bc9 100644 --- a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -27,13 +27,13 @@ from dataclasses import dataclass, field from typing import Any import click +from graphon.enums import WorkflowType from sqlalchemy import inspect from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db -from graphon.enums import WorkflowType from libs.archive_storage import ( ArchiveStorage, ArchiveStorageNotConfiguredError, diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 4334412c8b3..12053377e24 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -6,6 +6,8 @@ import uuid from datetime import UTC, datetime from typing import Any +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.db.session_factory import session_factory @@ -15,8 +17,6 @@ from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 9190a672498..2a56bc0c71e 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -2,6 +2,7 @@ import json import logging from typing import Any, cast +from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx import get from sqlalchemy import select from typing_extensions import TypedDict @@ -21,7 +22,6 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db -from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider from services.tools.tools_transform_service import ToolTransformService diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 931ca5021a9..fb6b5bea24d 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -2,6 +2,7 @@ import json import logging from datetime import datetime +from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import or_, select from sqlalchemy.orm import Session @@ -13,7 +14,6 @@ from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurati from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db -from graphon.model_runtime.utils.encoders import jsonable_encoder from models.model import App from models.tools import WorkflowToolProvider from models.workflow import Workflow diff --git a/api/services/trigger/schedule_service.py b/api/services/trigger/schedule_service.py index a827222c1dc..25e80770b83 100644 --- a/api/services/trigger/schedule_service.py +++ b/api/services/trigger/schedule_service.py @@ -2,6 +2,7 @@ import json import logging from datetime import datetime +from graphon.entities.graph_config import NodeConfigDict from sqlalchemy import select from sqlalchemy.orm import Session @@ -13,7 +14,6 @@ from core.workflow.nodes.trigger_schedule.entities import ( VisualConfig, ) from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError -from graphon.entities.graph_config import NodeConfigDict from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h from models.account import Account, TenantAccountJoin from models.trigger import WorkflowSchedulePlan diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index dca00a466b1..d72c0416092 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -5,6 +5,7 @@ from collections.abc import Mapping from typing import Any from flask import Request, Response +from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -20,7 +21,6 @@ from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_ from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.entities.graph_config import NodeConfigDict from models.model import App from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription, WorkflowPluginTrigger diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 5d9be84c069..c03275497d7 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -7,6 +7,9 @@ from typing import Any import orjson from flask import request +from graphon.entities.graph_config import NodeConfigDict +from graphon.file import FileTransferMethod +from graphon.variables.types import ArrayValidation, SegmentType from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -28,9 +31,6 @@ from enums.quota_type import QuotaType from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory -from graphon.entities.graph_config import NodeConfigDict -from graphon.file.models import FileTransferMethod -from graphon.variables.types import ArrayValidation, SegmentType from models.enums import AppTriggerStatus, AppTriggerType from models.model import App from models.trigger import AppTrigger, WorkflowWebhookTrigger diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index d0a4317065a..62916cc2c93 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -5,8 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from typing import Any, Generic, TypeAlias, TypeVar, overload -from configs import dify_config -from graphon.file.models import File +from graphon.file import File from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable from graphon.variables.segments import ( ArrayFileSegment, @@ -22,6 +21,8 @@ from graphon.variables.segments import ( ) from graphon.variables.utils import dumps_with_segments +from configs import dify_config + _MAX_DEPTH = 100 diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 5fd310b689a..3f78b823a63 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,5 +1,7 @@ import logging +from graphon.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelInstance, ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector @@ -9,7 +11,6 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models import UploadFile from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 1f3993505c9..31367f72fab 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,6 +1,11 @@ import json from typing import Any +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.input_entities import VariableEntity from typing_extensions import TypedDict from core.app.app_config.entities import ( @@ -19,11 +24,6 @@ from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.app_event import app_was_created from extensions.ext_database import db -from graphon.file.models import FileUploadConfig -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes import BuiltinNodeTypes -from graphon.variables.input_entities import VariableEntity from models import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig, IconType diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index fa26f507eed..bf178e8a44d 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -3,11 +3,11 @@ import uuid from datetime import datetime from typing import Any +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session from typing_extensions import TypedDict -from graphon.enums import WorkflowExecutionStatus from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 0b5c89e5740..98e338a2d4a 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -6,6 +6,19 @@ from concurrent.futures import ThreadPoolExecutor from enum import StrEnum from typing import Any, ClassVar +from graphon.enums import NodeType +from graphon.file import File +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.variable_assigner.common.helpers import get_updated_variables +from graphon.variable_loader import VariableLoader +from graphon.variables import Segment, StringSegment, VariableBase +from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.segments import ( + ArrayFileSegment, + FileSegment, +) +from graphon.variables.types import SegmentType +from graphon.variables.utils import dumps_with_segments from sqlalchemy import Engine, orm, select from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -26,19 +39,6 @@ from core.workflow.variable_prefixes import ( from extensions.ext_storage import storage from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable -from graphon.enums import NodeType -from graphon.file.models import File -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.variable_assigner.common.helpers import get_updated_variables -from graphon.variable_loader import VariableLoader -from graphon.variables import Segment, StringSegment, VariableBase -from graphon.variables.consts import SELECTORS_LENGTH -from graphon.variables.segments import ( - ArrayFileSegment, - FileSegment, -) -from graphon.variables.types import SegmentType -from graphon.variables.utils import dumps_with_segments from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models import Account, App, Conversation diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py index 5fca4447230..601e9261fc6 100644 --- a/api/services/workflow_event_snapshot_service.py +++ b/api/services/workflow_event_snapshot_service.py @@ -9,6 +9,10 @@ from collections.abc import Generator, Mapping, Sequence from dataclasses import dataclass from typing import Any +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import desc, select from sqlalchemy.orm import Session, sessionmaker @@ -22,10 +26,6 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from graphon.entities import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models.model import AppMode, Message from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index bef99458bef..b555676704e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -5,6 +5,31 @@ import uuid from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, cast +from graphon.entities import GraphInitParams, WorkflowNodeExecution +from graphon.entities.graph_config import NodeConfigDict +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + ErrorStrategy, + NodeType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.nodes.human_input.entities import HumanInputNodeData, validate_human_input_submission +from graphon.nodes.human_input.enums import HumanInputFormKind +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import load_into_variable_pool +from graphon.variables import VariableBase +from graphon.variables.input_entities import VariableEntityType +from graphon.variables.variables import Variable from sqlalchemy import exists, select from sqlalchemy.orm import Session, sessionmaker @@ -33,31 +58,6 @@ from events.app_event import app_draft_workflow_was_synced, app_published_workfl from extensions.ext_database import db from extensions.ext_storage import storage from factories.file_factory import build_from_mapping, build_from_mappings -from graphon.entities import GraphInitParams, WorkflowNodeExecution -from graphon.entities.graph_config import NodeConfigDict -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import ( - ErrorStrategy, - NodeType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.errors import WorkflowNodeRunFailedError -from graphon.file import File -from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent -from graphon.node_events import NodeRunResult -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.base.node import Node -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from graphon.nodes.human_input.entities import HumanInputNodeData, validate_human_input_submission -from graphon.nodes.human_input.enums import HumanInputFormKind -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.start.entities import StartNodeData -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import load_into_variable_pool -from graphon.variables import VariableBase -from graphon.variables.input_entities import VariableEntityType -from graphon.variables.variables import Variable from libs.datetime_utils import naive_utc_now from models import Account from models.human_input import HumanInputFormRecipient, RecipientType diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index 458099d99eb..489467651d0 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -7,6 +7,7 @@ from typing import Annotated, Any, TypeAlias, Union from celery import shared_task from flask import current_app, json +from graphon.runtime import GraphRuntimeState from pydantic import BaseModel, Discriminator, Field, Tag from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -22,7 +23,6 @@ from core.app.entities.app_invoke_entities import ( from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.runtime import GraphRuntimeState from libs.flask_utils import set_login_user from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index 6365400dd13..0a73c912798 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -10,6 +10,7 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task +from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -22,7 +23,6 @@ from core.app.layers.trigger_post_layer import TriggerPostLayer from core.db.session_factory import session_factory from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.runtime import GraphRuntimeState from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus from models.model import App, EndUser, Tenant diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index ed8a24b336f..20335d9b9f9 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -7,6 +7,7 @@ from pathlib import Path import click import pandas as pd from celery import shared_task +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import func from core.db.session_factory import session_factory @@ -14,7 +15,6 @@ from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from extensions.ext_storage import storage -from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py index fd743205a1b..ca73b4d3745 100644 --- a/api/tasks/human_input_timeout_tasks.py +++ b/api/tasks/human_input_timeout_tasks.py @@ -2,6 +2,8 @@ import logging from datetime import timedelta from celery import shared_task +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import or_, select from sqlalchemy.orm import sessionmaker @@ -9,8 +11,6 @@ from configs import dify_config from core.repositories.human_input_repository import HumanInputFormSubmissionRepository from extensions.ext_database import db from extensions.ext_storage import storage -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from models.human_input import HumanInputForm from models.workflow import WorkflowPause, WorkflowRun diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index f8ae3f4b6e7..a316eec7b95 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -6,6 +6,7 @@ from typing import Any import click from celery import shared_task +from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -14,7 +15,6 @@ from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod from extensions.ext_database import db from extensions.ext_mail import mail -from graphon.runtime import GraphRuntimeState, VariablePool from models.human_input import ( DeliveryMethodType, HumanInputDelivery, diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 25ea53dfac5..56626e372ea 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -12,6 +12,7 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -28,7 +29,6 @@ from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from enums.quota_type import QuotaType, unlimited -from graphon.enums import WorkflowExecutionStatus from models.enums import ( AppTriggerType, CreatorUserRole, diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index ae1c2991c9a..0c7f74c180a 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -9,11 +9,11 @@ import json import logging from celery import shared_task +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from core.db.session_factory import session_factory -from graphon.entities.workflow_execution import WorkflowExecution -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index b823ce3961d..f25ebe3bae4 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -9,13 +9,13 @@ import json import logging from celery import shared_task -from sqlalchemy import select - -from core.db.session_factory import session_factory from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, ) from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter +from sqlalchemy import select + +from core.db.session_factory import session_factory from models import CreatorUserRole, WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py index a876b0c4aae..91245e879e4 100644 --- a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py +++ b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py @@ -1,8 +1,9 @@ from collections.abc import Generator +from graphon.node_events import StreamCompletedEvent + from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage -from graphon.node_events import StreamCompletedEvent def _gen_var_stream() -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index b2de11b0680..3fdea109762 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,7 +1,8 @@ +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamCompletedEvent + from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult, StreamCompletedEvent class _Seg: diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index 878d9b24df0..c1bb8e12453 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -4,13 +4,13 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from graphon.file import File, FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader -from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile from models.enums import CreatorUserRole diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index c4146d5ccdd..ce04a158a82 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -4,9 +4,6 @@ from collections.abc import Generator, Sequence from decimal import Decimal from json import dumps -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from core.plugin.impl.model import PluginModelClient - # import monkeypatch from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.llm_entities import ( @@ -26,6 +23,9 @@ from graphon.model_runtime.entities.model_entities import ( ) from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.model import PluginModelClient + class MockModelClass(PluginModelClient): def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 0b21ff1d2a6..5c6636f31ec 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -3,6 +3,10 @@ import unittest import uuid import pytest +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType +from graphon.variables.variables import StringVariable from sqlalchemy import delete from sqlalchemy.orm import Session @@ -11,10 +15,6 @@ from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType from factories.variable_factory import build_segment -from graphon.nodes import BuiltinNodeTypes -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType -from graphon.variables.variables import StringVariable from libs import datetime_utils from models.enums import CreatorUserRole from models.model import UploadFile diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index f6f4cf260bc..38dc8bbb281 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,11 +2,11 @@ import uuid from unittest.mock import patch import pytest +from graphon.variables.segments import StringSegment from sqlalchemy import delete from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType -from graphon.variables.segments import StringSegment from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile @@ -193,6 +193,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def setup_offload_test_data(self, app_and_tenant): tenant, app = app_and_tenant from graphon.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: @@ -424,6 +425,7 @@ class TestDeleteDraftVariablesSessionCommit: def setup_offload_test_data(self, app_and_tenant): """Create test data with offload files for session commit tests.""" from graphon.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now tenant, app = app_and_tenant diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index a9a2617baed..c0143faa853 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -1,11 +1,12 @@ from unittest.mock import MagicMock +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory -from graphon.model_runtime.entities.model_entities import ModelType from models.provider import ProviderType diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 7573e00872c..ce0c8bf8ca8 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -2,17 +2,17 @@ import time import uuid import pytest - -from configs import dify_config -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.node_events import NodeRunResult from graphon.nodes.code.code_node import CodeNode from graphon.nodes.code.limits import CodeNodeLimits from graphon.runtime import GraphRuntimeState, VariablePool + +from configs import dify_config +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -172,7 +172,7 @@ def test_execute_code_output_validator(setup_code_executor_mock): result = node._run() assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.FAILED - assert result.error == "Output result must be a string, got int instead" + assert result.error == "Output result must be a string, got int instead." def test_execute_code_output_validator_depth(): diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 17ea7de8810..ce18486fafc 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -3,6 +3,11 @@ import uuid from urllib.parse import urlencode import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.graph import Graph +from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig +from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom @@ -11,11 +16,6 @@ from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.graph import Graph -from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig -from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -191,7 +191,6 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" - from core.workflow.system_variables import build_system_variables from graphon.enums import BuiltinNodeTypes from graphon.nodes.http_request.entities import ( HttpRequestNodeAuthorization, @@ -202,6 +201,8 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool + from core.workflow.system_variables import build_system_variables + # Create variable pool variable_pool = VariablePool( system_variables=build_system_variables(user_id="test", files=[]), diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index fa5d63cfbfd..f0f3fcead19 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,11 +4,6 @@ import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.llm_generator.output_parser.structured_output import _parse_structured_output -from core.model_manager import ModelInstance -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import StreamCompletedEvent from graphon.nodes.llm.file_saver import LLMFileSaver @@ -17,6 +12,12 @@ from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.nodes.protocols import HttpClientProtocol from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.llm_generator.output_parser.structured_output import _parse_structured_output +from core.model_manager import ModelInstance +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 367b5bbc110..3bf44df349c 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,16 +3,17 @@ import time import uuid from unittest.mock import MagicMock +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_runtime import DifyPromptMessageSerializer from core.workflow.system_variables import build_system_variables from extensions.ext_database import db -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage -from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory -from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 9e3e1a47e31..2d728569bee 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,14 +1,15 @@ import time import uuid -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode from graphon.runtime import GraphRuntimeState, VariablePool from graphon.template_rendering import TemplateRenderError + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index f9ec51ee10d..750ced7075e 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -2,17 +2,18 @@ import time import uuid from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.tools.utils.configuration import ToolParameterConfigurationManager -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.node_runtime import DifyToolNodeRuntime -from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.node_events import StreamCompletedEvent from graphon.nodes.protocols import ToolFileManagerProtocol from graphon.nodes.tool.tool_node import ToolNode from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.tools.utils.configuration import ToolParameterConfigurationManager +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index 5b515103881..5cc458fe2ef 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -4,11 +4,11 @@ import json import uuid from flask.testing import FlaskClient +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session from configs import dify_config from constants import HEADER_NAME_CSRF_TOKEN -from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py index 290be876974..8ddf867370e 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -3,12 +3,12 @@ import uuid from flask.testing import FlaskClient +from graphon.variables.segments import StringSegment from sqlalchemy import select from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from factories.variable_factory import segment_to_variable -from graphon.variables.segments import StringSegment from models import Workflow from models.model import AppMode from models.workflow import WorkflowDraftVariable diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index b8840c4ba8c..2b4c1b59abf 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -22,6 +22,13 @@ import uuid from time import time import pytest +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events import GraphRunPausedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session @@ -33,16 +40,6 @@ from core.app.layers.pause_state_persist_layer import ( ) from core.workflow.system_variables import build_system_variables from extensions.ext_storage import storage -from graphon.entities.pause_reason import SchedulingPause -from graphon.enums import WorkflowExecutionStatus -from graphon.graph_engine.entities.commands import GraphEngineCommand -from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from graphon.graph_events.graph import GraphRunPausedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime.graph_runtime_state import GraphRuntimeState -from graphon.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from graphon.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper -from graphon.runtime.variable_pool import VariablePool from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel @@ -545,7 +542,7 @@ class TestPauseStatePersistenceLayerTestContainers: layer.initialize(graph_runtime_state, command_channel) # Import other event types - from graphon.graph_events.graph import ( + from graphon.graph_events import ( GraphRunFailedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index e0c58f0f5c7..13caad799eb 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -4,6 +4,7 @@ from __future__ import annotations from uuid import uuid4 +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from sqlalchemy import Engine, select from sqlalchemy.orm import Session @@ -17,7 +18,6 @@ from core.workflow.human_input_compat import ( MemberRecipient, WebAppDeliveryMethod, ) -from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.human_input import ( EmailExternalRecipientPayload, diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index ae8c0716a42..0a9b476afc5 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -4,6 +4,18 @@ from datetime import timedelta from unittest.mock import MagicMock import pytest +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import delete, select from sqlalchemy.orm import Session @@ -15,18 +27,6 @@ from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchem from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.node_runtime import DifyHumanInputNodeRuntime from core.workflow.system_variables import build_system_variables -from graphon.enums import WorkflowType -from graphon.graph import Graph -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from graphon.graph_engine.graph_engine import GraphEngine -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models import Account from models.account import Tenant, TenantAccountJoin, TenantAccountRole diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index 2e207ddc674..cc72dc1cf39 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -4,13 +4,13 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from graphon.file import File, FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader -from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile from models.enums import CreatorUserRole diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py index 2fd289dfbca..b745aed1417 100644 --- a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py +++ b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py @@ -6,6 +6,7 @@ from decimal import Decimal from uuid import uuid4 from graphon.nodes.human_input.entities import FormDefinition, UserAction + from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin from models.enums import ConversationFromSource, InvokeFrom diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py index 641399c7f94..a68b3a08c7a 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -5,10 +5,10 @@ from __future__ import annotations from datetime import timedelta from uuid import uuid4 +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy.orm import Session, sessionmaker -from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index cb00752b35e..d28cfda1598 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -8,15 +8,15 @@ from unittest.mock import Mock from uuid import uuid4 import pytest -from sqlalchemy import Engine, delete, select -from sqlalchemy.orm import Session, sessionmaker - -from extensions.ext_storage import storage from graphon.entities import WorkflowExecution from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType from graphon.enums import WorkflowExecutionStatus from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus +from sqlalchemy import Engine, delete, select +from sqlalchemy.orm import Session, sessionmaker + +from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import ( diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py index aaf9a85d601..7f44eb6ca37 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -12,11 +12,11 @@ from decimal import Decimal from uuid import uuid4 import pytest +from graphon.nodes.human_input.entities import FormDefinition, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker -from graphon.nodes.human_input.entities import FormDefinition, UserAction -from graphon.nodes.human_input.enums import HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ConversationFromSource, InvokeFrom diff --git a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py index d6f0657380b..c5e9201ee37 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py @@ -7,12 +7,12 @@ from datetime import timedelta from uuid import uuid4 import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy import exc as sa_exc from sqlalchemy.orm import Session, sessionmaker -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 00a2f9a59f4..4f3c0e42000 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -842,6 +842,7 @@ class TestAgentService: conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) from graphon.file import FileTransferMethod, FileType + from models.enums import CreatorUserRole # Add files to message diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py index 02ab3f83146..fb0adbbcc26 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py @@ -3,10 +3,10 @@ from uuid import uuid4 import pytest +from graphon.variables import StringVariable from sqlalchemy.orm import sessionmaker from extensions.ext_database import db -from graphon.variables import StringVariable from models.workflow import ConversationVariable from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index 0de3c64c4f7..f9bfa570cbb 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -9,11 +9,11 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from graphon.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index 883c3c3febb..a814466e14f 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -2,10 +2,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType -from graphon.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings from models.enums import DataSourceType diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index fe426ae5161..c8f04e92159 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -5,9 +5,9 @@ Testcontainers integration tests for archived workflow run deletion service. from datetime import UTC, datetime, timedelta from uuid import uuid4 +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import select -from graphon.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowArchiveLog, WorkflowRun from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index 18c5320d0a1..c46b8fba0bd 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -3,6 +3,8 @@ import uuid from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from core.workflow.human_input_compat import ( EmailDeliveryConfig, @@ -10,8 +12,6 @@ from core.workflow.human_input_compat import ( EmailRecipients, ExternalRecipient, ) -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.human_input.entities import HumanInputNodeData from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import App, AppMode from models.workflow import Workflow, WorkflowType diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index 21a54e909ec..0f252515f72 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from graphon.runtime import VariablePool from sqlalchemy.engine import Engine from configs import dify_config @@ -15,7 +16,6 @@ from core.workflow.human_input_compat import ( ExternalRecipient, MemberRecipient, ) -from graphon.runtime import VariablePool from models.account import Account, TenantAccountJoin from services import human_input_delivery_test_service as service_module from services.human_input_delivery_test_service import ( diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index c0c1c25f1eb..9528257963e 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -6,11 +6,11 @@ from unittest.mock import patch import pytest from faker import Faker +from graphon.file import FileType from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client -from graphon.file.enums import FileType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ( ConversationFromSource, diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 8955a3b5f23..ba926bf6758 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -2,10 +2,10 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus -from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType from services.model_provider_service import ModelProviderService @@ -405,10 +405,11 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock models - from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.provider_entities import ProviderEntity + from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity + # Create real model objects instead of mocks provider_entity_1 = SimpleModelProviderEntity( ProviderEntity( @@ -643,9 +644,10 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock default model response - from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from graphon.model_runtime.entities.common_entities import I18nObject + from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity + mock_default_model = DefaultModelEntity( model="gpt-3.5-turbo", model_type=ModelType.LLM, diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 2a18345c875..749c6fff5bc 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -8,9 +8,9 @@ from unittest.mock import patch import pytest from faker import Faker +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session -from graphon.entities.workflow_execution import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLogCreatedFrom diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 86cf2327c7f..0c281c8c33b 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -1,9 +1,9 @@ import pytest from faker import Faker +from graphon.variables.segments import StringSegment from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from graphon.variables.segments import StringSegment from models import App, Workflow from models.enums import DraftVariableType from models.workflow import WorkflowDraftVariable diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index ce5c2bd162f..ce2fd2eeb1a 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -5,6 +5,9 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity, VariableEntityType from sqlalchemy.orm import Session from core.app.app_config.entities import ( @@ -18,9 +21,6 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models import Account, Tenant from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index 4dab895135b..7c43bf676b0 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -1,10 +1,10 @@ from datetime import datetime, timedelta from uuid import uuid4 +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker -from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index d341c5ce99a..a16f3ff773b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -3,6 +3,9 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.app_config.entities import WorkflowUIBasedAppConfig @@ -17,9 +20,6 @@ from core.workflow.human_input_compat import ( MemberRecipient, ) from extensions.ext_storage import storage -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.entities import HumanInputNodeData -from graphon.runtime import GraphRuntimeState, VariablePool from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index 9a7507a2f9f..96cf9cebf5e 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,11 +2,11 @@ import uuid from unittest.mock import ANY, call, patch import pytest +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from models import Tenant from models.enums import CreatorUserRole diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index b9f513a6d04..159ab51304a 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -24,12 +24,12 @@ from dataclasses import dataclass from datetime import timedelta import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import delete, select from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 8854ef5e04b..7539bae6855 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -10,6 +10,7 @@ from typing import Any import pytest from flask import Flask, Response from flask.testing import FlaskClient +from graphon.enums import BuiltinNodeTypes from sqlalchemy.orm import Session from configs import dify_config @@ -23,7 +24,6 @@ from core.trigger.debug import event_selectors from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key -from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus diff --git a/api/tests/unit_tests/controllers/console/app/test_audio.py b/api/tests/unit_tests/controllers/console/app/test_audio.py index 2d218dac7e5..c52bc02420e 100644 --- a/api/tests/unit_tests/controllers/console/app/test_audio.py +++ b/api/tests/unit_tests/controllers/console/app/test_audio.py @@ -4,6 +4,7 @@ import io from types import SimpleNamespace import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -20,7 +21,6 @@ from controllers.console.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index 341efc05caf..36076368805 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -5,12 +5,11 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from graphon.file import File, FileTransferMethod, FileType from werkzeug.exceptions import HTTPException, NotFound from controllers.console.app import workflow as workflow_module from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File def _unwrap(func): diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index c4a81484466..e11102acb1e 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -6,14 +6,14 @@ from unittest.mock import Mock import pytest from flask import Flask - -from controllers.console import wraps as console_wraps -from controllers.console.app import workflow_run as workflow_run_module -from controllers.web.error import NotFoundError from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import WorkflowExecutionStatus from graphon.nodes.human_input.entities import FormInput, UserAction from graphon.nodes.human_input.enums import FormInputType + +from controllers.console import wraps as console_wraps +from controllers.console.app import workflow_run as workflow_run_module +from controllers.web.error import NotFoundError from libs import login as login_lib from models.account import Account, AccountStatus, TenantAccountRole from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index 559b5fea09c..740da1f1df1 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch import pytest from flask_restx import marshal +from graphon.variables.types import SegmentType from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_FIELDS, @@ -15,7 +16,6 @@ from controllers.console.app.workflow_draft_variable import ( ) from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment -from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile @@ -310,8 +310,7 @@ def test_workflow_node_variables_fields(): def test_workflow_file_variable_with_signed_url(): """Test that File type variables include signed URLs in API responses.""" - from graphon.file.enums import FileTransferMethod, FileType - from graphon.file.models import File + from graphon.file import File, FileTransferMethod, FileType # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( @@ -367,8 +366,7 @@ def test_workflow_file_variable_with_signed_url(): def test_workflow_file_variable_remote_url(): """Test that File type variables with REMOTE_URL transfer method return the remote URL.""" - from graphon.file.enums import FileTransferMethod, FileType - from graphon.file.models import File + from graphon.file import File, FileTransferMethod, FileType # Create a File object with REMOTE_URL transfer method test_file = File( diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py index 5136922e88e..9c9f8da87c1 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from werkzeug.exceptions import Forbidden, NotFound from controllers.console import console_ns @@ -17,7 +18,6 @@ from controllers.console.datasets.rag_pipeline.datasource_auth import ( DatasourceUpdateProviderNameApi, ) from core.plugin.impl.oauth import OAuthHandler -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py index 63950736c54..6ef8ccfdbd3 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Response +from graphon.variables.types import SegmentType from controllers.console import console_ns from controllers.console.app.error import DraftWorkflowNotExist @@ -15,7 +16,6 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable impor ) from controllers.web.error import InvalidArgumentError, NotFoundError from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from graphon.variables.types import SegmentType from models.account import Account diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index e4acd91b76c..710c9be684c 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -20,7 +21,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.errors.invoke import InvokeError from models.account import Account from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService diff --git a/api/tests/unit_tests/controllers/console/explore/test_audio.py b/api/tests/unit_tests/controllers/console/explore/test_audio.py index b4b57022e28..66c9ba48c59 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_audio.py +++ b/api/tests/unit_tests/controllers/console/explore/test_audio.py @@ -2,6 +2,7 @@ from io import BytesIO from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError import controllers.console.explore.audio as audio_module @@ -19,7 +20,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index 145cc9cdd7e..2e4ca4f2a44 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError, NotFound import controllers.console.explore.message as module @@ -21,7 +22,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.errors.invoke import InvokeError from services.errors.conversation import ConversationNotExistsError from services.errors.message import ( FirstMessageNotExistsError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index 03eadcdb4e8..04beb31389c 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import controllers.console.explore.trial as module @@ -25,7 +26,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.errors.invoke import InvokeError from models import Account from models.account import TenantStatus from models.model import AppMode diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py index b2f949c6e2b..9c42ee9529a 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -11,10 +11,9 @@ from unittest.mock import MagicMock import pytest from flask import Flask from flask.views import MethodView -from werkzeug.exceptions import Forbidden - from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from werkzeug.exceptions import Forbidden if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py index 168479af1ea..fb9eec98cb5 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic_core import ValidationError from werkzeug.exceptions import Forbidden @@ -13,7 +14,6 @@ from controllers.console.workspace.model_providers import ( ModelProviderValidateApi, PreferredProviderTypeUpdateApi, ) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" INVALID_UUID = "123" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py index f0d32f81fb5..c829327bc7a 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_models.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -2,6 +2,8 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from controllers.console.workspace.models import ( DefaultModelApi, @@ -14,8 +16,6 @@ from controllers.console.workspace.models import ( ModelProviderModelParameterRuleApi, ModelProviderModelValidateApi, ) -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError def unwrap(func): diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index e81e612803b..5a8cb4619f4 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -13,6 +13,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -29,7 +30,6 @@ from controllers.service_api.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index 3364c07e623..57681d8f5bc 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -16,6 +16,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -34,7 +35,6 @@ from controllers.service_api.app.error import ( NotChatAppError, ) from core.errors.error import QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index 6543c270373..b1f036c6f36 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -19,6 +19,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.enums import WorkflowExecutionStatus from werkzeug.exceptions import BadRequest, NotFound from controllers.service_api.app.error import NotWorkflowAppError @@ -35,7 +36,6 @@ from controllers.service_api.app.workflow import ( WorkflowTaskStopApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError -from graphon.enums import WorkflowExecutionStatus from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py index eda270258d5..4b8e3a738cb 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -1,8 +1,9 @@ from types import SimpleNamespace -from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField from graphon.enums import WorkflowExecutionStatus +from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField + def test_workflow_run_status_field_with_enum() -> None: field = WorkflowRunStatusField() diff --git a/api/tests/unit_tests/controllers/web/test_audio.py b/api/tests/unit_tests/controllers/web/test_audio.py index a6ca441801b..cbfc8fa6130 100644 --- a/api/tests/unit_tests/controllers/web/test_audio.py +++ b/api/tests/unit_tests/controllers/web/test_audio.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.audio import AudioApi, TextApi from controllers.web.error import ( @@ -21,7 +22,6 @@ from controllers.web.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/web/test_completion.py b/api/tests/unit_tests/controllers/web/test_completion.py index 4f8d848637d..49039d03fe1 100644 --- a/api/tests/unit_tests/controllers/web/test_completion.py +++ b/api/tests/unit_tests/controllers/web/test_completion.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi from controllers.web.error import ( @@ -18,7 +19,6 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError def _completion_app() -> SimpleNamespace: diff --git a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py index cde8820e006..bc7aea0ef92 100644 --- a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py @@ -2,11 +2,11 @@ import json from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.llm_entities import LLMUsage from core.agent.cot_agent_runner import CotAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError -from graphon.model_runtime.entities.llm_entities import LLMUsage class DummyRunner(CotAgentRunner): diff --git a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py index ea8cc8aa864..97206019b9f 100644 --- a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py @@ -1,9 +1,9 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.agent.cot_chat_agent_runner import CotChatAgentRunner -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from tests.unit_tests.core.agent.conftest import ( DummyAgentConfig, DummyAppConfig, diff --git a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py index 2f5873d865f..defc8b4b642 100644 --- a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py @@ -1,8 +1,6 @@ import json import pytest - -from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -10,6 +8,8 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) +from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner + # ----------------------------- # Fixtures # ----------------------------- diff --git a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py index 17ab5babcbd..a44a0650eb4 100644 --- a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py @@ -3,11 +3,6 @@ from typing import Any from unittest.mock import MagicMock import pytest - -from core.agent.errors import AgentMaxIterationError -from core.agent.fc_agent_runner import FunctionCallAgentRunner -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueMessageFileEvent from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.message_entities import ( DocumentPromptMessageContent, @@ -16,6 +11,11 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) +from core.agent.errors import AgentMaxIterationError +from core.agent.fc_agent_runner import FunctionCallAgentRunner +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueMessageFileEvent + # ============================== # Dummy Helper Classes # ============================== diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py index 186b4a501d3..5ee66da94ab 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py @@ -2,6 +2,8 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.entities.model_entities import ModelStatus @@ -10,8 +12,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelPropertyKey class TestModelConfigConverter: diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py index d9fe7004ff7..e2f3c16335f 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py @@ -1,9 +1,9 @@ import pytest +from graphon.variables.input_entities import VariableEntityType from core.app.app_config.easy_ui_based_app.variables.manager import ( BasicVariablesConfigManager, ) -from graphon.variables.input_entities import VariableEntityType class TestBasicVariablesConfigManagerConvert: diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py index 11fc15c94df..8bde9c1f979 100644 --- a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -1,7 +1,8 @@ -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from graphon.file.models import FileTransferMethod, FileUploadConfig, ImageConfig +from graphon.file import FileTransferMethod, FileUploadConfig, ImageConfig from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager + def test_convert_with_vision(): config = { diff --git a/api/tests/unit_tests/core/app/app_config/test_entities.py b/api/tests/unit_tests/core/app/app_config/test_entities.py index f2bc3076dac..000f83cd5a2 100644 --- a/api/tests/unit_tests/core/app/app_config/test_entities.py +++ b/api/tests/unit_tests/core/app/app_config/test_entities.py @@ -1,10 +1,10 @@ import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.app_config.entities import ( DatasetRetrieveConfigEntity, PromptTemplateEntity, ) -from graphon.variables.input_entities import VariableEntity, VariableEntityType class TestAppConfigEntities: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index ef7df5e1da7..061719d15a5 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -3,12 +3,12 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 +from graphon.variables import SegmentType from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from factories import variable_factory -from graphon.variables import SegmentType from models import ConversationVariable, Workflow MINIMAL_GRAPH = { diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py index f2df35d7d02..e9fdeefee4d 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator +from graphon.enums import WorkflowNodeExecutionStatus + from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.entities.task_entities import ( ChatbotAppBlockingResponse, @@ -10,7 +12,6 @@ from core.app.entities.task_entities import ( NodeStartStreamResponse, PingStreamResponse, ) -from graphon.enums import WorkflowNodeExecutionStatus class TestAdvancedChatGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py index 99a386cd45e..a6d85989556 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py @@ -6,6 +6,8 @@ from types import SimpleNamespace from unittest import mock import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module from core.app.entities.app_invoke_entities import InvokeFrom @@ -17,8 +19,6 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import StreamEvent -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus from models.enums import MessageStatus from models.execution_extra_content import HumanInputContent from models.model import AppMode, EndUser diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 29fd63c063a..82b2e51019f 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -4,6 +4,8 @@ from contextlib import contextmanager from types import SimpleNamespace import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig from core.app.apps.advanced_chat.generate_task_pipeline import ( @@ -47,8 +49,6 @@ from core.app.entities.task_entities import ( ) from core.base.tts.app_generator_tts_publisher import AudioTrunk from core.workflow.system_variables import build_system_variables -from graphon.enums import BuiltinNodeTypes -from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models.enums import MessageStatus from models.model import AppMode, EndUser diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py index 80f7f94b1ac..7dc43581501 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py @@ -1,12 +1,12 @@ import contextlib import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class DummyAccount: diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py index 4567b354804..08250bc3b6f 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py @@ -1,10 +1,10 @@ import pytest +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.agent.entities import AgentEntity from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.moderation.base import ModerationError -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py index 8f3c41701b3..68bcffb0e83 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py @@ -2,6 +2,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.chat.app_runner import ChatAppRunner @@ -9,7 +10,6 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.moderation.base import ModerationError -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index f56ca8de994..f255d2c7df2 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent -from graphon.file.enums import FileTransferMethod, FileType -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index d6f7a05cdc8..4a94a2b4f1b 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -1,12 +1,11 @@ from types import SimpleNamespace import pytest +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.workflow.system_variables import build_system_variables from core.workflow.variable_pool_initializer import add_variables_to_pool -from graphon.runtime import GraphRuntimeState -from graphon.runtime.variable_pool import VariablePool def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 3ab63aed254..328cd12f12e 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -1,9 +1,10 @@ from collections.abc import Mapping, Sequence -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from graphon.variables.segments import ArrayFileSegment, FileSegment +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter + class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Test class for WorkflowResponseConverter._fetch_files_from_variable_value method""" diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index e8946281ac7..bc11bf41744 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -1,12 +1,13 @@ from datetime import UTC, datetime from types import SimpleNamespace +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter(): diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index 492e11ee0f7..c9e146ff126 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -1,10 +1,11 @@ from types import SimpleNamespace +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter() -> WorkflowResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 7ee375d8846..0fde7565d24 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -10,6 +10,8 @@ from typing import Any from unittest.mock import Mock import pytest +from graphon.entities import WorkflowStartReason +from graphon.enums import BuiltinNodeTypes from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -25,8 +27,6 @@ from core.app.entities.queue_entities import ( QueueNodeSucceededEvent, ) from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py index aa2085177e1..619d66085a4 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent import core.app.apps.completion.app_runner as module from core.app.apps.completion.app_runner import CompletionAppRunner from core.moderation.base import ModerationError -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py index f2e35f9900b..96af9fbdeef 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py @@ -3,13 +3,13 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError import core.app.apps.completion.app_generator as module from core.app.apps.completion.app_generator import CompletionAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py index cfe797aa764..6cdcab29abe 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus + from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( AppStreamResponse, @@ -10,7 +12,6 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus def test_convert_blocking_full_and_simple_response(): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py index 9db83f5531e..4fe82efcb33 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py @@ -1,4 +1,5 @@ import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult import core.app.apps.pipeline.pipeline_queue_manager as module from core.app.apps.base_app_queue_manager import PublishFrom @@ -13,7 +14,6 @@ from core.app.entities.queue_entities import ( QueueWorkflowPartialSuccessEvent, QueueWorkflowSucceededEvent, ) -from graphon.model_runtime.entities.llm_entities import LLMResult def test_publish_sets_stop_listen_and_raises_on_stopped(mocker): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index fb19d6d7615..ab70996f0aa 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -22,11 +22,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.graph_events import GraphRunFailedEvent import core.app.apps.pipeline.pipeline_runner as module from core.app.apps.pipeline.pipeline_runner import PipelineRunner from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from graphon.graph_events import GraphRunFailedEvent def _build_app_generate_entity() -> SimpleNamespace: diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index b0f8b423e1e..6167be3bbdb 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,7 +1,7 @@ import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.apps.base_app_generator import BaseAppGenerator -from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_validate_inputs_with_zero(): @@ -476,8 +476,9 @@ class TestBaseAppGeneratorExtras: assert converted[1] == "event: ping\n\n" def test_get_draft_var_saver_factory_debugger(self): - from core.app.entities.app_invoke_entities import InvokeFrom from graphon.enums import BuiltinNodeTypes + + from core.app.entities.app_invoke_entities import InvokeFrom from models import Account base_app_generator = BaseAppGenerator() diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py index 17de39ca99f..1dee7fdab66 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py @@ -4,6 +4,15 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from core.app.app_config.entities import ( AdvancedChatMessageEntity, @@ -14,15 +23,6 @@ from core.app.app_config.entities import ( from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessageRole, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index 3673b7f68eb..a126bc85f75 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -4,19 +4,14 @@ from types import ModuleType, SimpleNamespace from typing import Any import graphon.nodes.human_input.entities # noqa: F401 -from core.app.apps.advanced_chat import app_generator as adv_app_gen_module -from core.app.apps.workflow import app_generator as wf_app_gen_module -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason from graphon.entities.base_node_data import BaseNodeData, RetryConfig from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.entities.pause_reason import SchedulingPause -from graphon.entities.workflow_start_reason import WorkflowStartReason from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.graph_engine import GraphEngine -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from graphon.graph_engine.command_channels import InMemoryChannel from graphon.graph_events import ( GraphEngineEvent, GraphRunPausedEvent, @@ -30,6 +25,12 @@ from graphon.nodes.base.node import Node from graphon.nodes.end.entities import EndNodeData from graphon.nodes.start.entities import StartNodeData from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.apps.advanced_chat import app_generator as adv_app_gen_module +from core.app.apps.workflow import app_generator as wf_app_gen_module +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index 58c7bfa4bc2..de5bca161c7 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -4,23 +4,6 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest - -from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.app.entities.queue_entities import ( - QueueAgentLogEvent, - QueueIterationCompletedEvent, - QueueLoopCompletedEvent, - QueueNodeExceptionEvent, - QueueNodeFailedEvent, - QueueNodeRetryEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowPausedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.system_variables import default_system_variables from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import BuiltinNodeTypes from graphon.graph_events import ( @@ -41,6 +24,23 @@ from graphon.node_events import NodeRunResult from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.variables import StringVariable +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.queue_entities import ( + QueueAgentLogEvent, + QueueIterationCompletedEvent, + QueueLoopCompletedEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeRetryEvent, + QueueNodeSucceededEvent, + QueueTextChunkEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.workflow.system_variables import default_system_variables + class TestWorkflowBasedAppRunner: def test_resolve_user_from(self): diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py index 38a947986f6..aa789d9ff3a 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -1,11 +1,11 @@ from unittest.mock import MagicMock import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.queue_entities import QueueWorkflowPausedEvent -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph_events.graph import GraphRunPausedEvent class _DummyQueueManager: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 620a153204d..9e30faecf23 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -4,14 +4,14 @@ from typing import Any from unittest.mock import MagicMock, patch import pytest +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.workflow.system_variables import default_system_variables -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.runtime import GraphRuntimeState, VariablePool from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index ef0edf4096e..8a717e1dccd 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -3,6 +3,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from core.app.apps.common import workflow_response_converter from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -11,11 +16,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse from core.workflow.system_variables import build_system_variables -from graphon.entities.pause_reason import HumanInputRequired -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.graph_events.graph import GraphRunPausedEvent -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from models.account import Account from models.human_input import RecipientType diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py index 7dd7ffd7277..b768e813bd7 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus + from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( ErrorStreamResponse, @@ -9,7 +11,6 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus class TestWorkflowGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index a0a999cbc5a..29df903aa8b 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -2,14 +2,15 @@ import time from contextlib import contextmanager from unittest.mock import MagicMock +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState + from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.runtime import GraphRuntimeState from models.account import Account from models.model import AppMode from tests.workflow_test_utils import build_test_variable_pool diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index 115e35da8ad..dabd2594b43 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -4,6 +4,8 @@ from contextlib import contextmanager from types import SimpleNamespace import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline @@ -44,8 +46,6 @@ from core.app.entities.task_entities import ( ) from core.base.tts.app_generator_tts_publisher import AudioTrunk from core.workflow.system_variables import build_system_variables, system_variables_to_mapping -from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import AppMode, EndUser diff --git a/api/tests/unit_tests/core/app/entities/test_task_entities.py b/api/tests/unit_tests/core/app/entities/test_task_entities.py index 7c797806411..014a0cba729 100644 --- a/api/tests/unit_tests/core/app/entities/test_task_entities.py +++ b/api/tests/unit_tests/core/app/entities/test_task_entities.py @@ -1,10 +1,11 @@ +from graphon.enums import WorkflowNodeExecutionStatus + from core.app.entities.task_entities import ( NodeFinishStreamResponse, NodeRetryStreamResponse, NodeStartStreamResponse, StreamEvent, ) -from graphon.enums import WorkflowNodeExecutionStatus class TestTaskEntities: diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index 279e3159468..a78c1b428fb 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -1,16 +1,17 @@ from collections.abc import Sequence from unittest.mock import Mock +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.command_channels import CommandChannel +from graphon.graph_events import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent +from graphon.node_events import NodeRunResult +from graphon.runtime import ReadOnlyGraphRuntimeState +from graphon.variables import StringVariable +from graphon.variables.segments import Segment, StringSegment + from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer from core.workflow.system_variables import SystemVariableKey from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph_engine.protocols.command_channel import CommandChannel -from graphon.graph_events.node import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent -from graphon.node_events import NodeRunResult -from graphon.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from graphon.variables import StringVariable -from graphon.variables.segments import Segment, StringSegment from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 92a7788f6ed..035e64325bb 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -4,6 +4,17 @@ from time import time from unittest.mock import Mock import pytest +from graphon.entities.pause_reason import SchedulingPause +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events import ( + GraphRunFailedEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, +) +from graphon.runtime import ReadOnlyVariablePool +from graphon.variables.segments import Segment from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity @@ -14,17 +25,6 @@ from core.app.layers.pause_state_persist_layer import ( _WorkflowGenerateEntityWrapper, ) from core.workflow.system_variables import SystemVariableKey -from graphon.entities.pause_reason import SchedulingPause -from graphon.graph_engine.entities.commands import GraphEngineCommand -from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from graphon.graph_events.graph import ( - GraphRunFailedEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) -from graphon.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool -from graphon.variables.segments import Segment from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py index 56705f1a7ee..95931f4f8ba 100644 --- a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py @@ -1,5 +1,6 @@ +from graphon.graph_events import GraphRunPausedEvent + from core.app.layers.suspend_layer import SuspendLayer -from graphon.graph_events.graph import GraphRunPausedEvent class TestSuspendLayer: diff --git a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py index 1ac9a4d8c0c..7cf6eb4f310 100644 --- a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py @@ -1,7 +1,8 @@ from unittest.mock import Mock, patch -from core.app.layers.timeslice_layer import TimeSliceLayer from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand + +from core.app.layers.timeslice_layer import TimeSliceLayer from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import SchedulerCommand diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py index ecc431936c6..aa9285789b3 100644 --- a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -2,10 +2,11 @@ from datetime import UTC, datetime, timedelta from types import SimpleNamespace from unittest.mock import Mock, patch +from graphon.graph_events import GraphRunFailedEvent, GraphRunSucceededEvent +from graphon.runtime import VariablePool + from core.app.layers.trigger_post_layer import TriggerPostLayer from core.workflow.system_variables import build_system_variables -from graphon.graph_events.graph import GraphRunFailedEvent, GraphRunSucceededEvent -from graphon.runtime import VariablePool from models.enums import WorkflowTriggerStatus diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py index c246f7b7836..58aa7d74782 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.entities.queue_entities import QueueErrorEvent from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.errors.error import QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py index 1c1bf391d3e..4aaa10a81af 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -2,6 +2,8 @@ from types import SimpleNamespace from unittest.mock import ANY, Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity @@ -26,8 +28,6 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher from core.ops.ops_trace_manager import TraceQueueManager -from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py index ea000f3886a..f7e7b7e20ef 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py @@ -5,6 +5,9 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent from core.app.app_config.entities import ( AppAdditionalFeatures, @@ -38,9 +41,6 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AudioTrunk -from graphon.file.enums import FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py index abfbcdb9411..31b73130666 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -17,11 +17,11 @@ import uuid from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.file import FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.entities.task_entities import MessageEndStreamResponse from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline -from graphon.file.enums import FileTransferMethod, FileType from models.model import MessageFile, UploadFile diff --git a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py index 21c761c579b..29df7eea863 100644 --- a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py +++ b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py @@ -1,9 +1,10 @@ from types import SimpleNamespace from unittest.mock import patch +from graphon.model_runtime.entities.model_entities import ModelPropertyKey + from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager from core.app.app_config.entities import ModelConfigEntity -from graphon.model_runtime.entities.model_entities import ModelPropertyKey from models.provider_ids import ModelProviderID diff --git a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py index 5c50cb78dae..dc2d82ccd6c 100644 --- a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py +++ b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py @@ -2,14 +2,14 @@ from datetime import UTC, datetime from unittest.mock import Mock import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType +from graphon.node_events import NodeRunResult from core.app.workflow.layers.persistence import ( PersistenceWorkflowInfo, WorkflowPersistenceLayer, _NodeRuntimeSnapshot, ) -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType -from graphon.node_events import NodeRunResult def _build_layer() -> WorkflowPersistenceLayer: diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py index cddd03f4b01..7be9d6ac1eb 100644 --- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -8,13 +8,13 @@ from unittest.mock import MagicMock, patch from urllib.parse import parse_qs, urlparse import pytest +from graphon.file import File, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope from core.app.workflow import file_runtime from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime from core.workflow.file_reference import build_file_reference -from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py index c4bfb232729..8497261d453 100644 --- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -1,10 +1,10 @@ from types import SimpleNamespace import pytest +from graphon.enums import BuiltinNodeTypes from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.workflow.node_factory import DifyNodeFactory -from graphon.enums import BuiltinNodeTypes class DummyNode: diff --git a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py index 82552470a95..a47d3db6f5b 100644 --- a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py +++ b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py @@ -2,9 +2,10 @@ from __future__ import annotations from types import SimpleNamespace -from core.app.workflow.layers.observability import ObservabilityLayer from graphon.enums import BuiltinNodeTypes +from core.app.workflow.layers.observability import ObservabilityLayer + class TestObservabilityLayerExtras: def test_init_tracer_enabled_sets_tracer(self, monkeypatch): diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py index 9863f34aba2..d8a68f6d000 100644 --- a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -4,27 +4,21 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest - -from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity -from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.workflow.system_variables import SystemVariableKey, build_system_variables +from graphon.entities import WorkflowNodeExecution from graphon.entities.pause_reason import SchedulingPause -from graphon.entities.workflow_node_execution import WorkflowNodeExecution from graphon.enums import ( BuiltinNodeTypes, WorkflowExecutionStatus, WorkflowNodeExecutionStatus, WorkflowType, ) -from graphon.graph_events.graph import ( +from graphon.graph_events import ( GraphRunAbortedEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, -) -from graphon.graph_events.node import ( NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunPauseRequestedEvent, @@ -35,6 +29,10 @@ from graphon.graph_events.node import ( from graphon.node_events import NodeRunResult from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity +from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.workflow.system_variables import SystemVariableKey, build_system_variables + class _RepoRecorder: def __init__(self) -> None: diff --git a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py index 7b433ab57b9..5ff9774b525 100644 --- a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py +++ b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py @@ -301,7 +301,6 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() - from core.app.entities.queue_entities import QueueAgentMessageEvent from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -309,6 +308,8 @@ class TestAppGeneratorTTSPublisher: TextPromptMessageContent, ) + from core.app.entities.queue_entities import QueueAgentMessageEvent + chunk = LLMResultChunk( model="model", delta=LLMResultChunkDelta( @@ -336,10 +337,11 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() - from core.app.entities.queue_entities import QueueAgentMessageEvent from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from graphon.model_runtime.entities.message_entities import AssistantPromptMessage + from core.app.entities.queue_entities import QueueAgentMessageEvent + chunk = LLMResultChunk( model="model", delta=LLMResultChunkDelta( diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index af992e4e9ff..b0c72ee42f5 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -2,16 +2,15 @@ import types from collections.abc import Generator import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError from core.workflow.file_reference import parse_file_reference -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.file import File -from graphon.file.enums import FileTransferMethod, FileType -from graphon.node_events import StreamChunkEvent, StreamCompletedEvent def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py index 0b91d59953a..fbaf6d497d7 100644 --- a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py @@ -1,11 +1,10 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.file import File, FileTransferMethod, FileType from core.datasource.entities.datasource_entities import DatasourceMessage from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer -from graphon.file import File -from graphon.file.enums import FileTransferMethod, FileType from models.tools import ToolFile diff --git a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py index ef8f360dbfe..ff9fd0d8f3f 100644 --- a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py +++ b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py @@ -1,11 +1,12 @@ +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType + from core.entities.execution_extra_content import ( ExecutionExtraContentDomainModel, HumanInputContent, HumanInputFormDefinition, HumanInputFormSubmissionData, ) -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from models.execution_extra_content import ExecutionContentType diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py index a0b28201578..2acd278a31e 100644 --- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -8,6 +8,9 @@ drive provider mapping behavior. """ import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.entities.model_entities import ( DefaultModelEntity, @@ -16,9 +19,6 @@ from core.entities.model_entities import ( ProviderModelWithStatusEntity, SimpleModelProviderEntity, ) -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity: diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index fe2c226843b..8cf0409c4c2 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -6,6 +6,17 @@ from typing import Any from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus @@ -24,17 +35,6 @@ from core.entities.provider_entities import ( SystemConfiguration, SystemConfigurationStatus, ) -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FieldModelSchema, - FormType, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderEntity, -) from models.enums import CredentialSourceType from models.provider import ProviderType from models.provider_ids import ModelProviderID diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py index a159d3ad4d0..8685d162831 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py @@ -1,4 +1,5 @@ import pytest +from graphon.model_runtime.entities.model_entities import ModelType from core.entities.parameter_entities import AppSelectorScope from core.entities.provider_entities import ( @@ -8,7 +9,6 @@ from core.entities.provider_entities import ( ProviderQuotaType, ) from core.tools.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType def test_provider_quota_type_value_of_returns_enum_member() -> None: diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py index 6ed9ddb476c..b45f6fd9a77 100644 --- a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py @@ -2,20 +2,6 @@ import json from unittest.mock import MagicMock, patch import pytest - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import ( - ResponseFormat, - _handle_native_json_schema, - _handle_prompt_based_schema, - _parse_structured_output, - _prepare_schema_for_model, - _set_response_format, - convert_boolean_to_string, - invoke_llm_with_structured_output, - remove_additional_properties, -) -from core.model_manager import ModelInstance from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, @@ -31,6 +17,20 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import ( + ResponseFormat, + _handle_native_json_schema, + _handle_prompt_based_schema, + _parse_structured_output, + _prepare_schema_for_model, + _set_response_format, + convert_boolean_to_string, + invoke_llm_with_structured_output, + remove_additional_properties, +) +from core.model_manager import ModelInstance + class TestStructuredOutput: def test_remove_additional_properties(self): diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index b3a58858149..2c0a4411254 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -2,12 +2,12 @@ import json from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator -from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError class TestLLMGenerator: diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index bfb1fde5027..313d18c695d 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -3,6 +3,7 @@ from unittest.mock import Mock, patch import jsonschema import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types @@ -18,7 +19,6 @@ from core.mcp.server.streamable_http import ( prepare_tool_arguments, process_mapping_response, ) -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py index f459250b8ef..9a5fb319d7b 100644 --- a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py +++ b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py @@ -4,8 +4,6 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest - -from core.memory.token_buffer_memory import TokenBufferMemory from graphon.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -13,6 +11,8 @@ from graphon.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) + +from core.memory.token_buffer_memory import TokenBufferMemory from models.model import AppMode # --------------------------------------------------------------------------- diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py index 249ecb50065..6a672fdfd57 100644 --- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -1,7 +1,6 @@ from unittest.mock import Mock import pytest - from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from graphon.model_runtime.entities.provider_entities import ( diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py index c2324fdec45..62d631a7541 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py @@ -5,6 +5,8 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module @@ -34,8 +36,6 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey class RecordingTraceClient: diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py index e4d8f2d5ea0..2d2be12f051 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py @@ -1,6 +1,8 @@ import json from unittest.mock import MagicMock +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from opentelemetry.trace import Link, StatusCode from core.ops.aliyun_trace.entities.semconv import ( @@ -24,8 +26,6 @@ from core.ops.aliyun_trace.utils import ( serialize_json_data, ) from core.rag.models.document import Document -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py index 8ebf4419211..97f7a163272 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py @@ -5,6 +5,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( @@ -25,7 +26,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace -from graphon.enums import BuiltinNodeTypes from models import EndUser from models.enums import MessageStatus diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py index 34c64c54a1f..bfe916f0182 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py @@ -3,6 +3,7 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( @@ -21,7 +22,6 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py index afc5726ede2..f4c485a9fc5 100644 --- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py +++ b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py @@ -9,6 +9,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( @@ -21,7 +22,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds -from graphon.enums import BuiltinNodeTypes # โ”€โ”€ Helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py index c02ac413f27..1cb32f2ee02 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py @@ -5,6 +5,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( @@ -18,7 +19,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser from models.enums import MessageStatus diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py index 6113e5c6c8a..696f859b6f7 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py @@ -1,6 +1,8 @@ from datetime import datetime from unittest.mock import MagicMock, patch +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from opentelemetry.trace import StatusCode from core.ops.entities.trace_entity import ( @@ -25,8 +27,6 @@ from core.ops.tencent_trace.entities.semconv import ( ) from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.rag.models.document import Document -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class TestTencentSpanBuilder: diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py index 265652381ca..382e5dadc3b 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py @@ -2,6 +2,8 @@ import logging from unittest.mock import MagicMock, patch import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( @@ -14,8 +16,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.tencent_trace.tencent_trace import TencentDataTrace -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes from models import Account, App, TenantAccountJoin logger = logging.getLogger(__name__) diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py index 4b925390d91..6b5cb5b09a8 100644 --- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py +++ b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py @@ -1,7 +1,7 @@ +from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes from openinference.semconv.trace import OpenInferenceSpanKindValues from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind -from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes class TestGetNodeSpanKind: diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py index 531c7de05f9..5014f40afca 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py @@ -7,6 +7,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from weave.trace_server.trace_server_interface import TraceStatus from core.ops.entities.config_entity import WeaveConfig @@ -22,7 +23,6 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.ops.weave_trace.weave_trace import WeaveDataTrace -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey # โ”€โ”€ Helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py index c24d3ac0120..543b278715d 100644 --- a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py @@ -1,9 +1,10 @@ from types import SimpleNamespace from unittest.mock import patch +from graphon.model_runtime.entities.message_entities import UserPromptMessage + from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation from core.plugin.entities.request import RequestInvokeSummary -from graphon.model_runtime.entities.message_entities import UserPromptMessage def test_system_model_helpers_forward_user_id(): diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py index 68aa1305181..f8d0e127b1b 100644 --- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -6,15 +6,15 @@ from types import SimpleNamespace from unittest.mock import Mock, sentinel import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl import model_runtime as model_runtime_module from core.plugin.impl.model import PluginModelClient from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_schema() -> AIModelEntity: diff --git a/api/tests/unit_tests/core/plugin/test_plugin_entities.py b/api/tests/unit_tests/core/plugin/test_plugin_entities.py index f1c4c7e7009..a812b01c5bd 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_entities.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_entities.py @@ -4,6 +4,12 @@ from enum import StrEnum import pytest from flask import Response +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) from pydantic import ValidationError from core.plugin.entities.endpoint import EndpointEntityWithInstance @@ -25,12 +31,6 @@ from core.plugin.entities.request import ( ) from core.plugin.utils.http_parser import serialize_response from core.tools.entities.common_entities import I18nObject -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) class TestEndpointEntity: diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index af86f917b12..3063ca01970 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -17,6 +17,14 @@ from unittest.mock import MagicMock, patch import httpx import pytest +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from core.plugin.entities.plugin_daemon import ( @@ -37,14 +45,6 @@ from core.plugin.impl.exc import ( ) from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.tool import PluginToolManager -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError class TestPluginRuntimeExecution: diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index 4d4313dd845..90730dff5a4 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -1,13 +1,12 @@ from collections.abc import Generator import pytest +from graphon.file import File, FileTransferMethod, FileType from core.agent.entities import AgentInvokeMessage from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File class TestChunkMerger: diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 395d3921271..2b280dd6746 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,13 +2,6 @@ from typing import cast from unittest.mock import MagicMock, patch import pytest - -from configs import dify_config -from core.app.app_config.entities import ModelConfigEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -18,6 +11,13 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) + +from configs import dify_config +from core.app.app_config.entities import ModelConfigEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 803afa54d70..4a54649b289 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -1,11 +1,5 @@ from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import ( - ModelConfigWithCredentialsEntity, -) -from core.entities.provider_configuration import ProviderModelBundle -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, @@ -13,6 +7,13 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +from core.app.entities.app_invoke_entities import ( + ModelConfigWithCredentialsEntity, +) +from core.entities.provider_configuration import ProviderModelBundle +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py index 5d865d934cf..a4b3960b0a2 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_message.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -1,5 +1,3 @@ -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, AudioPromptMessageContent, @@ -9,6 +7,9 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil + def test_build_prompt_message_with_prompt_message_contents(): prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")]) diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 9f9ea33695f..e35ce2c48a9 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -2,9 +2,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.prompt.prompt_transform import PromptTransform -from graphon.model_runtime.entities.model_entities import ModelPropertyKey # from core.app.app_config.entities import ModelConfigEntity # from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index 0dc74b33dfb..3f188cfbb4b 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -2,6 +2,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -18,12 +24,6 @@ from core.prompt.prompt_templates.advanced_prompt_templates import ( CONTEXT, ) from core.prompt.simple_prompt_transform import SimplePromptTransform -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - TextPromptMessageContent, - UserPromptMessage, -) from models.model import AppMode, Conversation diff --git a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py index 1f3247590c4..006b4e7345e 100644 --- a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py +++ b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py @@ -1,12 +1,13 @@ from unittest.mock import MagicMock, patch +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError + from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError def _doc(content: str) -> Document: diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py index bfa78fe5658..6fd44be4d4b 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -12,11 +12,11 @@ from unittest.mock import Mock, patch import numpy as np import pytest +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from sqlalchemy.exc import IntegrityError from core.rag.embedding.cached_embedding import CacheEmbedding -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from models.dataset import Embedding diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 392f0b458b5..d7ba944e586 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -49,10 +49,6 @@ from unittest.mock import Mock, patch import numpy as np import pytest -from sqlalchemy.exc import IntegrityError - -from core.entities.embedding_type import EmbeddingInputType -from core.rag.embedding.cached_embedding import CacheEmbedding from graphon.model_runtime.entities.model_entities import ModelPropertyKey from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from graphon.model_runtime.errors.invoke import ( @@ -60,6 +56,10 @@ from graphon.model_runtime.errors.invoke import ( InvokeConnectionError, InvokeRateLimitError, ) +from sqlalchemy.exc import IntegrityError + +from core.entities.embedding_type import EmbeddingInputType +from core.rag.embedding.cached_embedding import CacheEmbedding from models.dataset import Embedding diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index c861871f020..cc2873dd3f7 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -2,14 +2,14 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelFeature from core.entities.knowledge_entities import PreviewDetail from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.models.document import AttachmentDocument, Document -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent -from graphon.model_runtime.entities.model_entities import ModelFeature class TestParagraphIndexProcessor: diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index 059876d410d..450e7166360 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -53,6 +53,7 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm.exc import ObjectDeletedError from core.errors.error import ProviderTokenNotInitError @@ -63,7 +64,6 @@ from core.indexing_runner import ( ) from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document -from graphon.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index 415597f336a..2ec7f0498e8 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -17,6 +17,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_manager import ModelInstance from core.rag.index_processor.constant.doc_type import DocType @@ -28,7 +29,6 @@ from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner -from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult def create_mock_model_instance() -> ModelInstance: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index a7e62e7b0a8..c11426163e3 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -6,6 +6,8 @@ from uuid import uuid4 import pytest from flask import Flask, current_app +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelFeature from sqlalchemy import column from core.app.app_config.entities import ( @@ -35,8 +37,6 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_retrieval import exc from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.entities.model_entities import ModelFeature from models.dataset import Dataset from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py index 43c521dcfd4..5a2ecb82204 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py @@ -1,8 +1,9 @@ from unittest.mock import Mock -from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter from graphon.model_runtime.entities.llm_entities import LLMUsage +from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter + class TestFunctionCallMultiDatasetRouter: def test_invoke_returns_none_when_no_tools(self) -> None: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py index c56528cf55e..539ac0f849f 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py @@ -1,12 +1,13 @@ from types import SimpleNamespace from unittest.mock import Mock, patch -from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish -from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.message_entities import PromptMessageRole from graphon.model_runtime.entities.model_entities import ModelType +from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish +from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter + class TestReactMultiDatasetRouter: def test_invoke_returns_none_when_no_tools(self) -> None: diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index 2735ec512f5..e229d5fc1a5 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -9,9 +9,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowType from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository -from graphon.entities.workflow_execution import WorkflowExecution, WorkflowType from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index 05b4f3a053c..7dbf78d0f0c 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -9,14 +9,14 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest - -from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from core.repositories.factory import OrderConfig from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) from graphon.enums import BuiltinNodeTypes + +from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 8be1ac318c7..0fc82dda530 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -7,6 +7,11 @@ from datetime import datetime from types import SimpleNamespace import pytest +from graphon.nodes.human_input.entities import ( + FormDefinition, + UserAction, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( HumanInputFormRecord, @@ -21,11 +26,6 @@ from core.workflow.human_input_compat import ( ExternalRecipient, MemberRecipient, ) -from graphon.nodes.human_input.entities import ( - FormDefinition, - UserAction, -) -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py index 1297a95df14..8ff0e405874 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -9,6 +9,8 @@ from typing import Any from unittest.mock import MagicMock import pytest +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( FormCreateParams, @@ -29,8 +31,6 @@ from core.workflow.human_input_compat import ( MemberRecipient, WebAppDeliveryMethod, ) -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import HumanInputFormRecipient, RecipientType diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py index 6cb3c3c6acd..e5c3e854875 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py @@ -3,11 +3,12 @@ from unittest.mock import MagicMock from uuid import uuid4 import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from graphon.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType from models import Account, CreatorUserRole, EndUser, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py index 6af7b02d4cc..5b4d26b7808 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -10,6 +10,12 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) from sqlalchemy import Engine, create_engine from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker @@ -23,12 +29,6 @@ from core.repositories.sqlalchemy_workflow_node_execution_repository import ( _find_first, _replace_or_append_offload, ) -from graphon.entities import WorkflowNodeExecution -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) from models import Account, EndUser from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py index abdbc72085f..84fe522388e 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py @@ -4,17 +4,17 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index 5af1376a0a9..27729e7f06b 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -11,17 +11,17 @@ from datetime import UTC, datetime from typing import Any from unittest.mock import MagicMock +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from sqlalchemy import Engine from configs import dify_config from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionTriggeredFrom from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index f17927f16b2..ac65d0c02bc 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,6 +1,7 @@ import json from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig + from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index afea9144c06..f5efb78b614 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -2,12 +2,12 @@ from unittest.mock import MagicMock, patch import pytest import redis +from graphon.model_runtime.entities.model_entities import ModelType from pytest_mock import MockerFixture from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.model_manager import LBModelManager from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.model_entities import ModelType @pytest.fixture diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index b19a21d7f44..331166fe63c 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -1,6 +1,15 @@ from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormOption, + FormType, + ProviderEntity, +) from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus from core.entities.provider_entities import ( @@ -12,15 +21,6 @@ from core.entities.provider_entities import ( RestrictModel, SystemConfiguration, ) -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormOption, - FormType, - ProviderEntity, -) from models.provider import Provider, ProviderType diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 7f6a50af996..259cb5fdd07 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -2,12 +2,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType from pytest_mock import MockerFixture from core.entities.provider_entities import ModelSettings from core.provider_manager import ProviderManager -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType from models.provider import LoadBalancingModelConfig, ProviderModelSetting, TenantDefaultModel from models.provider_ids import ModelProviderID diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py index 1ff81f61200..5d744f88c9b 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py @@ -6,13 +6,13 @@ from typing import Any from unittest.mock import patch import pytest +from graphon.model_runtime.entities.message_entities import UserPromptMessage from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType -from graphon.model_runtime.entities.message_entities import UserPromptMessage class _BuiltinDummyTool(BuiltinTool): diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py index 9ac280e31ae..ee0ce51eec9 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -6,6 +6,8 @@ from datetime import date from types import SimpleNamespace import pytest +from graphon.file import FileType +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -27,8 +29,6 @@ from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.errors import ToolInvokeError -from graphon.file.enums import FileType -from graphon.model_runtime.entities.model_entities import ModelPropertyKey def _build_builtin_tool(tool_cls: type[BuiltinTool]) -> BuiltinTool: diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py index b3442636b75..7fcebde3c55 100644 --- a/api/tests/unit_tests/core/tools/test_tool_file_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -12,9 +12,9 @@ from unittest.mock import MagicMock, Mock, patch import httpx import pytest +from graphon.file import FileTransferMethod from core.tools.tool_file_manager import ToolFileManager -from graphon.file import FileTransferMethod def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]: diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py index a4a563a4a1b..52f262e1cf1 100644 --- a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -13,8 +13,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest - -from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils from graphon.model_runtime.entities.model_entities import ModelPropertyKey from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, @@ -24,6 +22,8 @@ from graphon.model_runtime.errors.invoke import ( InvokeServerUnavailableError, ) +from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils + def _mock_model_instance(*, schema: dict | None = None) -> SimpleNamespace: model_type_instance = Mock() diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py index 43f3fbd5c99..0e3a7e623a8 100644 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -1,9 +1,9 @@ import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils -from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_ensure_no_human_input_nodes_passes_for_non_human_input(): diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py index b147d7fcdb8..2607861b59d 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py @@ -4,6 +4,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( @@ -13,7 +14,6 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.workflow_as_tool.provider import WorkflowToolProviderController -from graphon.variables.input_entities import VariableEntity, VariableEntityType def _controller() -> WorkflowToolProviderController: diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 72a73dd9368..c20edd74004 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -11,6 +11,7 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -24,7 +25,6 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool -from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType class StubScalars: diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py index ee7a3d9c96b..78622b78b6b 100644 --- a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py @@ -11,6 +11,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes, NodeType from core.plugin.entities.request import TriggerInvokeEventResponse from core.trigger.constants import ( @@ -26,7 +27,6 @@ from core.trigger.debug.event_selectors import ( select_trigger_debug_events, ) from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent -from graphon.enums import BuiltinNodeTypes, NodeType from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 72052c8c058..7406b88270b 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -2,11 +2,6 @@ import dataclasses import orjson import pytest -from pydantic import BaseModel - -from core.helper import encrypter -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool from graphon.file import File, FileTransferMethod, FileType from graphon.runtime import VariablePool from graphon.variables.segment_group import SegmentGroup @@ -47,6 +42,11 @@ from graphon.variables.variables import ( StringVariable, Variable, ) +from pydantic import BaseModel + +from core.helper import encrypter +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool def _build_variable_pool( diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index d4e862220ae..37ecd2890bb 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,5 +1,4 @@ import pytest - from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import StringSegment from graphon.variables.types import ArrayValidation, SegmentType diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 14f9b2991d7..09254e17a30 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -9,9 +9,7 @@ from dataclasses import dataclass from typing import Any import pytest - -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File +from graphon.file import File, FileTransferMethod, FileType from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import ( ArrayFileSegment, diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index dae5e1ce984..75b01bf42e9 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -1,6 +1,4 @@ import pytest -from pydantic import ValidationError - from graphon.variables import ( ArrayFileVariable, ArrayVariable, @@ -12,6 +10,7 @@ from graphon.variables import ( StringVariable, ) from graphon.variables.variables import VariableBase +from pydantic import ValidationError def test_frozen_variables(): diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py deleted file mode 100644 index ef5500b72fa..00000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ /dev/null @@ -1,307 +0,0 @@ -import json -from time import time -from unittest.mock import MagicMock, patch - -import pytest - -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool -from graphon.variables.variables import StringVariable - - -class StubCoordinator: - def __init__(self) -> None: - self.state = "initial" - - def dumps(self) -> str: - return json.dumps({"state": self.state}) - - def loads(self, data: str) -> None: - payload = json.loads(data) - self.state = payload["state"] - - -class TestGraphRuntimeState: - def test_execution_context_defaults_to_empty_context(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - - with state.execution_context: - assert state.execution_context is not None - - state.execution_context = None - - with state.execution_context: - assert state.execution_context is not None - - def test_property_getters_and_setters(self): - # FIXME(-LAN-): Mock VariablePool if needed - variable_pool = VariablePool() - start_time = time() - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=start_time) - - # Test variable_pool property (read-only) - assert state.variable_pool == variable_pool - - # Test start_at property - assert state.start_at == start_time - new_time = time() + 100 - state.start_at = new_time - assert state.start_at == new_time - - # Test total_tokens property - assert state.total_tokens == 0 - state.total_tokens = 100 - assert state.total_tokens == 100 - - # Test node_run_steps property - assert state.node_run_steps == 0 - state.node_run_steps = 5 - assert state.node_run_steps == 5 - - def test_outputs_immutability(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test that getting outputs returns a copy - outputs1 = state.outputs - outputs2 = state.outputs - assert outputs1 == outputs2 - assert outputs1 is not outputs2 # Different objects - - # Test that modifying retrieved outputs doesn't affect internal state - outputs = state.outputs - outputs["test"] = "value" - assert "test" not in state.outputs - - # Test set_output method - state.set_output("key1", "value1") - assert state.get_output("key1") == "value1" - - # Test update_outputs method - state.update_outputs({"key2": "value2", "key3": "value3"}) - assert state.get_output("key2") == "value2" - assert state.get_output("key3") == "value3" - - def test_llm_usage_immutability(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test that getting llm_usage returns a copy - usage1 = state.llm_usage - usage2 = state.llm_usage - assert usage1 is not usage2 # Different objects - - def test_type_validation(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test total_tokens validation - with pytest.raises(ValueError): - state.total_tokens = -1 - - # Test node_run_steps validation - with pytest.raises(ValueError): - state.node_run_steps = -1 - - def test_helper_methods(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test increment_node_run_steps - initial_steps = state.node_run_steps - state.increment_node_run_steps() - assert state.node_run_steps == initial_steps + 1 - - # Test add_tokens - initial_tokens = state.total_tokens - state.add_tokens(50) - assert state.total_tokens == initial_tokens + 50 - - # Test add_tokens validation - with pytest.raises(ValueError): - state.add_tokens(-1) - - def test_ready_queue_default_instantiation(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - - queue = state.ready_queue - - from graphon.graph_engine.ready_queue import InMemoryReadyQueue - - assert isinstance(queue, InMemoryReadyQueue) - - def test_graph_execution_lazy_instantiation(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - - execution = state.graph_execution - - from graphon.graph_engine.domain.graph_execution import GraphExecution - - assert isinstance(execution, GraphExecution) - assert execution.workflow_id == "" - assert state.graph_execution is execution - - def test_response_coordinator_configuration(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - with pytest.raises(ValueError): - _ = state.response_coordinator - - mock_graph = MagicMock() - with patch( - "graphon.graph_engine.response_coordinator.ResponseStreamCoordinator", autospec=True - ) as coordinator_cls: - coordinator_instance = coordinator_cls.return_value - state.configure(graph=mock_graph) - - assert state.response_coordinator is coordinator_instance - coordinator_cls.assert_called_once_with(variable_pool=variable_pool, graph=mock_graph) - - # Configure again with same graph should be idempotent - state.configure(graph=mock_graph) - - other_graph = MagicMock() - with pytest.raises(ValueError): - state.attach_graph(other_graph) - - def test_read_only_wrapper_exposes_additional_state(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - state.configure() - - wrapper = ReadOnlyGraphRuntimeStateWrapper(state) - - assert wrapper.ready_queue_size == 0 - assert wrapper.exceptions_count == 0 - - def test_read_only_wrapper_serializes_runtime_state(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - state.total_tokens = 5 - state.set_output("result", {"success": True}) - state.ready_queue.put("node-1") - - wrapper = ReadOnlyGraphRuntimeStateWrapper(state) - - wrapper_snapshot = json.loads(wrapper.dumps()) - state_snapshot = json.loads(state.dumps()) - - assert wrapper_snapshot == state_snapshot - - def test_dumps_and_loads_roundtrip_with_response_coordinator(self): - variable_pool = VariablePool() - variable_pool.add(("node1", "value"), "payload") - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - state.total_tokens = 10 - state.node_run_steps = 3 - state.set_output("final", {"result": True}) - usage = LLMUsage.from_metadata( - { - "prompt_tokens": 2, - "completion_tokens": 3, - "total_tokens": 5, - "total_price": "1.23", - "currency": "USD", - "latency": 0.5, - } - ) - state.llm_usage = usage - state.ready_queue.put("node-A") - - graph_execution = state.graph_execution - graph_execution.workflow_id = "wf-123" - graph_execution.exceptions_count = 4 - graph_execution.started = True - - mock_graph = MagicMock() - stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub, autospec=True): - state.attach_graph(mock_graph) - - stub.state = "configured" - - snapshot = state.dumps() - - restored = GraphRuntimeState.from_snapshot(snapshot) - - assert restored.total_tokens == 10 - assert restored.node_run_steps == 3 - assert restored.get_output("final") == {"result": True} - assert restored.llm_usage.total_tokens == usage.total_tokens - assert restored.ready_queue.qsize() == 1 - assert restored.ready_queue.get(timeout=0.01) == "node-A" - - restored_segment = restored.variable_pool.get(("node1", "value")) - assert restored_segment is not None - assert restored_segment.value == "payload" - - restored_execution = restored.graph_execution - assert restored_execution.workflow_id == "wf-123" - assert restored_execution.exceptions_count == 4 - assert restored_execution.started is True - - new_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub, autospec=True): - restored.attach_graph(mock_graph) - - assert new_stub.state == "configured" - - def test_loads_rehydrates_existing_instance(self): - variable_pool = VariablePool() - variable_pool.add(("node", "key"), "value") - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - state.total_tokens = 7 - state.node_run_steps = 2 - state.set_output("foo", "bar") - state.ready_queue.put("node-1") - - execution = state.graph_execution - execution.workflow_id = "wf-456" - execution.started = True - - mock_graph = MagicMock() - original_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub, autospec=True): - state.attach_graph(mock_graph) - - original_stub.state = "configured" - snapshot = state.dumps() - - new_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub, autospec=True): - restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) - restored.attach_graph(mock_graph) - restored.loads(snapshot) - - assert restored.total_tokens == 7 - assert restored.node_run_steps == 2 - assert restored.get_output("foo") == "bar" - assert restored.ready_queue.qsize() == 1 - assert restored.ready_queue.get(timeout=0.01) == "node-1" - - restored_segment = restored.variable_pool.get(("node", "key")) - assert restored_segment is not None - assert restored_segment.value == "value" - - restored_execution = restored.graph_execution - assert restored_execution.workflow_id == "wf-456" - assert restored_execution.started is True - - assert new_stub.state == "configured" - - def test_snapshot_restore_preserves_updated_conversation_variable(self): - variable_pool = VariablePool( - conversation_variables=[StringVariable(name="session_name", value="before")], - ) - variable_pool.add((CONVERSATION_VARIABLE_NODE_ID, "session_name"), "after") - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - snapshot = state.dumps() - restored = GraphRuntimeState.from_snapshot(snapshot) - - restored_value = restored.variable_pool.get((CONVERSATION_VARIABLE_NODE_ID, "session_name")) - assert restored_value is not None - assert restored_value.value == "after" diff --git a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py deleted file mode 100644 index 856ec959b73..00000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Tests for PauseReason discriminated union serialization/deserialization. -""" - -import pytest -from pydantic import BaseModel, ValidationError - -from graphon.entities.pause_reason import ( - HumanInputRequired, - PauseReason, - SchedulingPause, -) - - -class _Holder(BaseModel): - """Helper model that embeds PauseReason for union tests.""" - - reason: PauseReason - - -class TestPauseReasonDiscriminator: - """Test suite for PauseReason union discriminator.""" - - @pytest.mark.parametrize( - ("dict_value", "expected"), - [ - pytest.param( - { - "reason": { - "TYPE": "human_input_required", - "form_id": "form_id", - "form_content": "form_content", - "node_id": "node_id", - "node_title": "node_title", - }, - }, - HumanInputRequired( - form_id="form_id", - form_content="form_content", - node_id="node_id", - node_title="node_title", - ), - id="HumanInputRequired", - ), - pytest.param( - { - "reason": { - "TYPE": "scheduled_pause", - "message": "Hold on", - } - }, - SchedulingPause(message="Hold on"), - id="SchedulingPause", - ), - ], - ) - def test_model_validate(self, dict_value, expected): - """Ensure scheduled pause payloads with lowercase TYPE deserialize.""" - holder = _Holder.model_validate(dict_value) - - assert type(holder.reason) == type(expected) - assert holder.reason == expected - - @pytest.mark.parametrize( - "reason", - [ - HumanInputRequired( - form_id="form_id", - form_content="form_content", - node_id="node_id", - node_title="node_title", - ), - SchedulingPause(message="Hold on"), - ], - ids=lambda x: type(x).__name__, - ) - def test_model_construct(self, reason): - holder = _Holder(reason=reason) - assert holder.reason == reason - - def test_model_construct_with_invalid_type(self): - with pytest.raises(ValidationError): - holder = _Holder(reason=object()) # type: ignore - - def test_unknown_type_fails_validation(self): - """Unknown TYPE values should raise a validation error.""" - with pytest.raises(ValidationError): - _Holder.model_validate({"reason": {"TYPE": "UNKNOWN"}}) diff --git a/api/tests/unit_tests/core/workflow/entities/test_template.py b/api/tests/unit_tests/core/workflow/entities/test_template.py deleted file mode 100644 index e8304b9bcdd..00000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_template.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Tests for template module.""" - -from graphon.nodes.base.template import Template, TextSegment, VariableSegment - - -class TestTemplate: - """Test Template class functionality.""" - - def test_from_answer_template_simple(self): - """Test parsing a simple answer template.""" - template_str = "Hello, {{#node1.name#}}!" - template = Template.from_answer_template(template_str) - - assert len(template.segments) == 3 - assert isinstance(template.segments[0], TextSegment) - assert template.segments[0].text == "Hello, " - assert isinstance(template.segments[1], VariableSegment) - assert template.segments[1].selector == ["node1", "name"] - assert isinstance(template.segments[2], TextSegment) - assert template.segments[2].text == "!" - - def test_from_answer_template_multiple_vars(self): - """Test parsing an answer template with multiple variables.""" - template_str = "Hello {{#node1.name#}}, your age is {{#node2.age#}}." - template = Template.from_answer_template(template_str) - - assert len(template.segments) == 5 - assert isinstance(template.segments[0], TextSegment) - assert template.segments[0].text == "Hello " - assert isinstance(template.segments[1], VariableSegment) - assert template.segments[1].selector == ["node1", "name"] - assert isinstance(template.segments[2], TextSegment) - assert template.segments[2].text == ", your age is " - assert isinstance(template.segments[3], VariableSegment) - assert template.segments[3].selector == ["node2", "age"] - assert isinstance(template.segments[4], TextSegment) - assert template.segments[4].text == "." - - def test_from_answer_template_no_vars(self): - """Test parsing an answer template with no variables.""" - template_str = "Hello, world!" - template = Template.from_answer_template(template_str) - - assert len(template.segments) == 1 - assert isinstance(template.segments[0], TextSegment) - assert template.segments[0].text == "Hello, world!" - - def test_from_end_outputs_single(self): - """Test creating template from End node outputs with single variable.""" - outputs_config = [{"variable": "text", "value_selector": ["node1", "text"]}] - template = Template.from_end_outputs(outputs_config) - - assert len(template.segments) == 1 - assert isinstance(template.segments[0], VariableSegment) - assert template.segments[0].selector == ["node1", "text"] - - def test_from_end_outputs_multiple(self): - """Test creating template from End node outputs with multiple variables.""" - outputs_config = [ - {"variable": "text", "value_selector": ["node1", "text"]}, - {"variable": "result", "value_selector": ["node2", "result"]}, - ] - template = Template.from_end_outputs(outputs_config) - - assert len(template.segments) == 3 - assert isinstance(template.segments[0], VariableSegment) - assert template.segments[0].selector == ["node1", "text"] - assert template.segments[0].variable_name == "text" - assert isinstance(template.segments[1], TextSegment) - assert template.segments[1].text == "\n" - assert isinstance(template.segments[2], VariableSegment) - assert template.segments[2].selector == ["node2", "result"] - assert template.segments[2].variable_name == "result" - - def test_from_end_outputs_empty(self): - """Test creating template from empty End node outputs.""" - outputs_config = [] - template = Template.from_end_outputs(outputs_config) - - assert len(template.segments) == 0 - - def test_template_str_representation(self): - """Test string representation of template.""" - template_str = "Hello, {{#node1.name#}}!" - template = Template.from_answer_template(template_str) - - assert str(template) == template_str diff --git a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py deleted file mode 100644 index 7e087516836..00000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py +++ /dev/null @@ -1,136 +0,0 @@ -from graphon.runtime import VariablePool -from graphon.variables.segments import ( - BooleanSegment, - IntegerSegment, - NoneSegment, - StringSegment, -) - - -class TestVariablePoolGetAndNestedAttribute: - # - # _get_nested_attribute tests - # - def test__get_nested_attribute_existing_key(self): - pool = VariablePool.empty() - obj = {"a": 123} - segment = pool._get_nested_attribute(obj, "a") - assert segment is not None - assert segment.value == 123 - - def test__get_nested_attribute_missing_key(self): - pool = VariablePool.empty() - obj = {"a": 123} - segment = pool._get_nested_attribute(obj, "b") - assert segment is None - - def test__get_nested_attribute_non_dict(self): - pool = VariablePool.empty() - obj = ["not", "a", "dict"] - segment = pool._get_nested_attribute(obj, "a") - assert segment is None - - def test__get_nested_attribute_with_none_value(self): - pool = VariablePool.empty() - obj = {"a": None} - segment = pool._get_nested_attribute(obj, "a") - assert segment is not None - assert isinstance(segment, NoneSegment) - - def test__get_nested_attribute_with_empty_string(self): - pool = VariablePool.empty() - obj = {"a": ""} - segment = pool._get_nested_attribute(obj, "a") - assert segment is not None - assert isinstance(segment, StringSegment) - assert segment.value == "" - - # - # get tests - # - def test_get_simple_variable(self): - pool = VariablePool.empty() - pool.add(("node1", "var1"), "value1") - segment = pool.get(("node1", "var1")) - assert segment is not None - assert segment.value == "value1" - - def test_get_missing_variable(self): - pool = VariablePool.empty() - result = pool.get(("node1", "unknown")) - assert result is None - - def test_get_with_too_short_selector(self): - pool = VariablePool.empty() - result = pool.get(("only_node",)) - assert result is None - - def test_get_nested_object_attribute(self): - pool = VariablePool.empty() - obj_value = {"inner": "hello"} - pool.add(("node1", "obj"), obj_value) - - # simulate selector with nested attr - segment = pool.get(("node1", "obj", "inner")) - assert segment is not None - assert segment.value == "hello" - - def test_get_nested_object_missing_attribute(self): - pool = VariablePool.empty() - obj_value = {"inner": "hello"} - pool.add(("node1", "obj"), obj_value) - - result = pool.get(("node1", "obj", "not_exist")) - assert result is None - - def test_get_nested_object_attribute_with_falsy_values(self): - pool = VariablePool.empty() - obj_value = { - "inner_none": None, - "inner_empty": "", - "inner_zero": 0, - "inner_false": False, - } - pool.add(("node1", "obj"), obj_value) - - segment_none = pool.get(("node1", "obj", "inner_none")) - assert segment_none is not None - assert isinstance(segment_none, NoneSegment) - - segment_empty = pool.get(("node1", "obj", "inner_empty")) - assert segment_empty is not None - assert isinstance(segment_empty, StringSegment) - assert segment_empty.value == "" - - segment_zero = pool.get(("node1", "obj", "inner_zero")) - assert segment_zero is not None - assert isinstance(segment_zero, IntegerSegment) - assert segment_zero.value == 0 - - segment_false = pool.get(("node1", "obj", "inner_false")) - assert segment_false is not None - assert isinstance(segment_false, BooleanSegment) - assert segment_false.value is False - - -class TestVariablePoolGetNotModifyVariableDictionary: - _NODE_ID = "start" - _VAR_NAME = "name" - - def test_convert_to_template_should_not_introduce_extra_keys(self): - pool = VariablePool.empty() - pool.add([self._NODE_ID, self._VAR_NAME], 0) - pool.convert_template("The start.name is {{#start.name#}}") - assert "The start" not in pool.variable_dictionary - - def test_get_should_not_modify_variable_dictionary(self): - pool = VariablePool.empty() - pool.get([self._NODE_ID, self._VAR_NAME]) - assert len(pool.variable_dictionary) == 0 - assert "start" not in pool.variable_dictionary - - pool = VariablePool.empty() - pool.add([self._NODE_ID, self._VAR_NAME], "Joe") - pool.get([self._NODE_ID, "count"]) - start_subdict = pool.variable_dictionary[self._NODE_ID] - assert "count" not in start_subdict diff --git a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py deleted file mode 100644 index 5e697f22f37..00000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py +++ /dev/null @@ -1,225 +0,0 @@ -""" -Unit tests for WorkflowNodeExecution domain model, focusing on process_data truncation functionality. -""" - -from dataclasses import dataclass -from datetime import datetime -from typing import Any - -import pytest - -from graphon.entities.workflow_node_execution import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes - - -class TestWorkflowNodeExecutionProcessDataTruncation: - """Test process_data truncation functionality in WorkflowNodeExecution domain model.""" - - def create_workflow_node_execution( - self, - process_data: dict[str, Any] | None = None, - ) -> WorkflowNodeExecution: - """Create a WorkflowNodeExecution instance for testing.""" - return WorkflowNodeExecution( - id="test-execution-id", - workflow_id="test-workflow-id", - index=1, - node_id="test-node-id", - node_type=BuiltinNodeTypes.LLM, - title="Test Node", - process_data=process_data, - created_at=datetime.now(), - ) - - def test_initial_process_data_truncated_state(self): - """Test that process_data_truncated returns False initially.""" - execution = self.create_workflow_node_execution() - - assert execution.process_data_truncated is False - assert execution.get_truncated_process_data() is None - - def test_set_and_get_truncated_process_data(self): - """Test setting and getting truncated process_data.""" - execution = self.create_workflow_node_execution() - test_truncated_data = {"truncated": True, "key": "value"} - - execution.set_truncated_process_data(test_truncated_data) - - assert execution.process_data_truncated is True - assert execution.get_truncated_process_data() == test_truncated_data - - def test_set_truncated_process_data_to_none(self): - """Test setting truncated process_data to None.""" - execution = self.create_workflow_node_execution() - - # First set some data - execution.set_truncated_process_data({"key": "value"}) - assert execution.process_data_truncated is True - - # Then set to None - execution.set_truncated_process_data(None) - assert execution.process_data_truncated is False - assert execution.get_truncated_process_data() is None - - def test_get_response_process_data_with_no_truncation(self): - """Test get_response_process_data when no truncation is set.""" - original_data = {"original": True, "data": "value"} - execution = self.create_workflow_node_execution(process_data=original_data) - - response_data = execution.get_response_process_data() - - assert response_data == original_data - assert execution.process_data_truncated is False - - def test_get_response_process_data_with_truncation(self): - """Test get_response_process_data when truncation is set.""" - original_data = {"original": True, "large_data": "x" * 10000} - truncated_data = {"original": True, "large_data": "[TRUNCATED]"} - - execution = self.create_workflow_node_execution(process_data=original_data) - execution.set_truncated_process_data(truncated_data) - - response_data = execution.get_response_process_data() - - # Should return truncated data, not original - assert response_data == truncated_data - assert response_data != original_data - assert execution.process_data_truncated is True - - def test_get_response_process_data_with_none_process_data(self): - """Test get_response_process_data when process_data is None.""" - execution = self.create_workflow_node_execution(process_data=None) - - response_data = execution.get_response_process_data() - - assert response_data is None - assert execution.process_data_truncated is False - - def test_consistency_with_inputs_outputs_pattern(self): - """Test that process_data truncation follows the same pattern as inputs/outputs.""" - execution = self.create_workflow_node_execution() - - # Test that all truncation methods exist and behave consistently - test_data = {"test": "data"} - - # Test inputs truncation - execution.set_truncated_inputs(test_data) - assert execution.inputs_truncated is True - assert execution.get_truncated_inputs() == test_data - - # Test outputs truncation - execution.set_truncated_outputs(test_data) - assert execution.outputs_truncated is True - assert execution.get_truncated_outputs() == test_data - - # Test process_data truncation - execution.set_truncated_process_data(test_data) - assert execution.process_data_truncated is True - assert execution.get_truncated_process_data() == test_data - - @pytest.mark.parametrize( - "test_data", - [ - {"simple": "value"}, - {"nested": {"key": "value"}}, - {"list": [1, 2, 3]}, - {"mixed": {"string": "value", "number": 42, "list": [1, 2]}}, - {}, # empty dict - ], - ) - def test_truncated_process_data_with_various_data_types(self, test_data): - """Test that truncated process_data works with various data types.""" - execution = self.create_workflow_node_execution() - - execution.set_truncated_process_data(test_data) - - assert execution.process_data_truncated is True - assert execution.get_truncated_process_data() == test_data - assert execution.get_response_process_data() == test_data - - -@dataclass -class ProcessDataScenario: - """Test scenario data for process_data functionality.""" - - name: str - original_data: dict[str, Any] | None - truncated_data: dict[str, Any] | None - expected_truncated_flag: bool - expected_response_data: dict[str, Any] | None - - -class TestWorkflowNodeExecutionProcessDataScenarios: - """Test various scenarios for process_data handling.""" - - def get_process_data_scenarios(self) -> list[ProcessDataScenario]: - """Create test scenarios for process_data functionality.""" - return [ - ProcessDataScenario( - name="no_process_data", - original_data=None, - truncated_data=None, - expected_truncated_flag=False, - expected_response_data=None, - ), - ProcessDataScenario( - name="process_data_without_truncation", - original_data={"small": "data"}, - truncated_data=None, - expected_truncated_flag=False, - expected_response_data={"small": "data"}, - ), - ProcessDataScenario( - name="process_data_with_truncation", - original_data={"large": "x" * 10000, "metadata": "info"}, - truncated_data={"large": "[TRUNCATED]", "metadata": "info"}, - expected_truncated_flag=True, - expected_response_data={"large": "[TRUNCATED]", "metadata": "info"}, - ), - ProcessDataScenario( - name="empty_process_data", - original_data={}, - truncated_data=None, - expected_truncated_flag=False, - expected_response_data={}, - ), - ProcessDataScenario( - name="complex_nested_data_with_truncation", - original_data={ - "config": {"setting": "value"}, - "logs": ["log1", "log2"] * 1000, # Large list - "status": "running", - }, - truncated_data={"config": {"setting": "value"}, "logs": "[TRUNCATED: 2000 items]", "status": "running"}, - expected_truncated_flag=True, - expected_response_data={ - "config": {"setting": "value"}, - "logs": "[TRUNCATED: 2000 items]", - "status": "running", - }, - ), - ] - - @pytest.mark.parametrize( - "scenario", - get_process_data_scenarios(None), - ids=[scenario.name for scenario in get_process_data_scenarios(None)], - ) - def test_process_data_scenarios(self, scenario: ProcessDataScenario): - """Test various process_data scenarios.""" - execution = WorkflowNodeExecution( - id="test-execution-id", - workflow_id="test-workflow-id", - index=1, - node_id="test-node-id", - node_type=BuiltinNodeTypes.LLM, - title="Test Node", - process_data=scenario.original_data, - created_at=datetime.now(), - ) - - if scenario.truncated_data is not None: - execution.set_truncated_process_data(scenario.truncated_data) - - assert execution.process_data_truncated == scenario.expected_truncated_flag - assert execution.get_response_process_data() == scenario.expected_response_data diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph.py b/api/tests/unit_tests/core/workflow/graph/test_graph.py deleted file mode 100644 index b138a7dfdc4..00000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph.py +++ /dev/null @@ -1,281 +0,0 @@ -"""Unit tests for Graph class methods.""" - -from unittest.mock import Mock - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState -from graphon.graph.edge import Edge -from graphon.graph.graph import Graph -from graphon.nodes.base.node import Node - - -def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node: - """Create a mock node for testing.""" - node = Mock(spec=Node) - node.id = node_id - node.execution_type = execution_type - node.state = state - node.node_type = BuiltinNodeTypes.START - return node - - -class TestMarkInactiveRootBranches: - """Test cases for _mark_inactive_root_branches method.""" - - def test_single_root_no_marking(self): - """Test that single root graph doesn't mark anything as skipped.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - } - - in_edges = {"child1": ["edge1"]} - out_edges = {"root1": ["edge1"]} - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["child1"].state == NodeState.UNKNOWN - assert edges["edge1"].state == NodeState.UNKNOWN - - def test_multiple_roots_mark_inactive(self): - """Test marking inactive root branches with multiple root nodes.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), - } - - in_edges = {"child1": ["edge1"], "child2": ["edge2"]} - out_edges = {"root1": ["edge1"], "root2": ["edge2"]} - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["child1"].state == NodeState.UNKNOWN - assert nodes["child2"].state == NodeState.SKIPPED - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - - def test_shared_downstream_node(self): - """Test that shared downstream nodes are not skipped if at least one path is active.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - "shared": create_mock_node("shared", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), - "edge3": Edge(id="edge3", tail="child1", head="shared", source_handle="source"), - "edge4": Edge(id="edge4", tail="child2", head="shared", source_handle="source"), - } - - in_edges = { - "child1": ["edge1"], - "child2": ["edge2"], - "shared": ["edge3", "edge4"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "child1": ["edge3"], - "child2": ["edge4"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["child1"].state == NodeState.UNKNOWN - assert nodes["child2"].state == NodeState.SKIPPED - assert nodes["shared"].state == NodeState.UNKNOWN # Not skipped because edge3 is active - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - assert edges["edge3"].state == NodeState.UNKNOWN - assert edges["edge4"].state == NodeState.SKIPPED - - def test_deep_branch_marking(self): - """Test marking deep branches with multiple levels.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "level1_a": create_mock_node("level1_a", NodeExecutionType.EXECUTABLE), - "level1_b": create_mock_node("level1_b", NodeExecutionType.EXECUTABLE), - "level2_a": create_mock_node("level2_a", NodeExecutionType.EXECUTABLE), - "level2_b": create_mock_node("level2_b", NodeExecutionType.EXECUTABLE), - "level3": create_mock_node("level3", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="level1_a", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="level1_b", source_handle="source"), - "edge3": Edge(id="edge3", tail="level1_a", head="level2_a", source_handle="source"), - "edge4": Edge(id="edge4", tail="level1_b", head="level2_b", source_handle="source"), - "edge5": Edge(id="edge5", tail="level2_b", head="level3", source_handle="source"), - } - - in_edges = { - "level1_a": ["edge1"], - "level1_b": ["edge2"], - "level2_a": ["edge3"], - "level2_b": ["edge4"], - "level3": ["edge5"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "level1_a": ["edge3"], - "level1_b": ["edge4"], - "level2_b": ["edge5"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["level1_a"].state == NodeState.UNKNOWN - assert nodes["level1_b"].state == NodeState.SKIPPED - assert nodes["level2_a"].state == NodeState.UNKNOWN - assert nodes["level2_b"].state == NodeState.SKIPPED - assert nodes["level3"].state == NodeState.SKIPPED - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - assert edges["edge3"].state == NodeState.UNKNOWN - assert edges["edge4"].state == NodeState.SKIPPED - assert edges["edge5"].state == NodeState.SKIPPED - - def test_non_root_execution_type(self): - """Test that nodes with non-ROOT execution type are not treated as root nodes.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "non_root": create_mock_node("non_root", NodeExecutionType.EXECUTABLE), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="non_root", head="child2", source_handle="source"), - } - - in_edges = {"child1": ["edge1"], "child2": ["edge2"]} - out_edges = {"root1": ["edge1"], "non_root": ["edge2"]} - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["non_root"].state == NodeState.UNKNOWN # Not marked as skipped - assert nodes["child1"].state == NodeState.UNKNOWN - assert nodes["child2"].state == NodeState.UNKNOWN - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.UNKNOWN - - def test_empty_graph(self): - """Test handling of empty graph structures.""" - nodes = {} - edges = {} - in_edges = {} - out_edges = {} - - # Should not raise any errors - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "non_existent") - - def test_three_roots_mark_two_inactive(self): - """Test with three root nodes where two should be marked inactive.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "root3": create_mock_node("root3", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - "child3": create_mock_node("child3", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), - "edge3": Edge(id="edge3", tail="root3", head="child3", source_handle="source"), - } - - in_edges = { - "child1": ["edge1"], - "child2": ["edge2"], - "child3": ["edge3"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "root3": ["edge3"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root2") - - assert nodes["root1"].state == NodeState.SKIPPED - assert nodes["root2"].state == NodeState.UNKNOWN # Active root - assert nodes["root3"].state == NodeState.SKIPPED - assert nodes["child1"].state == NodeState.SKIPPED - assert nodes["child2"].state == NodeState.UNKNOWN - assert nodes["child3"].state == NodeState.SKIPPED - assert edges["edge1"].state == NodeState.SKIPPED - assert edges["edge2"].state == NodeState.UNKNOWN - assert edges["edge3"].state == NodeState.SKIPPED - - def test_convergent_paths(self): - """Test convergent paths where multiple inactive branches lead to same node.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "root3": create_mock_node("root3", NodeExecutionType.ROOT), - "mid1": create_mock_node("mid1", NodeExecutionType.EXECUTABLE), - "mid2": create_mock_node("mid2", NodeExecutionType.EXECUTABLE), - "convergent": create_mock_node("convergent", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="mid1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="mid2", source_handle="source"), - "edge3": Edge(id="edge3", tail="root3", head="convergent", source_handle="source"), - "edge4": Edge(id="edge4", tail="mid1", head="convergent", source_handle="source"), - "edge5": Edge(id="edge5", tail="mid2", head="convergent", source_handle="source"), - } - - in_edges = { - "mid1": ["edge1"], - "mid2": ["edge2"], - "convergent": ["edge3", "edge4", "edge5"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "root3": ["edge3"], - "mid1": ["edge4"], - "mid2": ["edge5"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["root3"].state == NodeState.SKIPPED - assert nodes["mid1"].state == NodeState.UNKNOWN - assert nodes["mid2"].state == NodeState.SKIPPED - assert nodes["convergent"].state == NodeState.UNKNOWN # Not skipped due to active path from root1 - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - assert edges["edge3"].state == NodeState.SKIPPED - assert edges["edge4"].state == NodeState.UNKNOWN - assert edges["edge5"].state == NodeState.SKIPPED diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py deleted file mode 100644 index f3eaa1d6869..00000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.graph import Graph -from graphon.nodes.base.node import Node - - -def _make_node(node_id: str, node_type: NodeType = BuiltinNodeTypes.START) -> Node: - node = MagicMock(spec=Node) - node.id = node_id - node.node_type = node_type - node.execution_type = None # attribute not used in builder path - return node - - -def test_graph_builder_creates_linear_graph(): - builder = Graph.new() - root = _make_node("root", BuiltinNodeTypes.START) - mid = _make_node("mid", BuiltinNodeTypes.LLM) - end = _make_node("end", BuiltinNodeTypes.END) - - graph = builder.add_root(root).add_node(mid).add_node(end).build() - - assert graph.root_node is root - assert graph.nodes == {"root": root, "mid": mid, "end": end} - assert len(graph.edges) == 2 - first_edge = next(iter(graph.edges.values())) - assert first_edge.tail == "root" - assert first_edge.head == "mid" - assert graph.out_edges["mid"] == [edge_id for edge_id, edge in graph.edges.items() if edge.tail == "mid"] - - -def test_graph_builder_supports_custom_predecessor(): - builder = Graph.new() - root = _make_node("root") - branch = _make_node("branch") - other = _make_node("other") - - graph = builder.add_root(root).add_node(branch).add_node(other, from_node_id="root").build() - - outgoing_root = graph.out_edges["root"] - assert len(outgoing_root) == 2 - edge_targets = {graph.edges[eid].head for eid in outgoing_root} - assert edge_targets == {"branch", "other"} - - -def test_graph_builder_validates_usage(): - builder = Graph.new() - node = _make_node("node") - - with pytest.raises(ValueError, match="Root node"): - builder.add_node(node) - - builder.add_root(node) - duplicate = _make_node("node") - with pytest.raises(ValueError, match="Duplicate"): - builder.add_node(duplicate) diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py deleted file mode 100644 index 3620a20e567..00000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py +++ /dev/null @@ -1,118 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import pytest - -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import default_system_variables -from graphon.graph import Graph -from graphon.graph.validation import GraphValidationError -from graphon.nodes import BuiltinNodeTypes -from graphon.runtime import GraphRuntimeState, VariablePool -from tests.workflow_test_utils import build_test_graph_init_params - - -def _build_iteration_graph(node_id: str) -> dict[str, Any]: - return { - "nodes": [ - { - "id": node_id, - "data": { - "type": "iteration", - "title": "Iteration", - "iterator_selector": ["start", "items"], - "output_selector": [node_id, "output"], - }, - } - ], - "edges": [], - } - - -def _build_loop_graph(node_id: str) -> dict[str, Any]: - return { - "nodes": [ - { - "id": node_id, - "data": { - "type": "loop", - "title": "Loop", - "loop_count": 1, - "break_conditions": [], - "logical_operator": "and", - "loop_variables": [], - "outputs": {}, - }, - } - ], - "edges": [], - } - - -def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=default_system_variables(), - user_inputs={}, - environment_variables=[], - ), - start_at=0.0, - ) - return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) - - -def test_iteration_root_requires_skip_validation(): - node_id = "iteration-node" - graph_config = _build_iteration_graph(node_id) - node_factory = _make_factory(graph_config) - - with pytest.raises(GraphValidationError): - Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - ) - - graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - skip_validation=True, - ) - - assert graph.root_node.id == node_id - assert graph.root_node.node_type == BuiltinNodeTypes.ITERATION - - -def test_loop_root_requires_skip_validation(): - node_id = "loop-node" - graph_config = _build_loop_graph(node_id) - node_factory = _make_factory(graph_config) - - with pytest.raises(GraphValidationError): - Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - ) - - graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - skip_validation=True, - ) - - assert graph.root_node.id == node_id - assert graph.root_node.node_type == BuiltinNodeTypes.LOOP diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py deleted file mode 100644 index bfd0b483925..00000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ /dev/null @@ -1,219 +0,0 @@ -from __future__ import annotations - -import time -from collections.abc import Mapping -from dataclasses import dataclass - -import pytest - -from core.workflow.system_variables import build_system_variables -from graphon.entities import GraphInitParams -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeExecutionType, NodeType -from graphon.graph import Graph -from graphon.graph.validation import GraphValidationError -from graphon.nodes.base.node import Node -from graphon.runtime import GraphRuntimeState, VariablePool -from tests.workflow_test_utils import build_test_graph_init_params - - -class _TestNodeData(BaseNodeData): - type: NodeType | None = None - execution_type: NodeExecutionType | str | None = None - - -class _TestNode(Node[_TestNodeData]): - node_type = BuiltinNodeTypes.ANSWER - execution_type = NodeExecutionType.EXECUTABLE - - @classmethod - def version(cls) -> str: - return "1" - - def __init__( - self, - *, - id: str, - config: Mapping[str, object], - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - node_type_value = self.data.get("type") - if isinstance(node_type_value, str): - self.node_type = node_type_value - - def _run(self): - raise NotImplementedError - - def post_init(self) -> None: - super().post_init() - self._maybe_override_execution_type() - self.data = dict(self.node_data.model_dump()) - - def _maybe_override_execution_type(self) -> None: - execution_type_value = self.node_data.execution_type - if execution_type_value is None: - return - if isinstance(execution_type_value, NodeExecutionType): - self.execution_type = execution_type_value - else: - self.execution_type = NodeExecutionType(execution_type_value) - - -@dataclass(slots=True) -class _SimpleNodeFactory: - graph_init_params: GraphInitParams - graph_runtime_state: GraphRuntimeState - - def create_node(self, node_config: Mapping[str, object]) -> _TestNode: - node_id = str(node_config["id"]) - node = _TestNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) - return node - - -@pytest.fixture -def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: - graph_config: dict[str, object] = {"edges": [], "nodes": []} - init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="service-api", - call_depth=0, - ) - variable_pool = VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}) - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - factory = _SimpleNodeFactory(graph_init_params=init_params, graph_runtime_state=runtime_state) - return factory, graph_config - - -def test_graph_initialization_runs_default_validators( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -): - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - {"id": "answer", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "Answer"}}, - ] - graph_config["edges"] = [ - {"source": "start", "target": "answer", "sourceHandle": "success"}, - ] - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert graph.root_node.id == "start" - assert "answer" in graph.nodes - - -def test_graph_validation_fails_for_unknown_edge_targets( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - ] - graph_config["edges"] = [ - {"source": "start", "target": "missing", "sourceHandle": "success"}, - ] - - with pytest.raises(GraphValidationError) as exc: - Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert any(issue.code == "MISSING_NODE" for issue in exc.value.issues) - - -def test_graph_promotes_fail_branch_nodes_to_branch_execution_type( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - { - "id": "branch", - "data": { - "type": BuiltinNodeTypes.IF_ELSE, - "title": "Branch", - "error_strategy": ErrorStrategy.FAIL_BRANCH, - }, - }, - ] - graph_config["edges"] = [ - {"source": "start", "target": "branch", "sourceHandle": "success"}, - ] - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH - - -def test_graph_init_ignores_custom_note_nodes_before_node_data_validation( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - {"id": "answer", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "Answer"}}, - { - "id": "note", - "type": "custom-note", - "data": { - "type": "", - "title": "", - "desc": "", - "text": "{}", - "theme": "blue", - }, - }, - ] - graph_config["edges"] = [ - {"source": "start", "target": "answer", "sourceHandle": "success"}, - ] - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert graph.root_node.id == "start" - assert "answer" in graph.nodes - assert "note" not in graph.nodes - - -def test_graph_init_fails_for_unknown_root_node_id( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - ] - graph_config["edges"] = [] - - with pytest.raises(ValueError, match="Root node id missing not found in the graph"): - Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="missing") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/README.md b/api/tests/unit_tests/core/workflow/graph_engine/README.md index 960fef7d438..dd419f0810f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/README.md +++ b/api/tests/unit_tests/core/workflow/graph_engine/README.md @@ -1,441 +1,30 @@ -# Graph Engine Testing Framework +# Workflow Graph Engine Smoke Tests -## Overview +This directory now keeps only a small Dify-owned smoke layer around the external +`graphon` package. -This directory contains a comprehensive testing framework for the Graph Engine, including: +Retained coverage focuses on: -1. **TableTestRunner** - Advanced table-driven test framework for workflow testing -1. **Auto-Mock System** - Powerful mocking framework for testing without external dependencies +1. Dify workflow layers: + - `layers/test_llm_quota.py` + - `layers/test_observability.py` +2. Human-input resume integration: + - `test_parallel_human_input_join_resume.py` +3. One mocked tool/chatflow smoke path: + - `test_tool_in_chatflow.py` -## TableTestRunner Framework +The helper modules below remain only because the retained smoke tests use them: -The TableTestRunner (`test_table_runner.py`) provides a robust table-driven testing framework for GraphEngine workflows. +1. `test_mock_config.py` +2. `test_mock_factory.py` +3. `test_mock_nodes.py` +4. `test_table_runner.py` -### Features - -- **Table-driven testing** - Define test cases as structured data -- **Parallel test execution** - Run tests concurrently for faster execution -- **Property-based testing** - Integration with Hypothesis for fuzzing -- **Event sequence validation** - Verify correct event ordering -- **Mock configuration** - Seamless integration with the auto-mock system -- **Performance metrics** - Track execution times and bottlenecks -- **Detailed error reporting** - Comprehensive failure diagnostics - -### Basic Usage - -```python -from test_table_runner import TableTestRunner, WorkflowTestCase - -# Create test runner -runner = TableTestRunner() - -# Define test case -test_case = WorkflowTestCase( - fixture_path="simple_workflow", - inputs={"query": "Hello"}, - expected_outputs={"result": "World"}, - description="Basic workflow test", -) - -# Run single test -result = runner.run_test_case(test_case) -assert result.success -``` - -### Advanced Features - -#### Parallel Execution - -```python -runner = TableTestRunner(max_workers=8) - -test_cases = [ - WorkflowTestCase(...), - WorkflowTestCase(...), - # ... more test cases -] - -# Run tests in parallel -suite_result = runner.run_table_tests( - test_cases, - parallel=True, - fail_fast=False -) - -print(f"Success rate: {suite_result.success_rate:.1f}%") -``` - -#### Event Sequence Validation - -```python -from graphon.graph_events import ( - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, -) - -test_case = WorkflowTestCase( - fixture_path="workflow", - inputs={}, - expected_outputs={}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ] -) -``` - -### Test Suite Reports - -```python -# Run test suite -suite_result = runner.run_table_tests(test_cases) - -# Generate detailed report -report = runner.generate_report(suite_result) -print(report) - -# Access specific results -failed_results = suite_result.get_failed_results() -for result in failed_results: - print(f"Failed: {result.test_case.description}") - print(f" Error: {result.error}") -``` - -### Performance Testing - -```python -# Enable logging for performance insights -runner = TableTestRunner( - enable_logging=True, - log_level="DEBUG" -) - -# Run tests and analyze performance -suite_result = runner.run_table_tests(test_cases) - -# Get slowest tests -sorted_results = sorted( - suite_result.results, - key=lambda r: r.execution_time, - reverse=True -) - -print("Slowest tests:") -for result in sorted_results[:5]: - print(f" {result.test_case.description}: {result.execution_time:.2f}s") -``` - -## Integration: TableTestRunner + Auto-Mock System - -The TableTestRunner seamlessly integrates with the auto-mock system for comprehensive workflow testing: - -```python -from test_table_runner import TableTestRunner, WorkflowTestCase -from test_mock_config import MockConfigBuilder - -# Configure mocks -mock_config = (MockConfigBuilder() - .with_llm_response("Mocked LLM response") - .with_tool_response({"result": "mocked"}) - .with_delays(True) # Simulate realistic delays - .build()) - -# Create test case with mocking -test_case = WorkflowTestCase( - fixture_path="complex_workflow", - inputs={"query": "test"}, - expected_outputs={"answer": "Mocked LLM response"}, - use_auto_mock=True, # Enable auto-mocking - mock_config=mock_config, - description="Test with mocked services", -) - -# Run test -runner = TableTestRunner() -result = runner.run_test_case(test_case) -``` - -## Auto-Mock System - -The auto-mock system provides a powerful framework for testing workflows that contain nodes requiring third-party services (LLM, APIs, tools, etc.) without making actual external calls. This enables: - -- **Fast test execution** - No network latency or API rate limits -- **Deterministic results** - Consistent outputs for reliable testing -- **Cost savings** - No API usage charges during testing -- **Offline testing** - Tests can run without internet connectivity -- **Error simulation** - Test error handling without triggering real failures - -## Architecture - -The auto-mock system consists of three main components: - -### 1. MockNodeFactory (`test_mock_factory.py`) - -- Extends `DifyNodeFactory` to intercept node creation -- Automatically detects nodes requiring third-party services -- Returns mock node implementations instead of real ones -- Supports registration of custom mock implementations - -### 2. Mock Node Implementations (`test_mock_nodes.py`) - -- `MockLLMNode` - Mocks LLM API calls (OpenAI, Anthropic, etc.) -- `MockAgentNode` - Mocks agent execution -- `MockToolNode` - Mocks tool invocations -- `MockKnowledgeRetrievalNode` - Mocks knowledge base queries -- `MockHttpRequestNode` - Mocks HTTP requests -- `MockParameterExtractorNode` - Mocks parameter extraction -- `MockDocumentExtractorNode` - Mocks document processing -- `MockQuestionClassifierNode` - Mocks question classification - -### 3. Mock Configuration (`test_mock_config.py`) - -- `MockConfig` - Global configuration for mock behavior -- `NodeMockConfig` - Node-specific mock configuration -- `MockConfigBuilder` - Fluent interface for building configurations - -## Usage - -### Basic Example - -```python -from test_graph_engine import TableTestRunner, WorkflowTestCase -from test_mock_config import MockConfigBuilder - -# Create test runner -runner = TableTestRunner() - -# Configure mock responses -mock_config = (MockConfigBuilder() - .with_llm_response("Mocked LLM response") - .build()) - -# Define test case -test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Hello"}, - expected_outputs={"answer": "Mocked LLM response"}, - use_auto_mock=True, # Enable auto-mocking - mock_config=mock_config, -) - -# Run test -result = runner.run_test_case(test_case) -assert result.success -``` - -### Custom Node Outputs - -```python -# Configure specific outputs for individual nodes -mock_config = MockConfig() -mock_config.set_node_outputs("llm_node_123", { - "text": "Custom response for this specific node", - "usage": {"total_tokens": 50}, - "finish_reason": "stop", -}) -``` - -### Error Simulation - -```python -# Simulate node failures for error handling tests -mock_config = MockConfig() -mock_config.set_node_error("http_node", "Connection timeout") -``` - -### Simulated Delays - -```python -# Add realistic execution delays -from test_mock_config import NodeMockConfig - -node_config = NodeMockConfig( - node_id="llm_node", - outputs={"text": "Response"}, - delay=1.5, # 1.5 second delay -) -mock_config.set_node_config("llm_node", node_config) -``` - -### Custom Handlers - -```python -# Define custom logic for mock outputs -def custom_handler(node): - # Access node state and return dynamic outputs - return { - "text": f"Processed: {node.graph_runtime_state.variable_pool.get('query')}", - } - -node_config = NodeMockConfig( - node_id="llm_node", - custom_handler=custom_handler, -) -``` - -## Node Types Automatically Mocked - -The following node types are automatically mocked when `use_auto_mock=True`: - -- `LLM` - Language model nodes -- `AGENT` - Agent execution nodes -- `TOOL` - Tool invocation nodes -- `KNOWLEDGE_RETRIEVAL` - Knowledge base query nodes -- `HTTP_REQUEST` - HTTP request nodes -- `PARAMETER_EXTRACTOR` - Parameter extraction nodes -- `DOCUMENT_EXTRACTOR` - Document processing nodes -- `QUESTION_CLASSIFIER` - Question classification nodes - -## Advanced Features - -### Registering Custom Mock Implementations - -```python -from test_mock_factory import MockNodeFactory - -# Create custom mock implementation -class CustomMockNode(BaseNode): - def _run(self): - # Custom mock logic - pass - -# Register for a specific node type -factory = MockNodeFactory(...) -factory.register_mock_node_type(NodeType.CUSTOM, CustomMockNode) -``` - -### Default Configurations by Node Type - -```python -# Set defaults for all nodes of a specific type -mock_config.set_default_config(NodeType.LLM, { - "temperature": 0.7, - "max_tokens": 100, -}) -``` - -### MockConfigBuilder Fluent API - -```python -config = (MockConfigBuilder() - .with_llm_response("LLM response") - .with_agent_response("Agent response") - .with_tool_response({"result": "data"}) - .with_retrieval_response("Retrieved content") - .with_http_response({"status_code": 200, "body": "{}"}) - .with_node_output("node_id", {"output": "value"}) - .with_node_error("error_node", "Error message") - .with_delays(True) - .build()) -``` - -## Testing Workflows - -### 1. Create Workflow Fixture - -Create a YAML fixture file in `api/tests/fixtures/workflow/` directory defining your workflow graph. - -### 2. Configure Mocks - -Set up mock configurations for nodes that need third-party services. - -### 3. Define Test Cases - -Create `WorkflowTestCase` instances with inputs, expected outputs, and mock config. - -### 4. Run Tests - -Use `TableTestRunner` to execute test cases and validate results. - -## Best Practices - -1. **Use descriptive mock responses** - Make it clear in outputs that they are mocked -1. **Test both success and failure paths** - Use error simulation to test error handling -1. **Keep mock configs close to tests** - Define mocks in the same test file for clarity -1. **Use custom handlers sparingly** - Only when dynamic behavior is needed -1. **Document mock behavior** - Comment why specific mock values are chosen -1. **Validate mock accuracy** - Ensure mocks reflect real service behavior - -## Examples - -See `test_mock_example.py` for comprehensive examples including: - -- Basic LLM workflow testing -- Custom node outputs -- HTTP and tool workflow testing -- Error simulation -- Performance testing with delays - -## Running Tests - -### TableTestRunner Tests +Examples: ```bash -# Run graph engine tests (includes property-based tests) -uv run pytest api/tests/unit_tests/graphon/graph_engine/test_graph_engine.py - -# Run with specific test patterns -uv run pytest api/tests/unit_tests/graphon/graph_engine/test_graph_engine.py -k "test_echo" - -# Run with verbose output -uv run pytest api/tests/unit_tests/graphon/graph_engine/test_graph_engine.py -v +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py ``` - -### Mock System Tests - -```bash -# Run auto-mock system tests -uv run pytest api/tests/unit_tests/graphon/graph_engine/test_auto_mock_system.py - -# Run examples -uv run python api/tests/unit_tests/graphon/graph_engine/test_mock_example.py - -# Run simple validation -uv run python api/tests/unit_tests/graphon/graph_engine/test_mock_simple.py -``` - -### All Tests - -```bash -# Run all graph engine tests -uv run pytest api/tests/unit_tests/graphon/graph_engine/ - -# Run with coverage -uv run pytest api/tests/unit_tests/graphon/graph_engine/ --cov=graphon.graph_engine - -# Run in parallel -uv run pytest api/tests/unit_tests/graphon/graph_engine/ -n auto -``` - -## Troubleshooting - -### Issue: Mock not being applied - -- Ensure `use_auto_mock=True` in `WorkflowTestCase` -- Verify node ID matches in mock config -- Check that node type is in the auto-mock list - -### Issue: Unexpected outputs - -- Debug by printing `result.actual_outputs` -- Check if custom handler is overriding expected outputs -- Verify mock config is properly built - -### Issue: Import errors - -- Ensure all mock modules are in the correct path -- Check that required dependencies are installed - -## Future Enhancements - -Potential improvements to the auto-mock system: - -1. **Recording and playback** - Record real API responses for replay in tests -1. **Mock templates** - Pre-defined mock configurations for common scenarios -1. **Async support** - Better support for async node execution -1. **Mock validation** - Validate mock outputs against node schemas -1. **Performance profiling** - Built-in performance metrics for mocked workflows diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py deleted file mode 100644 index 795362b1580..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ /dev/null @@ -1,315 +0,0 @@ -"""Tests for Redis command channel implementation.""" - -import json -from unittest.mock import MagicMock - -from graphon.graph_engine.command_channels.redis_channel import RedisChannel -from graphon.graph_engine.entities.commands import ( - AbortCommand, - CommandType, - GraphEngineCommand, - UpdateVariablesCommand, - VariableUpdate, -) -from graphon.variables import IntegerVariable, StringVariable - - -class TestRedisChannel: - """Test suite for RedisChannel functionality.""" - - def test_init(self): - """Test RedisChannel initialization.""" - mock_redis = MagicMock() - channel_key = "test:channel:key" - ttl = 7200 - - channel = RedisChannel(mock_redis, channel_key, ttl) - - assert channel._redis == mock_redis - assert channel._key == channel_key - assert channel._command_ttl == ttl - - def test_init_default_ttl(self): - """Test RedisChannel initialization with default TTL.""" - mock_redis = MagicMock() - channel_key = "test:channel:key" - - channel = RedisChannel(mock_redis, channel_key) - - assert channel._command_ttl == 3600 # Default TTL - - def test_send_command(self): - """Test sending a command to Redis.""" - mock_redis = MagicMock() - mock_pipe = MagicMock() - context = MagicMock() - context.__enter__.return_value = mock_pipe - context.__exit__.return_value = None - mock_redis.pipeline.return_value = context - - channel = RedisChannel(mock_redis, "test:key", 3600) - - pending_key = "test:key:pending" - - # Create a test command - command = GraphEngineCommand(command_type=CommandType.ABORT) - - # Send the command - channel.send_command(command) - - # Verify pipeline was used - mock_redis.pipeline.assert_called_once() - - # Verify rpush was called with correct data - expected_json = json.dumps(command.model_dump()) - mock_pipe.rpush.assert_called_once_with("test:key", expected_json) - - # Verify expire was set - mock_pipe.expire.assert_called_once_with("test:key", 3600) - mock_pipe.set.assert_called_once_with(pending_key, "1", ex=3600) - - # Verify execute was called - mock_pipe.execute.assert_called_once() - - def test_fetch_commands_empty(self): - """Test fetching commands when Redis list is empty.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context] - - # No pending marker - pending_pipe.execute.return_value = [None, 0] - mock_redis.llen.return_value = 0 - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert commands == [] - mock_redis.pipeline.assert_called_once() - fetch_pipe.lrange.assert_not_called() - fetch_pipe.delete.assert_not_called() - - def test_fetch_commands_with_abort_command(self): - """Test fetching abort commands from Redis.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Create abort command data - abort_command = AbortCommand() - command_json = json.dumps(abort_command.model_dump()) - - # Simulate Redis returning one command - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert len(commands) == 1 - assert isinstance(commands[0], AbortCommand) - assert commands[0].command_type == CommandType.ABORT - - def test_fetch_commands_multiple(self): - """Test fetching multiple commands from Redis.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Create multiple commands - command1 = GraphEngineCommand(command_type=CommandType.ABORT) - command2 = AbortCommand() - - command1_json = json.dumps(command1.model_dump()) - command2_json = json.dumps(command2.model_dump()) - - # Simulate Redis returning multiple commands - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command1_json.encode(), command2_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert len(commands) == 2 - assert commands[0].command_type == CommandType.ABORT - assert isinstance(commands[1], AbortCommand) - - def test_fetch_commands_with_update_variables_command(self): - """Test fetching update variables command from Redis.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - update_command = UpdateVariablesCommand( - updates=[ - VariableUpdate( - value=StringVariable(name="foo", value="bar", selector=["node1", "foo"]), - ), - VariableUpdate( - value=IntegerVariable(name="baz", value=123, selector=["node2", "baz"]), - ), - ] - ) - command_json = json.dumps(update_command.model_dump()) - - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert len(commands) == 1 - assert isinstance(commands[0], UpdateVariablesCommand) - assert isinstance(commands[0].updates[0].value, StringVariable) - assert list(commands[0].updates[0].value.selector) == ["node1", "foo"] - assert commands[0].updates[0].value.value == "bar" - - def test_fetch_commands_skips_invalid_json(self): - """Test that invalid JSON commands are skipped.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Mix valid and invalid JSON - valid_command = AbortCommand() - valid_json = json.dumps(valid_command.model_dump()) - invalid_json = b"invalid json {" - - # Simulate Redis returning mixed valid/invalid commands - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[invalid_json, valid_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - # Should only return the valid command - assert len(commands) == 1 - assert isinstance(commands[0], AbortCommand) - - def test_deserialize_command_abort(self): - """Test deserializing an abort command.""" - channel = RedisChannel(MagicMock(), "test:key") - - abort_data = {"command_type": CommandType.ABORT} - command = channel._deserialize_command(abort_data) - - assert isinstance(command, AbortCommand) - assert command.command_type == CommandType.ABORT - - def test_deserialize_command_generic(self): - """Test deserializing a generic command.""" - channel = RedisChannel(MagicMock(), "test:key") - - # For now, only ABORT is supported, but test generic handling - generic_data = {"command_type": CommandType.ABORT} - command = channel._deserialize_command(generic_data) - - assert command is not None - assert command.command_type == CommandType.ABORT - - def test_deserialize_command_invalid(self): - """Test deserializing invalid command data.""" - channel = RedisChannel(MagicMock(), "test:key") - - # Missing command_type - invalid_data = {"some_field": "value"} - command = channel._deserialize_command(invalid_data) - - assert command is None - - def test_deserialize_command_invalid_type(self): - """Test deserializing command with invalid type.""" - channel = RedisChannel(MagicMock(), "test:key") - - # Invalid command type - invalid_data = {"command_type": "INVALID_TYPE"} - command = channel._deserialize_command(invalid_data) - - assert command is None - - def test_atomic_fetch_and_clear(self): - """Test that fetch_commands atomically fetches and clears the list.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - command = AbortCommand() - command_json = json.dumps(command.model_dump()) - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - - # First fetch should return the command - commands = channel.fetch_commands() - assert len(commands) == 1 - - # Verify both lrange and delete were called in the pipeline - assert fetch_pipe.lrange.call_count == 1 - assert fetch_pipe.delete.call_count == 1 - fetch_pipe.lrange.assert_called_with("test:key", 0, -1) - fetch_pipe.delete.assert_called_with("test:key") - - def test_fetch_commands_without_pending_marker_returns_empty(self): - """Ensure we avoid unnecessary list reads when pending flag is missing.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Pending flag absent - pending_pipe.execute.return_value = [None, 0] - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert commands == [] - mock_redis.llen.assert_not_called() - assert mock_redis.pipeline.call_count == 1 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py deleted file mode 100644 index cacbe9ba4e9..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Tests for graph engine event handlers.""" - -from __future__ import annotations - -from graphon.entities.base_node_data import RetryConfig -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.graph_engine.domain.graph_execution import GraphExecution -from graphon.graph_engine.event_management.event_handlers import EventHandler -from graphon.graph_engine.event_management.event_manager import EventManager -from graphon.graph_engine.graph_state_manager import GraphStateManager -from graphon.graph_engine.ready_queue.in_memory import InMemoryReadyQueue -from graphon.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator -from graphon.graph_events import NodeRunRetryEvent, NodeRunStartedEvent -from graphon.node_events import NodeRunResult -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now - - -class _StubEdgeProcessor: - """Minimal edge processor stub for tests.""" - - -class _StubErrorHandler: - """Minimal error handler stub for tests.""" - - -class _StubNode: - """Simple node stub exposing the attributes needed by the state manager.""" - - def __init__(self, node_id: str) -> None: - self.id = node_id - self.state = NodeState.UNKNOWN - self.title = "Stub Node" - self.execution_type = NodeExecutionType.EXECUTABLE - self.error_strategy = None - self.retry_config = RetryConfig() - self.retry = False - - -def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]: - """Construct an EventHandler with in-memory dependencies for testing.""" - - node = _StubNode(node_id) - graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node) - - variable_pool = VariablePool() - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) - graph_execution = GraphExecution(workflow_id="test-workflow") - - event_manager = EventManager() - state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue()) - response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph) - - handler = EventHandler( - graph=graph, - graph_runtime_state=runtime_state, - graph_execution=graph_execution, - response_coordinator=response_coordinator, - event_collector=event_manager, - edge_processor=_StubEdgeProcessor(), - state_manager=state_manager, - error_handler=_StubErrorHandler(), - ) - - return handler, event_manager, graph_execution - - -def test_retry_does_not_emit_additional_start_event() -> None: - """Ensure retry attempts do not produce duplicate start events.""" - - node_id = "test-node" - handler, event_manager, graph_execution = _build_event_handler(node_id) - - execution_id = "exec-1" - node_type = BuiltinNodeTypes.CODE - start_time = naive_utc_now() - - start_event = NodeRunStartedEvent( - id=execution_id, - node_id=node_id, - node_type=node_type, - node_title="Stub Node", - start_at=start_time, - ) - handler.dispatch(start_event) - - retry_event = NodeRunRetryEvent( - id=execution_id, - node_id=node_id, - node_type=node_type, - node_title="Stub Node", - start_at=start_time, - error="boom", - retry_index=1, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error="boom", - error_type="TestError", - ), - ) - handler.dispatch(retry_event) - - # Simulate the node starting execution again after retry - second_start_event = NodeRunStartedEvent( - id=execution_id, - node_id=node_id, - node_type=node_type, - node_title="Stub Node", - start_at=start_time, - ) - handler.dispatch(second_start_event) - - collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined] - - assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent] - - node_execution = graph_execution.get_or_create_node_execution(node_id) - assert node_execution.retry_count == 1 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py deleted file mode 100644 index dc0998caf18..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Tests for the EventManager.""" - -from __future__ import annotations - -import logging - -from graphon.graph_engine.event_management.event_manager import EventManager -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent - - -class _FaultyLayer(GraphEngineLayer): - """Layer that raises from on_event to test error handling.""" - - def on_graph_start(self) -> None: # pragma: no cover - not used in tests - pass - - def on_event(self, event: GraphEngineEvent) -> None: - raise RuntimeError("boom") - - def on_graph_end(self, error: Exception | None) -> None: # pragma: no cover - not used in tests - pass - - -def test_event_manager_logs_layer_errors(caplog) -> None: - """Ensure errors raised by layers are logged when collecting events.""" - - event_manager = EventManager() - event_manager.set_layers([_FaultyLayer()]) - - with caplog.at_level(logging.ERROR): - event_manager.collect(GraphEngineEvent()) - - error_logs = [record for record in caplog.records if "Error in layer on_event" in record.getMessage()] - assert error_logs, "Expected layer errors to be logged" - - log_record = error_logs[0] - assert log_record.exc_info is not None - assert isinstance(log_record.exc_info[1], RuntimeError) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py deleted file mode 100644 index cf8811dc2b5..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for graph traversal components.""" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py deleted file mode 100644 index b030496eb16..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py +++ /dev/null @@ -1,307 +0,0 @@ -"""Unit tests for skip propagator.""" - -from unittest.mock import MagicMock, create_autospec - -from graphon.graph import Edge, Graph -from graphon.graph_engine.graph_state_manager import GraphStateManager -from graphon.graph_engine.graph_traversal.skip_propagator import SkipPropagator - - -class TestSkipPropagator: - """Test suite for SkipPropagator.""" - - def test_propagate_skip_from_edge_with_unknown_edges_stops_processing(self) -> None: - """When there are unknown incoming edges, propagation should stop.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create a mock edge - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - # Setup graph edges dict - mock_graph.edges = {"edge_1": mock_edge} - - # Setup incoming edges - incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Setup state manager to return has_unknown=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": True, - "has_taken": False, - "all_skipped": False, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - mock_graph.get_incoming_edges.assert_called_once_with("node_2") - mock_state_manager.analyze_edge_states.assert_called_once_with(incoming_edges) - # Should not call any other state manager methods - mock_state_manager.enqueue_node.assert_not_called() - mock_state_manager.start_execution.assert_not_called() - mock_state_manager.mark_node_skipped.assert_not_called() - - def test_propagate_skip_from_edge_with_taken_edge_enqueues_node(self) -> None: - """When there is at least one taken edge, node should be enqueued.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create a mock edge - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - mock_graph.edges = {"edge_1": mock_edge} - incoming_edges = [MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Setup state manager to return has_taken=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": True, - "all_skipped": False, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - mock_state_manager.enqueue_node.assert_called_once_with("node_2") - mock_state_manager.start_execution.assert_called_once_with("node_2") - mock_state_manager.mark_node_skipped.assert_not_called() - - def test_propagate_skip_from_edge_with_all_skipped_propagates_to_node(self) -> None: - """When all incoming edges are skipped, should propagate skip to node.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create a mock edge - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - mock_graph.edges = {"edge_1": mock_edge} - incoming_edges = [MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Setup state manager to return all_skipped=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": False, - "all_skipped": True, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - mock_state_manager.mark_node_skipped.assert_called_once_with("node_2") - mock_state_manager.enqueue_node.assert_not_called() - mock_state_manager.start_execution.assert_not_called() - - def test_propagate_skip_to_node_marks_node_and_outgoing_edges_skipped(self) -> None: - """_propagate_skip_to_node should mark node and all outgoing edges as skipped.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create outgoing edges - edge1 = MagicMock(spec=Edge) - edge1.id = "edge_2" - edge1.head = "node_downstream_1" # Set head for propagate_skip_from_edge - - edge2 = MagicMock(spec=Edge) - edge2.id = "edge_3" - edge2.head = "node_downstream_2" - - # Setup graph edges dict for propagate_skip_from_edge - mock_graph.edges = {"edge_2": edge1, "edge_3": edge2} - mock_graph.get_outgoing_edges.return_value = [edge1, edge2] - - # Setup get_incoming_edges to return empty list to stop recursion - mock_graph.get_incoming_edges.return_value = [] - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Use mock to call private method - # Act - propagator._propagate_skip_to_node("node_1") - - # Assert - mock_state_manager.mark_node_skipped.assert_called_once_with("node_1") - mock_state_manager.mark_edge_skipped.assert_any_call("edge_2") - mock_state_manager.mark_edge_skipped.assert_any_call("edge_3") - assert mock_state_manager.mark_edge_skipped.call_count == 2 - # Should recursively propagate from each edge - # Since propagate_skip_from_edge is called, we need to verify it was called - # But we can't directly verify due to recursion. We'll trust the logic. - - def test_skip_branch_paths_marks_unselected_edges_and_propagates(self) -> None: - """skip_branch_paths should mark all unselected edges as skipped and propagate.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create unselected edges - edge1 = MagicMock(spec=Edge) - edge1.id = "edge_1" - edge1.head = "node_downstream_1" - - edge2 = MagicMock(spec=Edge) - edge2.id = "edge_2" - edge2.head = "node_downstream_2" - - unselected_edges = [edge1, edge2] - - # Setup graph edges dict - mock_graph.edges = {"edge_1": edge1, "edge_2": edge2} - # Setup get_incoming_edges to return empty list to stop recursion - mock_graph.get_incoming_edges.return_value = [] - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.skip_branch_paths(unselected_edges) - - # Assert - mock_state_manager.mark_edge_skipped.assert_any_call("edge_1") - mock_state_manager.mark_edge_skipped.assert_any_call("edge_2") - assert mock_state_manager.mark_edge_skipped.call_count == 2 - # propagate_skip_from_edge should be called for each edge - # We can't directly verify due to the mock, but the logic is covered - - def test_propagate_skip_from_edge_recursively_propagates_through_graph(self) -> None: - """Skip propagation should recursively propagate through the graph.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create edge chain: edge_1 -> node_2 -> edge_3 -> node_4 - edge1 = MagicMock(spec=Edge) - edge1.id = "edge_1" - edge1.head = "node_2" - - edge3 = MagicMock(spec=Edge) - edge3.id = "edge_3" - edge3.head = "node_4" - - mock_graph.edges = {"edge_1": edge1, "edge_3": edge3} - - # Setup get_incoming_edges to return different values based on node - def get_incoming_edges_side_effect(node_id): - if node_id == "node_2": - return [edge1] - elif node_id == "node_4": - return [edge3] - return [] - - mock_graph.get_incoming_edges.side_effect = get_incoming_edges_side_effect - - # Setup get_outgoing_edges to return different values based on node - def get_outgoing_edges_side_effect(node_id): - if node_id == "node_2": - return [edge3] - elif node_id == "node_4": - return [] # No outgoing edges, stops recursion - return [] - - mock_graph.get_outgoing_edges.side_effect = get_outgoing_edges_side_effect - - # Setup state manager to return all_skipped for both nodes - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": False, - "all_skipped": True, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - # Should mark node_2 as skipped - mock_state_manager.mark_node_skipped.assert_any_call("node_2") - # Should mark edge_3 as skipped - mock_state_manager.mark_edge_skipped.assert_any_call("edge_3") - # Should propagate to node_4 - mock_state_manager.mark_node_skipped.assert_any_call("node_4") - assert mock_state_manager.mark_node_skipped.call_count == 2 - - def test_propagate_skip_from_edge_with_mixed_edge_states_handles_correctly(self) -> None: - """Test with mixed edge states (some unknown, some taken, some skipped).""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - mock_graph.edges = {"edge_1": mock_edge} - incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge), MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Test 1: has_unknown=True, has_taken=False, all_skipped=False - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": True, - "has_taken": False, - "all_skipped": False, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - should stop processing - mock_state_manager.enqueue_node.assert_not_called() - mock_state_manager.mark_node_skipped.assert_not_called() - - # Reset mocks for next test - mock_state_manager.reset_mock() - mock_graph.reset_mock() - - # Test 2: has_unknown=False, has_taken=True, all_skipped=False - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": True, - "all_skipped": False, - } - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - should enqueue node - mock_state_manager.enqueue_node.assert_called_once_with("node_2") - mock_state_manager.start_execution.assert_called_once_with("node_2") - - # Reset mocks for next test - mock_state_manager.reset_mock() - mock_graph.reset_mock() - - # Test 3: has_unknown=False, has_taken=False, all_skipped=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": False, - "all_skipped": True, - } - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - should propagate skip - mock_state_manager.mark_node_skipped.assert_called_once_with("node_2") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py deleted file mode 100644 index 2fead1d7196..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Utilities for testing HumanInputNode without database dependencies.""" - -from __future__ import annotations - -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRecipientEntity, - HumanInputFormRepository, -) -from graphon.nodes.human_input.enums import HumanInputFormStatus -from libs.datetime_utils import naive_utc_now - - -class _InMemoryFormRecipient(HumanInputFormRecipientEntity): - """Minimal recipient entity required by the repository interface.""" - - def __init__(self, recipient_id: str, token: str) -> None: - self._id = recipient_id - self._token = token - - @property - def id(self) -> str: - return self._id - - @property - def token(self) -> str: - return self._token - - -@dataclass -class _InMemoryFormEntity(HumanInputFormEntity): - form_id: str - rendered: str - token: str | None = None - action_id: str | None = None - data: Mapping[str, Any] | None = None - is_submitted: bool = False - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() - - @property - def id(self) -> str: - return self.form_id - - @property - def submission_token(self) -> str | None: - return self.token - - @property - def recipients(self) -> list[HumanInputFormRecipientEntity]: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class InMemoryHumanInputFormRepository(HumanInputFormRepository): - """Pure in-memory repository used by workflow graph engine tests.""" - - def __init__(self) -> None: - self._form_counter = 0 - self.created_params: list[FormCreateParams] = [] - self.created_forms: list[_InMemoryFormEntity] = [] - self._forms_by_node_id: dict[str, _InMemoryFormEntity] = {} - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - self.created_params.append(params) - self._form_counter += 1 - form_id = f"form-{self._form_counter}" - token = f"token-{form_id}" - entity = _InMemoryFormEntity( - form_id=form_id, - rendered=params.rendered_content, - token=token, - ) - self.created_forms.append(entity) - self._forms_by_node_id[params.node_id] = entity - return entity - - def get_form(self, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_node_id.get(node_id) - - # Convenience helpers for tests ------------------------------------- - - def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None: - """Simulate a human submission for the next repository lookup.""" - - if not self.created_forms: - raise AssertionError("no form has been created to attach submission data") - entity = self.created_forms[-1] - entity.action_id = action_id - entity.data = form_data or {} - entity.is_submitted = True - entity.status_value = HumanInputFormStatus.SUBMITTED - entity.expiration = naive_utc_now() + timedelta(days=1) - - def clear_submission(self) -> None: - if not self.created_forms: - return - for form in self.created_forms: - form.action_id = None - form.data = None - form.is_submitted = False - form.status_value = HumanInputFormStatus.WAITING diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index b642dc82fe5..41627f5e0be 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -5,13 +5,12 @@ Shared fixtures for ObservabilityLayer tests. from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import set_tracer_provider -from graphon.enums import BuiltinNodeTypes - @pytest.fixture def memory_span_exporter(): @@ -62,9 +61,10 @@ def mock_llm_node(): @pytest.fixture def mock_tool_node(): """Create a mock Tool Node with tool-specific attributes.""" - from core.tools.entities.tool_entities import ToolProviderType from graphon.nodes.tool.entities import ToolNodeData + from core.tools.entities.tool_entities import ToolProviderType + node = MagicMock() node.id = "test-tool-node-id" node.title = "Test Tool Node" @@ -117,8 +117,8 @@ def mock_result_event(): """Create a mock result event with NodeRunResult.""" from datetime import datetime - from graphon.graph_events.node import NodeRunSucceededEvent - from graphon.node_events.base import NodeRunResult + from graphon.graph_events import NodeRunSucceededEvent + from graphon.node_events import NodeRunResult node_run_result = NodeRunResult( inputs={"query": "test query"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py deleted file mode 100644 index 7ff77c19c18..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations - -import pytest - -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_engine.layers.base import ( - GraphEngineLayer, - GraphEngineLayerNotInitializedError, -) -from graphon.graph_events import GraphEngineEvent - -from ..test_table_runner import WorkflowRunner - - -class LayerForTest(GraphEngineLayer): - def on_graph_start(self) -> None: - pass - - def on_event(self, event: GraphEngineEvent) -> None: - pass - - def on_graph_end(self, error: Exception | None) -> None: - pass - - -def test_layer_runtime_state_raises_when_uninitialized() -> None: - layer = LayerForTest() - - with pytest.raises(GraphEngineLayerNotInitializedError): - _ = layer.graph_runtime_state - - -def test_layer_runtime_state_available_after_engine_layer() -> None: - runner = WorkflowRunner() - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture( - fixture_data, - inputs={"query": "test layer state"}, - ) - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - layer = LayerForTest() - engine.layer(layer) - - outputs = layer.graph_runtime_state.outputs - ready_queue_size = layer.graph_runtime_state.ready_queue_size - - assert outputs == {} - assert isinstance(ready_queue_size, int) - assert ready_queue_size >= 0 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 80874e768ae..99d131737e9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -3,15 +3,16 @@ from datetime import datetime from types import SimpleNamespace from unittest.mock import MagicMock, patch +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.entities.commands import CommandType +from graphon.graph_events import NodeRunSucceededEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import NodeRunResult + from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph_engine.entities.commands import CommandType -from graphon.graph_events.node import NodeRunSucceededEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import NodeRunResult def _build_dify_context() -> DifyRunContext: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py index 14ce55938d5..9cf72763ee2 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -13,10 +13,10 @@ Test coverage: from unittest.mock import patch import pytest +from graphon.enums import BuiltinNodeTypes from opentelemetry.trace import StatusCode from core.app.workflow.layers.observability import ObservabilityLayer -from graphon.enums import BuiltinNodeTypes class TestObservabilityLayerInitialization: @@ -144,7 +144,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node, mock_result_event ): """Test that LLM parser is used for LLM nodes and extracts LLM-specific attributes.""" - from graphon.node_events.base import NodeRunResult + from graphon.node_events import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={}, @@ -182,7 +182,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_retrieval_node, mock_result_event ): """Test that retrieval parser is used for retrieval nodes and extracts retrieval-specific attributes.""" - from graphon.node_events.base import NodeRunResult + from graphon.node_events import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"query": "test query"}, @@ -210,7 +210,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node, mock_result_event ): """Test that result_event parameter allows parsers to extract inputs and outputs.""" - from graphon.node_events.base import NodeRunResult + from graphon.node_events import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"input_key": "input_value"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py deleted file mode 100644 index ab3a31f673a..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py +++ /dev/null @@ -1,189 +0,0 @@ -"""Tests for dispatcher command checking behavior.""" - -from __future__ import annotations - -import queue -from unittest import mock - -from graphon.entities.pause_reason import SchedulingPause -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph_engine.event_management.event_handlers import EventHandler -from graphon.graph_engine.orchestration.dispatcher import Dispatcher -from graphon.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from graphon.graph_events import ( - GraphNodeEventBase, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from graphon.node_events import NodeRunResult -from libs.datetime_utils import naive_utc_now - - -def test_dispatcher_should_consume_remains_events_after_pause(): - event_queue = queue.Queue() - event_queue.put( - GraphNodeEventBase( - id="test", - node_id="test", - node_type=BuiltinNodeTypes.START, - ) - ) - event_handler = mock.Mock(spec=EventHandler) - execution_coordinator = mock.Mock(spec=ExecutionCoordinator) - execution_coordinator.paused.return_value = True - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=event_handler, - execution_coordinator=execution_coordinator, - ) - dispatcher._dispatcher_loop() - assert event_queue.empty() - - -class _StubExecutionCoordinator: - """Stub execution coordinator that tracks command checks.""" - - def __init__(self) -> None: - self.command_checks = 0 - self.scaling_checks = 0 - self.execution_complete = False - self.failed = False - self._paused = False - - def process_commands(self) -> None: - self.command_checks += 1 - - def check_scaling(self) -> None: - self.scaling_checks += 1 - - @property - def paused(self) -> bool: - return self._paused - - @property - def aborted(self) -> bool: - return False - - def mark_complete(self) -> None: - self.execution_complete = True - - def mark_failed(self, error: Exception) -> None: # pragma: no cover - defensive, not triggered in tests - self.failed = True - - -class _StubEventHandler: - """Minimal event handler that marks execution complete after handling an event.""" - - def __init__(self, coordinator: _StubExecutionCoordinator) -> None: - self._coordinator = coordinator - self.events = [] - - def dispatch(self, event) -> None: - self.events.append(event) - self._coordinator.mark_complete() - - -def _run_dispatcher_for_event(event) -> int: - """Run the dispatcher loop for a single event and return command check count.""" - event_queue: queue.Queue = queue.Queue() - event_queue.put(event) - - coordinator = _StubExecutionCoordinator() - event_handler = _StubEventHandler(coordinator) - - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=event_handler, - execution_coordinator=coordinator, - ) - - dispatcher._dispatcher_loop() - - return coordinator.command_checks - - -def _make_started_event() -> NodeRunStartedEvent: - return NodeRunStartedEvent( - id="start-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - node_title="Test Node", - start_at=naive_utc_now(), - ) - - -def _make_succeeded_event() -> NodeRunSucceededEvent: - return NodeRunSucceededEvent( - id="success-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - node_title="Test Node", - start_at=naive_utc_now(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ) - - -def test_dispatcher_checks_commands_during_idle_and_on_completion() -> None: - """Dispatcher polls commands when idle and after completion events.""" - started_checks = _run_dispatcher_for_event(_make_started_event()) - succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event()) - - assert started_checks == 2 - assert succeeded_checks == 3 - - -class _PauseStubEventHandler: - """Minimal event handler that marks execution complete after handling an event.""" - - def __init__(self, coordinator: _StubExecutionCoordinator) -> None: - self._coordinator = coordinator - self.events = [] - - def dispatch(self, event) -> None: - self.events.append(event) - if isinstance(event, NodeRunPauseRequestedEvent): - self._coordinator.mark_complete() - - -def test_dispatcher_drain_event_queue(): - events = [ - NodeRunStartedEvent( - id="start-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - node_title="Code", - start_at=naive_utc_now(), - ), - NodeRunPauseRequestedEvent( - id="pause-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - reason=SchedulingPause(message="test pause"), - ), - NodeRunSucceededEvent( - id="success-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - start_at=naive_utc_now(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ), - ] - - event_queue: queue.Queue = queue.Queue() - for e in events: - event_queue.put(e) - - coordinator = _StubExecutionCoordinator() - event_handler = _PauseStubEventHandler(coordinator) - - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=event_handler, - execution_coordinator=coordinator, - ) - - dispatcher._dispatcher_loop() - - # ensure all events are drained. - assert event_queue.empty() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py deleted file mode 100644 index 1510c8e595c..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py +++ /dev/null @@ -1,37 +0,0 @@ -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_answer_end_with_text(): - fixture_name = "answer_end_with_text" - case = WorkflowTestCase( - fixture_name, - query="Hello, AI!", - expected_outputs={"answer": "prefixHello, AI!suffix"}, - expected_event_sequence=[ - GraphRunStartedEvent, - # Start - NodeRunStartedEvent, - # The chunks are now emitted as the Answer node processes them - # since sys.query is a special selector that gets attributed to - # the active response node - NodeRunStreamChunkEvent, # prefix - NodeRunStreamChunkEvent, # sys.query - NodeRunStreamChunkEvent, # suffix - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py deleted file mode 100644 index 6569439b568..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py +++ /dev/null @@ -1,28 +0,0 @@ -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - -LLM_NODE_ID = "1759052580454" - - -def test_answer_nodes_emit_in_order() -> None: - mock_config = ( - MockConfigBuilder() - .with_llm_response("unused default") - .with_node_output(LLM_NODE_ID, {"text": "mocked llm text"}) - .build() - ) - - expected_answer = "--- answer 1 ---\n\nfoo\n--- answer 2 ---\n\nmocked llm text\n" - - case = WorkflowTestCase( - fixture_path="test-answer-order", - query="", - expected_outputs={"answer": expected_answer}, - use_auto_mock=True, - mock_config=mock_config, - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - - assert result.success, result.error diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py deleted file mode 100644 index 05ec565def6..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py +++ /dev/null @@ -1,24 +0,0 @@ -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_array_iteration_formatting_workflow(): - """ - Validate Iteration node processes [1,2,3] into formatted strings. - - Fixture description expects: - {"output": ["output: 1", "output: 2", "output: 3"]} - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="array_iteration_formatting_workflow", - inputs={}, - expected_outputs={"output": ["output: 1", "output: 2", "output: 3"]}, - description="Iteration formats numbers into strings", - use_auto_mock=True, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Iteration workflow failed: {result.error}" - assert result.actual_outputs == test_case.expected_outputs diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py deleted file mode 100644 index 5d0b37acc5f..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py +++ /dev/null @@ -1,392 +0,0 @@ -""" -Tests for the auto-mock system. - -This module contains tests that validate the auto-mock functionality -for workflows containing nodes that require third-party services. -""" - -import pytest - -from graphon.enums import BuiltinNodeTypes -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_simple_llm_workflow_with_auto_mock(): - """Test that a simple LLM workflow runs successfully with auto-mocking.""" - runner = TableTestRunner() - - # Create mock configuration - mock_config = MockConfigBuilder().with_llm_response("This is a test response from mocked LLM").build() - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Hello, how are you?"}, - expected_outputs={"answer": "This is a test response from mocked LLM"}, - description="Simple LLM workflow with auto-mock", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs is not None - assert "answer" in result.actual_outputs - assert result.actual_outputs["answer"] == "This is a test response from mocked LLM" - - -def test_llm_workflow_with_custom_node_output(): - """Test LLM workflow with custom output for specific node.""" - runner = TableTestRunner() - - # Create mock configuration with custom output for specific node - mock_config = MockConfig() - mock_config.set_node_outputs( - "llm_node", - { - "text": "Custom response for this specific node", - "usage": { - "prompt_tokens": 20, - "completion_tokens": 10, - "total_tokens": 30, - }, - "finish_reason": "stop", - }, - ) - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Test query"}, - expected_outputs={"answer": "Custom response for this specific node"}, - description="LLM workflow with custom node output", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs is not None - assert result.actual_outputs["answer"] == "Custom response for this specific node" - - -def test_http_tool_workflow_with_auto_mock(): - """Test workflow with HTTP request and tool nodes using auto-mock.""" - runner = TableTestRunner() - - # Create mock configuration - mock_config = MockConfig() - mock_config.set_node_outputs( - "http_node", - { - "status_code": 200, - "body": '{"key": "value", "number": 42}', - "headers": {"content-type": "application/json"}, - }, - ) - mock_config.set_node_outputs( - "tool_node", - { - "result": {"key": "value", "number": 42}, - }, - ) - - test_case = WorkflowTestCase( - fixture_path="http_request_with_json_tool_workflow", - inputs={"url": "https://api.example.com/data"}, - expected_outputs={ - "status_code": 200, - "parsed_data": {"key": "value", "number": 42}, - }, - description="HTTP and Tool workflow with auto-mock", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs is not None - assert result.actual_outputs["status_code"] == 200 - assert result.actual_outputs["parsed_data"] == {"key": "value", "number": 42} - - -def test_workflow_with_simulated_node_error(): - """Test that workflows handle simulated node errors correctly.""" - runner = TableTestRunner() - - # Create mock configuration with error - mock_config = MockConfig() - mock_config.set_node_error("llm_node", "Simulated LLM API error") - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "This should fail"}, - expected_outputs={}, # We expect failure, so no outputs - description="LLM workflow with simulated error", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - # The workflow should fail due to the simulated error - assert not result.success - assert result.error is not None - - -def test_workflow_with_mock_delays(): - """Test that mock delays work correctly.""" - runner = TableTestRunner() - - # Create mock configuration with delays - mock_config = MockConfig(simulate_delays=True) - node_config = NodeMockConfig( - node_id="llm_node", - outputs={"text": "Response after delay"}, - delay=0.1, # 100ms delay - ) - mock_config.set_node_config("llm_node", node_config) - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Test with delay"}, - expected_outputs={"answer": "Response after delay"}, - description="LLM workflow with simulated delay", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - # Execution time should be at least the delay - assert result.execution_time >= 0.1 - - -def test_mock_config_builder(): - """Test the MockConfigBuilder fluent interface.""" - config = ( - MockConfigBuilder() - .with_llm_response("LLM response") - .with_agent_response("Agent response") - .with_tool_response({"tool": "output"}) - .with_retrieval_response("Retrieval content") - .with_http_response({"status_code": 201, "body": "created"}) - .with_node_output("node1", {"output": "value"}) - .with_node_error("node2", "error message") - .with_delays(True) - .build() - ) - - assert config.default_llm_response == "LLM response" - assert config.default_agent_response == "Agent response" - assert config.default_tool_response == {"tool": "output"} - assert config.default_retrieval_response == "Retrieval content" - assert config.default_http_response == {"status_code": 201, "body": "created"} - assert config.simulate_delays is True - - node1_config = config.get_node_config("node1") - assert node1_config is not None - assert node1_config.outputs == {"output": "value"} - - node2_config = config.get_node_config("node2") - assert node2_config is not None - assert node2_config.error == "error message" - - -def test_mock_factory_node_type_detection(): - """Test that MockNodeFactory correctly identifies nodes to mock.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.runtime import GraphRuntimeState, VariablePool - - from .test_mock_factory import MockNodeFactory - - graph_init_params = build_test_graph_init_params( - workflow_id="test", - graph_config={}, - tenant_id="test", - app_id="test", - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # Test that third-party service nodes are identified for mocking - assert factory.should_mock_node(BuiltinNodeTypes.LLM) - assert factory.should_mock_node(BuiltinNodeTypes.AGENT) - assert factory.should_mock_node(BuiltinNodeTypes.TOOL) - assert factory.should_mock_node(BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL) - assert factory.should_mock_node(BuiltinNodeTypes.HTTP_REQUEST) - assert factory.should_mock_node(BuiltinNodeTypes.PARAMETER_EXTRACTOR) - assert factory.should_mock_node(BuiltinNodeTypes.DOCUMENT_EXTRACTOR) - - # Test that CODE and TEMPLATE_TRANSFORM are mocked (they require SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Test that non-service nodes are not mocked - assert not factory.should_mock_node(BuiltinNodeTypes.START) - assert not factory.should_mock_node(BuiltinNodeTypes.END) - assert not factory.should_mock_node(BuiltinNodeTypes.IF_ELSE) - assert not factory.should_mock_node(BuiltinNodeTypes.VARIABLE_AGGREGATOR) - - -def test_custom_mock_handler(): - """Test using a custom handler function for mock outputs.""" - runner = TableTestRunner() - - # Custom handler that modifies output based on input - def custom_llm_handler(node) -> dict: - # In a real scenario, we could access node.graph_runtime_state.variable_pool - # to get the actual inputs - return { - "text": "Custom handler response", - "usage": { - "prompt_tokens": 5, - "completion_tokens": 3, - "total_tokens": 8, - }, - "finish_reason": "stop", - } - - mock_config = MockConfig() - node_config = NodeMockConfig( - node_id="llm_node", - custom_handler=custom_llm_handler, - ) - mock_config.set_node_config("llm_node", node_config) - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Test custom handler"}, - expected_outputs={"answer": "Custom handler response"}, - description="LLM workflow with custom handler", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs["answer"] == "Custom handler response" - - -def test_workflow_without_auto_mock(): - """Test that workflows work normally without auto-mock enabled.""" - runner = TableTestRunner() - - # This test uses the echo workflow which doesn't need external services - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "Test without mock"}, - expected_outputs={"query": "Test without mock"}, - description="Echo workflow without auto-mock", - use_auto_mock=False, # Auto-mock disabled - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs["query"] == "Test without mock" - - -def test_register_custom_mock_node(): - """Test registering a custom mock implementation for a node type.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.nodes.template_transform import TemplateTransformNode - from graphon.runtime import GraphRuntimeState, VariablePool - - from .test_mock_factory import MockNodeFactory - - # Create a custom mock for TemplateTransformNode - class MockTemplateTransformNode(TemplateTransformNode): - def _run(self): - # Custom mock implementation - pass - - graph_init_params = build_test_graph_init_params( - workflow_id="test", - graph_config={}, - tenant_id="test", - app_id="test", - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # TEMPLATE_TRANSFORM is mocked by default (requires SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Unregister mock - factory.unregister_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - assert not factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Re-register custom mock - factory.register_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM, MockTemplateTransformNode) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - -def test_default_config_by_node_type(): - """Test setting default configurations by node type.""" - mock_config = MockConfig() - - # Set default config for all LLM nodes - mock_config.set_default_config( - BuiltinNodeTypes.LLM, - { - "default_response": "Default LLM response for all nodes", - "temperature": 0.7, - }, - ) - - # Set default config for all HTTP nodes - mock_config.set_default_config( - BuiltinNodeTypes.HTTP_REQUEST, - { - "default_status": 200, - "default_timeout": 30, - }, - ) - - llm_config = mock_config.get_default_config(BuiltinNodeTypes.LLM) - assert llm_config["default_response"] == "Default LLM response for all nodes" - assert llm_config["temperature"] == 0.7 - - http_config = mock_config.get_default_config(BuiltinNodeTypes.HTTP_REQUEST) - assert http_config["default_status"] == 200 - assert http_config["default_timeout"] == 30 - - # Non-configured node type should return empty dict - tool_config = mock_config.get_default_config(BuiltinNodeTypes.TOOL) - assert tool_config == {} - - -if __name__ == "__main__": - # Run all tests - pytest.main([__file__, "-v"]) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py deleted file mode 100644 index cefe3b8ac8a..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py +++ /dev/null @@ -1,41 +0,0 @@ -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_basic_chatflow(): - fixture_name = "basic_chatflow" - mock_config = MockConfigBuilder().with_llm_response("mocked llm response").build() - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - expected_outputs={"answer": "mocked llm response"}, - expected_event_sequence=[ - GraphRunStartedEvent, - # START - NodeRunStartedEvent, - NodeRunSucceededEvent, - # LLM - NodeRunStartedEvent, - ] - + [NodeRunStreamChunkEvent] * ("mocked llm response".count(" ") + 2) - + [ - NodeRunSucceededEvent, - # ANSWER - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py deleted file mode 100644 index 01ac2d7a968..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ /dev/null @@ -1,266 +0,0 @@ -"""Test the command system for GraphEngine control.""" - -import time -from unittest.mock import MagicMock - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from graphon.entities.graph_init_params import GraphInitParams -from graphon.entities.pause_reason import SchedulingPause -from graphon.graph import Graph -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_engine.entities.commands import ( - AbortCommand, - CommandType, - PauseCommand, - UpdateVariablesCommand, - VariableUpdate, -) -from graphon.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import IntegerVariable, StringVariable - - -def test_abort_command(): - """Test that GraphEngine properly handles abort commands.""" - - # Create shared GraphRuntimeState - shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) - - # Create a minimal mock graph - mock_graph = MagicMock(spec=Graph) - mock_graph.nodes = {} - mock_graph.edges = {} - mock_graph.root_node = MagicMock() - mock_graph.root_node.id = "start" - - # Create mock nodes with required attributes - using shared runtime state - start_node = StartNode( - id="start", - config={"id": "start", "data": {"title": "start", "variables": []}}, - graph_init_params=GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ), - graph_runtime_state=shared_runtime_state, - ) - mock_graph.nodes["start"] = start_node - - # Mock graph methods - mock_graph.get_outgoing_edges = MagicMock(return_value=[]) - mock_graph.get_incoming_edges = MagicMock(return_value=[]) - - # Create command channel - command_channel = InMemoryChannel() - - # Create GraphEngine with same shared runtime state - engine = GraphEngine( - workflow_id="test_workflow", - graph=mock_graph, - graph_runtime_state=shared_runtime_state, # Use shared instance - command_channel=command_channel, - config=GraphEngineConfig(), - ) - - # Queue an abort request before starting. - engine.request_abort("Test abort") - - # Run engine and collect events - events = list(engine.run()) - - # Verify we get start and abort events - assert any(isinstance(e, GraphRunStartedEvent) for e in events) - assert any(isinstance(e, GraphRunAbortedEvent) for e in events) - - # Find the abort event and check its reason - abort_events = [e for e in events if isinstance(e, GraphRunAbortedEvent)] - assert len(abort_events) == 1 - assert abort_events[0].reason is not None - assert "aborted: test abort" in abort_events[0].reason.lower() - - -def test_redis_channel_serialization(): - """Test that Redis channel properly serializes and deserializes commands.""" - import json - from unittest.mock import MagicMock - - # Mock redis client - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None) - - from graphon.graph_engine.command_channels.redis_channel import RedisChannel - - # Create channel with a specific key - channel = RedisChannel(mock_redis, channel_key="workflow:123:commands") - - # Test sending a command - abort_command = AbortCommand(reason="Test abort") - channel.send_command(abort_command) - - # Verify redis methods were called - mock_pipeline.rpush.assert_called_once() - mock_pipeline.expire.assert_called_once() - - # Verify the serialized data - call_args = mock_pipeline.rpush.call_args - key = call_args[0][0] - command_json = call_args[0][1] - - assert key == "workflow:123:commands" - - # Verify JSON structure - command_data = json.loads(command_json) - assert command_data["command_type"] == "abort" - assert command_data["reason"] == "Test abort" - - # Test pause command serialization - pause_command = PauseCommand(reason="User requested pause") - channel.send_command(pause_command) - - assert len(mock_pipeline.rpush.call_args_list) == 2 - second_call_args = mock_pipeline.rpush.call_args_list[1] - pause_command_json = second_call_args[0][1] - pause_command_data = json.loads(pause_command_json) - assert pause_command_data["command_type"] == CommandType.PAUSE.value - assert pause_command_data["reason"] == "User requested pause" - - -def test_pause_command(): - """Test that GraphEngine properly handles pause commands.""" - - shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) - - mock_graph = MagicMock(spec=Graph) - mock_graph.nodes = {} - mock_graph.edges = {} - mock_graph.root_node = MagicMock() - mock_graph.root_node.id = "start" - - start_node = StartNode( - id="start", - config={"id": "start", "data": {"title": "start", "variables": []}}, - graph_init_params=GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ), - graph_runtime_state=shared_runtime_state, - ) - mock_graph.nodes["start"] = start_node - - mock_graph.get_outgoing_edges = MagicMock(return_value=[]) - mock_graph.get_incoming_edges = MagicMock(return_value=[]) - - command_channel = InMemoryChannel() - - engine = GraphEngine( - workflow_id="test_workflow", - graph=mock_graph, - graph_runtime_state=shared_runtime_state, - command_channel=command_channel, - config=GraphEngineConfig(), - ) - - pause_command = PauseCommand(reason="User requested pause") - command_channel.send_command(pause_command) - - events = list(engine.run()) - - assert any(isinstance(e, GraphRunStartedEvent) for e in events) - pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)] - assert len(pause_events) == 1 - assert pause_events[0].reasons == [SchedulingPause(message="User requested pause")] - - graph_execution = engine.graph_runtime_state.graph_execution - assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")] - - -def test_update_variables_command_updates_pool(): - """Test that GraphEngine updates variable pool via update variables command.""" - - shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) - shared_runtime_state.variable_pool.add(("node1", "foo"), "old value") - - mock_graph = MagicMock(spec=Graph) - mock_graph.nodes = {} - mock_graph.edges = {} - mock_graph.root_node = MagicMock() - mock_graph.root_node.id = "start" - - start_node = StartNode( - id="start", - config={"id": "start", "data": {"title": "start", "variables": []}}, - graph_init_params=GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ), - graph_runtime_state=shared_runtime_state, - ) - mock_graph.nodes["start"] = start_node - - mock_graph.get_outgoing_edges = MagicMock(return_value=[]) - mock_graph.get_incoming_edges = MagicMock(return_value=[]) - - command_channel = InMemoryChannel() - - engine = GraphEngine( - workflow_id="test_workflow", - graph=mock_graph, - graph_runtime_state=shared_runtime_state, - command_channel=command_channel, - config=GraphEngineConfig(), - ) - - update_command = UpdateVariablesCommand( - updates=[ - VariableUpdate( - value=StringVariable(name="foo", value="new value", selector=["node1", "foo"]), - ), - VariableUpdate( - value=IntegerVariable(name="bar", value=123, selector=["node2", "bar"]), - ), - ] - ) - command_channel.send_command(update_command) - - list(engine.run()) - - updated_existing = shared_runtime_state.variable_pool.get(["node1", "foo"]) - added_new = shared_runtime_state.variable_pool.get(["node2", "bar"]) - - assert updated_existing is not None - assert updated_existing.value == "new value" - assert added_new is not None - assert added_new.value == 123 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py deleted file mode 100644 index ba9c5024528..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Test suite for complex branch workflow with parallel execution and conditional routing. - -This test suite validates the behavior of a workflow that: -1. Executes nodes in parallel (IF/ELSE and LLM branches) -2. Routes based on conditional logic (query containing 'hello') -3. Handles multiple answer nodes with different outputs -""" - -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -class TestComplexBranchWorkflow: - """Test suite for complex branch workflow with parallel execution.""" - - def setup_method(self): - """Set up test environment before each test method.""" - self.runner = TableTestRunner() - self.fixture_path = "test_complex_branch" - - def test_hello_branch_with_llm(self): - """ - Test when query contains 'hello' - should trigger true branch. - Both IF/ELSE and LLM should execute in parallel. - """ - mock_text_1 = "This is a mocked LLM response for hello world" - test_cases = [ - WorkflowTestCase( - fixture_path=self.fixture_path, - query="hello world", - expected_outputs={ - "answer": f"contains 'hello'{mock_text_1}", - }, - description="Basic hello case with parallel LLM execution", - use_auto_mock=True, - mock_config=(MockConfigBuilder().with_node_output("1755502777322", {"text": mock_text_1}).build()), - ), - WorkflowTestCase( - fixture_path=self.fixture_path, - query="say hello to everyone", - expected_outputs={ - "answer": "contains 'hello'Mocked response for greeting", - }, - description="Hello in middle of sentence", - use_auto_mock=True, - mock_config=( - MockConfigBuilder() - .with_node_output("1755502777322", {"text": "Mocked response for greeting"}) - .build() - ), - ), - ] - - suite_result = self.runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test '{result.test_case.description}' failed: {result.error}" - assert result.actual_outputs - assert any(isinstance(event, GraphRunStartedEvent) for event in result.events) - assert any(isinstance(event, GraphRunSucceededEvent) for event in result.events) - - start_index = next( - idx for idx, event in enumerate(result.events) if isinstance(event, GraphRunStartedEvent) - ) - success_index = max( - idx for idx, event in enumerate(result.events) if isinstance(event, GraphRunSucceededEvent) - ) - assert start_index < success_index - - started_node_ids = {event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)} - assert {"1755502773326", "1755502777322"}.issubset(started_node_ids), ( - f"Branch or LLM nodes missing in events: {started_node_ids}" - ) - - assert any(isinstance(event, NodeRunStreamChunkEvent) for event in result.events), ( - "Expected streaming chunks from LLM execution" - ) - - llm_start_index = next( - idx - for idx, event in enumerate(result.events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == "1755502777322" - ) - assert any( - idx > llm_start_index and isinstance(event, NodeRunStreamChunkEvent) - for idx, event in enumerate(result.events) - ), "Streaming chunks should follow LLM node start" - - def test_non_hello_branch_with_llm(self): - """ - Test when query doesn't contain 'hello' - should trigger false branch. - LLM output should be used as the final answer. - """ - test_cases = [ - WorkflowTestCase( - fixture_path=self.fixture_path, - query="goodbye world", - expected_outputs={ - "answer": "Mocked LLM response for goodbye", - }, - description="Goodbye case - false branch with LLM output", - use_auto_mock=True, - mock_config=( - MockConfigBuilder() - .with_node_output("1755502777322", {"text": "Mocked LLM response for goodbye"}) - .build() - ), - ), - WorkflowTestCase( - fixture_path=self.fixture_path, - query="test message", - expected_outputs={ - "answer": "Mocked response for test", - }, - description="Regular message - false branch", - use_auto_mock=True, - mock_config=( - MockConfigBuilder().with_node_output("1755502777322", {"text": "Mocked response for test"}).build() - ), - ), - ] - - suite_result = self.runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test '{result.test_case.description}' failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py deleted file mode 100644 index 38514807310..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py +++ /dev/null @@ -1,220 +0,0 @@ -""" -Test for streaming output workflow behavior. - -This test validates that: -- When blocking == 1: No NodeRunStreamChunkEvent (flow through Template node) -- When blocking != 1: NodeRunStreamChunkEvent present (direct LLM to End output) -""" - -from graphon.enums import BuiltinNodeTypes -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_events import ( - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_table_runner import TableTestRunner - - -def test_streaming_output_with_blocking_equals_one(): - """ - Test workflow when blocking == 1 (LLM โ†’ Template โ†’ End). - - Template node doesn't produce streaming output, so no NodeRunStreamChunkEvent should be present. - This test should FAIL according to requirements. - """ - runner = TableTestRunner() - - # Load the workflow configuration - fixture_data = runner.workflow_runner.load_fixture("conditional_streaming_vs_template_workflow") - - # Create graph from fixture with auto-mock enabled - graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( - fixture_data=fixture_data, - inputs={"query": "Hello, how are you?", "blocking": 1}, - use_mock_factory=True, - ) - - # Create and run the engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Execute the workflow - events = list(engine.run()) - - # Check for successful completion - success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)] - assert len(success_events) > 0, "Workflow should complete successfully" - - # Check for streaming events - stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] - stream_chunk_count = len(stream_chunk_events) - - # According to requirements, we expect exactly 3 streaming events from the End node - # 1. User query - # 2. Newline - # 3. Template output (which contains the LLM response) - assert stream_chunk_count == 3, f"Expected 3 streaming events when blocking=1, but got {stream_chunk_count}" - - first_chunk, second_chunk, third_chunk = stream_chunk_events[0], stream_chunk_events[1], stream_chunk_events[2] - assert first_chunk.chunk == "Hello, how are you?", ( - f"Expected first chunk to be user input, but got {first_chunk.chunk}" - ) - assert second_chunk.chunk == "\n", f"Expected second chunk to be newline, but got {second_chunk.chunk}" - # Third chunk will be the template output with the mock LLM response - assert isinstance(third_chunk.chunk, str), f"Expected third chunk to be string, but got {type(third_chunk.chunk)}" - - # Find indices of first LLM success event and first stream chunk event - llm2_start_index = next( - ( - i - for i, e in enumerate(events) - if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM - ), - -1, - ) - first_chunk_index = next( - (i for i, e in enumerate(events) if isinstance(e, NodeRunStreamChunkEvent)), - -1, - ) - - assert first_chunk_index < llm2_start_index, ( - f"Expected first chunk before LLM2 start, but got {first_chunk_index} and {llm2_start_index}" - ) - - # Check that NodeRunStreamChunkEvent contains 'query' should has same id with Start NodeRunStartedEvent - start_node_id = graph.root_node.id - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_id == start_node_id] - assert len(start_events) == 1, f"Expected 1 start event for node {start_node_id}, but got {len(start_events)}" - start_event = start_events[0] - query_chunk_events = [e for e in stream_chunk_events if e.chunk == "Hello, how are you?"] - assert all(e.id == start_event.id for e in query_chunk_events), "Expected all query chunk events to have same id" - - # Check all Template's NodeRunStreamChunkEvent should has same id with Template's NodeRunStartedEvent - start_events = [ - e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM - ] - template_chunk_events = [e for e in stream_chunk_events if e.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM] - assert len(template_chunk_events) == 1, f"Expected 1 template chunk event, but got {len(template_chunk_events)}" - assert all(e.id in [se.id for se in start_events] for e in template_chunk_events), ( - "Expected all Template chunk events to have same id with Template's NodeRunStartedEvent" - ) - - # Check that NodeRunStreamChunkEvent contains '\n' is from the End node - end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.END] - assert len(end_events) == 1, f"Expected 1 end event, but got {len(end_events)}" - newline_chunk_events = [e for e in stream_chunk_events if e.chunk == "\n"] - assert len(newline_chunk_events) == 1, f"Expected 1 newline chunk event, but got {len(newline_chunk_events)}" - # The newline chunk should be from the End node (check node_id, not execution id) - assert all(e.node_id == end_events[0].node_id for e in newline_chunk_events), ( - "Expected all newline chunk events to be from End node" - ) - - -def test_streaming_output_with_blocking_not_equals_one(): - """ - Test workflow when blocking != 1 (LLM โ†’ End directly). - - End node should produce streaming output with NodeRunStreamChunkEvent. - This test should PASS according to requirements. - """ - runner = TableTestRunner() - - # Load the workflow configuration - fixture_data = runner.workflow_runner.load_fixture("conditional_streaming_vs_template_workflow") - - # Create graph from fixture with auto-mock enabled - graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( - fixture_data=fixture_data, - inputs={"query": "Hello, how are you?", "blocking": 2}, - use_mock_factory=True, - ) - - # Create and run the engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Execute the workflow - events = list(engine.run()) - - # Check for successful completion - success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)] - assert len(success_events) > 0, "Workflow should complete successfully" - - # Check for streaming events - expecting streaming events - stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] - stream_chunk_count = len(stream_chunk_events) - - # This assertion should PASS according to requirements - assert stream_chunk_count > 0, f"Expected streaming events when blocking!=1, but got {stream_chunk_count}" - - # We should have at least 2 chunks (query and newline) - assert stream_chunk_count >= 2, f"Expected at least 2 streaming events, but got {stream_chunk_count}" - - first_chunk, second_chunk = stream_chunk_events[0], stream_chunk_events[1] - assert first_chunk.chunk == "Hello, how are you?", ( - f"Expected first chunk to be user input, but got {first_chunk.chunk}" - ) - assert second_chunk.chunk == "\n", f"Expected second chunk to be newline, but got {second_chunk.chunk}" - - # Find indices of first LLM success event and first stream chunk event - llm2_start_index = next( - ( - i - for i, e in enumerate(events) - if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM - ), - -1, - ) - first_chunk_index = next( - (i for i, e in enumerate(events) if isinstance(e, NodeRunStreamChunkEvent)), - -1, - ) - - assert first_chunk_index < llm2_start_index, ( - f"Expected first chunk before LLM2 start, but got {first_chunk_index} and {llm2_start_index}" - ) - - # With auto-mock, the LLM will produce mock responses - just verify we have streaming chunks - # and they are strings - for chunk_event in stream_chunk_events[2:]: - assert isinstance(chunk_event.chunk, str), f"Expected chunk to be string, but got {type(chunk_event.chunk)}" - - # Check that NodeRunStreamChunkEvent contains 'query' should has same id with Start NodeRunStartedEvent - start_node_id = graph.root_node.id - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_id == start_node_id] - assert len(start_events) == 1, f"Expected 1 start event for node {start_node_id}, but got {len(start_events)}" - start_event = start_events[0] - query_chunk_events = [e for e in stream_chunk_events if e.chunk == "Hello, how are you?"] - assert all(e.id == start_event.id for e in query_chunk_events), "Expected all query chunk events to have same id" - - # Check all LLM's NodeRunStreamChunkEvent should be from LLM nodes - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.LLM] - llm_chunk_events = [e for e in stream_chunk_events if e.node_type == BuiltinNodeTypes.LLM] - llm_node_ids = {se.node_id for se in start_events} - assert all(e.node_id in llm_node_ids for e in llm_chunk_events), ( - "Expected all LLM chunk events to be from LLM nodes" - ) - - # Check that NodeRunStreamChunkEvent contains '\n' is from the End node - end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.END] - assert len(end_events) == 1, f"Expected 1 end event, but got {len(end_events)}" - newline_chunk_events = [e for e in stream_chunk_events if e.chunk == "\n"] - assert len(newline_chunk_events) == 1, f"Expected 1 newline chunk event, but got {len(newline_chunk_events)}" - # The newline chunk should be from the End node (check node_id, not execution id) - assert all(e.node_id == end_events[0].node_id for e in newline_chunk_events), ( - "Expected all newline chunk events to be from End node" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py deleted file mode 100644 index ae7dd48bb16..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Utilities for detecting if database service is available for workflow tests. -""" - -import psycopg2 -import pytest - -from configs import dify_config - - -def is_database_available() -> bool: - """ - Check if the database service is available by attempting to connect to it. - - Returns: - True if database is available, False otherwise. - """ - try: - # Try to establish a database connection using a context manager - with psycopg2.connect( - host=dify_config.DB_HOST, - port=dify_config.DB_PORT, - database=dify_config.DB_DATABASE, - user=dify_config.DB_USERNAME, - password=dify_config.DB_PASSWORD, - connect_timeout=2, # 2 second timeout - ) as conn: - pass # Connection established and will be closed automatically - return True - except (psycopg2.OperationalError, psycopg2.Error): - return False - - -def skip_if_database_unavailable(): - """ - Pytest skip decorator that skips tests when database service is unavailable. - - Usage: - @skip_if_database_unavailable() - def test_my_workflow(): - ... - """ - return pytest.mark.skipif( - not is_database_available(), - reason="Database service is not available (connection refused or authentication failed)", - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py deleted file mode 100644 index 3264ad1168e..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py +++ /dev/null @@ -1,72 +0,0 @@ -import queue - -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph_engine.orchestration.dispatcher import Dispatcher -from graphon.graph_events import NodeRunSucceededEvent -from graphon.node_events import NodeRunResult -from libs.datetime_utils import naive_utc_now - - -class StubExecutionCoordinator: - def __init__(self, paused: bool) -> None: - self._paused = paused - self.mark_complete_called = False - self.failed_error: Exception | None = None - - @property - def aborted(self) -> bool: - return False - - @property - def paused(self) -> bool: - return self._paused - - @property - def execution_complete(self) -> bool: - return False - - def check_scaling(self) -> None: - return None - - def process_commands(self) -> None: - return None - - def mark_complete(self) -> None: - self.mark_complete_called = True - - def mark_failed(self, error: Exception) -> None: - self.failed_error = error - - -class StubEventHandler: - def __init__(self) -> None: - self.events: list[object] = [] - - def dispatch(self, event: object) -> None: - self.events.append(event) - - -def test_dispatcher_drains_events_when_paused() -> None: - event_queue: queue.Queue = queue.Queue() - event = NodeRunSucceededEvent( - id="exec-1", - node_id="node-1", - node_type=BuiltinNodeTypes.START, - start_at=naive_utc_now(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ) - event_queue.put(event) - - handler = StubEventHandler() - coordinator = StubExecutionCoordinator(paused=True) - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=handler, - execution_coordinator=coordinator, - event_emitter=None, - ) - - dispatcher._dispatcher_loop() - - assert handler.events == [event] - assert coordinator.mark_complete_called is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py deleted file mode 100644 index ada55f3dc5d..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Test case for end node without value_type field (backward compatibility). - -This test validates that end nodes work correctly even when the value_type -field is missing from the output configuration, ensuring backward compatibility -with older workflow definitions. -""" - -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_end_node_without_value_type_field(): - """ - Test that end node works without explicit value_type field. - - The fixture implements a simple workflow that: - 1. Takes a query input from start node - 2. Passes it directly to end node - 3. End node outputs the value without specifying value_type - 4. Should correctly infer the type and output the value - - This ensures backward compatibility with workflow definitions - created before value_type became a required field. - """ - fixture_name = "end_node_without_value_type_field_workflow" - - case = WorkflowTestCase( - fixture_path=fixture_name, - inputs={"query": "test query"}, - expected_outputs={"query": "test query"}, - expected_event_sequence=[ - # Graph start - GraphRunStartedEvent, - # Start node - NodeRunStartedEvent, - NodeRunStreamChunkEvent, # Start node streams the input value - NodeRunSucceededEvent, - # End node - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Graph end - GraphRunSucceededEvent, - ], - description="End node without value_type field should work correctly", - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs == {"query": "test query"}, ( - f"Expected output to be {{'query': 'test query'}}, got {result.actual_outputs}" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py deleted file mode 100644 index 95a94110d2f..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Unit tests for the execution coordinator orchestration logic.""" - -from unittest.mock import MagicMock - -import pytest - -from graphon.graph_engine.command_processing.command_processor import CommandProcessor -from graphon.graph_engine.domain.graph_execution import GraphExecution -from graphon.graph_engine.graph_state_manager import GraphStateManager -from graphon.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from graphon.graph_engine.worker_management.worker_pool import WorkerPool - - -def _build_coordinator(graph_execution: GraphExecution) -> tuple[ExecutionCoordinator, MagicMock, MagicMock]: - command_processor = MagicMock(spec=CommandProcessor) - state_manager = MagicMock(spec=GraphStateManager) - worker_pool = MagicMock(spec=WorkerPool) - - coordinator = ExecutionCoordinator( - graph_execution=graph_execution, - state_manager=state_manager, - command_processor=command_processor, - worker_pool=worker_pool, - ) - return coordinator, state_manager, worker_pool - - -def test_handle_pause_stops_workers_and_clears_state() -> None: - """Paused execution should stop workers and clear executing state.""" - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - graph_execution.pause("Awaiting human input") - - coordinator, state_manager, worker_pool = _build_coordinator(graph_execution) - - coordinator.handle_pause_if_needed() - - worker_pool.stop.assert_called_once_with() - state_manager.clear_executing.assert_called_once_with() - - -def test_handle_pause_noop_when_execution_running() -> None: - """Running execution should not trigger pause handling.""" - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - - coordinator, state_manager, worker_pool = _build_coordinator(graph_execution) - - coordinator.handle_pause_if_needed() - - worker_pool.stop.assert_not_called() - state_manager.clear_executing.assert_not_called() - - -def test_has_executing_nodes_requires_pause() -> None: - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - - coordinator, _, _ = _build_coordinator(graph_execution) - - with pytest.raises(AssertionError): - coordinator.has_executing_nodes() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py deleted file mode 100644 index 51ece26d494..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ /dev/null @@ -1,770 +0,0 @@ -""" -Table-driven test framework for GraphEngine workflows. - -This file contains property-based tests and specific workflow tests. -The core test framework is in test_table_runner.py. -""" - -import time - -from hypothesis import HealthCheck, given, settings -from hypothesis import strategies as st - -from graphon.entities.base_node_data import DefaultValue, DefaultValueType -from graphon.enums import ErrorStrategy -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_events import ( - GraphRunPartialSucceededEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) - -# Import the test framework from the new module -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase - - -# Property-based fuzzing tests for the start-end workflow -@given(query_input=st.text()) -@settings(max_examples=50, deadline=30000, suppress_health_check=[HealthCheck.too_slow]) -def test_echo_workflow_property_basic_strings(query_input): - """ - Property-based test: Echo workflow should return exactly what was input. - - This tests the fundamental property that for any string input, - the start-end workflow should echo it back unchanged. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Fuzzing test with input: {repr(query_input)[:50]}...", - ) - - result = runner.run_test_case(test_case) - - # Property: The workflow should complete successfully - assert result.success, f"Workflow failed with input {repr(query_input)}: {result.error}" - - # Property: Output should equal input (echo behavior) - assert result.actual_outputs - assert result.actual_outputs == {"query": query_input}, ( - f"Echo property violated. Input: {repr(query_input)}, " - f"Expected: {repr(query_input)}, Got: {repr(result.actual_outputs.get('query'))}" - ) - - -@given(query_input=st.text(min_size=0, max_size=1000)) -@settings(max_examples=30, deadline=20000) -def test_echo_workflow_property_bounded_strings(query_input): - """ - Property-based test with size bounds to test edge cases more efficiently. - - Tests strings up to 1000 characters to balance thoroughness with performance. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Bounded fuzzing test (len={len(query_input)})", - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed with bounded input: {result.error}" - assert result.actual_outputs == {"query": query_input} - - -@given( - query_input=st.one_of( - st.text(alphabet=st.characters(whitelist_categories=["Lu", "Ll", "Nd", "Po"])), # Letters, digits, punctuation - st.text(alphabet="๐ŸŽ‰๐ŸŒŸ๐Ÿ’ซโญ๐Ÿ”ฅ๐Ÿ’ฏ๐Ÿš€๐ŸŽฏ"), # Emojis - st.text(alphabet="ฮฑฮฒฮณฮดฮตฮถฮทฮธฮนฮบฮปฮผฮฝฮพฮฟฯ€ฯฯƒฯ„ฯ…ฯ†ฯ‡ฯˆฯ‰"), # Greek letters - st.text(alphabet="ไธญๆ–‡ๆต‹่ฏ•ํ•œ๊ตญ์–ดๆ—ฅๆœฌ่ชžุงู„ุนุฑุจูŠุฉ"), # International characters - st.just(""), # Empty string - st.just(" " * 100), # Whitespace only - st.just("\n\t\r\f\v"), # Special whitespace chars - st.just('{"json": "like", "data": [1, 2, 3]}'), # JSON-like string - st.just("SELECT * FROM users; DROP TABLE users;--"), # SQL injection attempt - st.just(""), # XSS attempt - st.just("../../etc/passwd"), # Path traversal attempt - ) -) -@settings(max_examples=40, deadline=25000) -def test_echo_workflow_property_diverse_inputs(query_input): - """ - Property-based test with diverse input types including edge cases and security payloads. - - Tests various categories of potentially problematic inputs: - - Unicode characters from different languages - - Emojis and special symbols - - Whitespace variations - - Malicious payloads (SQL injection, XSS, path traversal) - - JSON-like structures - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Diverse input fuzzing: {type(query_input).__name__}", - ) - - result = runner.run_test_case(test_case) - - # Property: System should handle all inputs gracefully (no crashes) - assert result.success, f"Workflow failed with diverse input {repr(query_input)}: {result.error}" - - # Property: Echo behavior must be preserved regardless of input type - assert result.actual_outputs == {"query": query_input} - - -@given(query_input=st.text(min_size=1000, max_size=5000)) -@settings(max_examples=10, deadline=60000) -def test_echo_workflow_property_large_inputs(query_input): - """ - Property-based test for large inputs to test memory and performance boundaries. - - Tests the system's ability to handle larger payloads efficiently. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Large input test (size: {len(query_input)} chars)", - timeout=45.0, # Longer timeout for large inputs - ) - - start_time = time.perf_counter() - result = runner.run_test_case(test_case) - execution_time = time.perf_counter() - start_time - - # Property: Large inputs should still work - assert result.success, f"Large input workflow failed: {result.error}" - - # Property: Echo behavior preserved for large inputs - assert result.actual_outputs == {"query": query_input} - - # Property: Performance should be reasonable even for large inputs - assert execution_time < 30.0, f"Large input took too long: {execution_time:.2f}s" - - -def test_echo_workflow_robustness_smoke_test(): - """ - Smoke test to ensure the basic workflow functionality works before fuzzing. - - This test uses a simple, known-good input to verify the test infrastructure - is working correctly before running the fuzzing tests. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "smoke test"}, - expected_outputs={"query": "smoke test"}, - description="Smoke test for basic functionality", - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Smoke test failed: {result.error}" - assert result.actual_outputs == {"query": "smoke test"} - assert result.execution_time > 0 - - -def test_if_else_workflow_true_branch(): - """ - Test if-else workflow when input contains 'hello' (true branch). - - Should output {"true": input_query} when query contains "hello". - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello world"}, - expected_outputs={"true": "hello world"}, - description="Basic hello case", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "say hello to everyone"}, - expected_outputs={"true": "say hello to everyone"}, - description="Hello in middle of sentence", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello"}, - expected_outputs={"true": "hello"}, - description="Just hello", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hellohello"}, - expected_outputs={"true": "hellohello"}, - description="Multiple hello occurrences", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test case '{result.test_case.description}' failed: {result.error}" - # Check that outputs contain ONLY the expected key (true branch) - assert result.actual_outputs == result.test_case.expected_outputs, ( - f"Expected only 'true' key in outputs for {result.test_case.description}. " - f"Expected: {result.test_case.expected_outputs}, Got: {result.actual_outputs}" - ) - - -def test_if_else_workflow_false_branch(): - """ - Test if-else workflow when input does not contain 'hello' (false branch). - - Should output {"false": input_query} when query does not contain "hello". - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "goodbye world"}, - expected_outputs={"false": "goodbye world"}, - description="Basic goodbye case", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hi there"}, - expected_outputs={"false": "hi there"}, - description="Simple greeting without hello", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": ""}, - expected_outputs={"false": ""}, - description="Empty string", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "test message"}, - expected_outputs={"false": "test message"}, - description="Regular message", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test case '{result.test_case.description}' failed: {result.error}" - # Check that outputs contain ONLY the expected key (false branch) - assert result.actual_outputs == result.test_case.expected_outputs, ( - f"Expected only 'false' key in outputs for {result.test_case.description}. " - f"Expected: {result.test_case.expected_outputs}, Got: {result.actual_outputs}" - ) - - -def test_if_else_workflow_edge_cases(): - """ - Test if-else workflow edge cases and case sensitivity. - - Tests various edge cases including case sensitivity, similar words, etc. - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "Hello world"}, - expected_outputs={"false": "Hello world"}, - description="Capitalized Hello (case sensitive test)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "HELLO"}, - expected_outputs={"false": "HELLO"}, - description="All caps HELLO (case sensitive test)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "helllo"}, - expected_outputs={"false": "helllo"}, - description="Typo: helllo (with extra l)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "helo"}, - expected_outputs={"false": "helo"}, - description="Typo: helo (missing l)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello123"}, - expected_outputs={"true": "hello123"}, - description="Hello with numbers", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello!@#"}, - expected_outputs={"true": "hello!@#"}, - description="Hello with special characters", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": " hello "}, - expected_outputs={"true": " hello "}, - description="Hello with surrounding spaces", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test case '{result.test_case.description}' failed: {result.error}" - # Check that outputs contain ONLY the expected key - assert result.actual_outputs == result.test_case.expected_outputs, ( - f"Expected exact match for {result.test_case.description}. " - f"Expected: {result.test_case.expected_outputs}, Got: {result.actual_outputs}" - ) - - -@given(query_input=st.text()) -@settings(max_examples=50, deadline=30000, suppress_health_check=[HealthCheck.too_slow]) -def test_if_else_workflow_property_basic_strings(query_input): - """ - Property-based test: If-else workflow should output correct branch based on 'hello' content. - - This tests the fundamental property that for any string input: - - If input contains "hello", output should be {"true": input} - - If input doesn't contain "hello", output should be {"false": input} - """ - runner = TableTestRunner() - - # Determine expected output based on whether input contains "hello" - contains_hello = "hello" in query_input - expected_key = "true" if contains_hello else "false" - expected_outputs = {expected_key: query_input} - - test_case = WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": query_input}, - expected_outputs=expected_outputs, - description=f"Property test with input: {repr(query_input)[:50]}...", - ) - - result = runner.run_test_case(test_case) - - # Property: The workflow should complete successfully - assert result.success, f"Workflow failed with input {repr(query_input)}: {result.error}" - - # Property: Output should contain ONLY the expected key with correct value - assert result.actual_outputs == expected_outputs, ( - f"If-else property violated. Input: {repr(query_input)}, " - f"Expected: {expected_outputs}, Got: {result.actual_outputs}" - ) - - -@given(query_input=st.text(min_size=0, max_size=1000)) -@settings(max_examples=30, deadline=20000) -def test_if_else_workflow_property_bounded_strings(query_input): - """ - Property-based test with size bounds for if-else workflow. - - Tests strings up to 1000 characters to balance thoroughness with performance. - """ - runner = TableTestRunner() - - contains_hello = "hello" in query_input - expected_key = "true" if contains_hello else "false" - expected_outputs = {expected_key: query_input} - - test_case = WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": query_input}, - expected_outputs=expected_outputs, - description=f"Bounded if-else test (len={len(query_input)}, contains_hello={contains_hello})", - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed with bounded input: {result.error}" - assert result.actual_outputs == expected_outputs - - -@given( - query_input=st.one_of( - st.text(alphabet=st.characters(whitelist_categories=["Lu", "Ll", "Nd", "Po"])), # Letters, digits, punctuation - st.text(alphabet="hello"), # Strings that definitely contain hello - st.text(alphabet="xyz"), # Strings that definitely don't contain hello - st.just("hello world"), # Known true case - st.just("goodbye world"), # Known false case - st.just(""), # Empty string - st.just("Hello"), # Case sensitivity test - st.just("HELLO"), # Case sensitivity test - st.just("hello" * 10), # Multiple hello occurrences - st.just("say hello to everyone"), # Hello in middle - st.text(alphabet="๐ŸŽ‰๐ŸŒŸ๐Ÿ’ซโญ๐Ÿ”ฅ๐Ÿ’ฏ๐Ÿš€๐ŸŽฏ"), # Emojis - st.text(alphabet="ไธญๆ–‡ๆต‹่ฏ•ํ•œ๊ตญ์–ดๆ—ฅๆœฌ่ชžุงู„ุนุฑุจูŠุฉ"), # International characters - ) -) -@settings(max_examples=40, deadline=25000) -def test_if_else_workflow_property_diverse_inputs(query_input): - """ - Property-based test with diverse input types for if-else workflow. - - Tests various categories including: - - Known true/false cases - - Case sensitivity scenarios - - Unicode characters from different languages - - Emojis and special symbols - - Multiple hello occurrences - """ - runner = TableTestRunner() - - contains_hello = "hello" in query_input - expected_key = "true" if contains_hello else "false" - expected_outputs = {expected_key: query_input} - - test_case = WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": query_input}, - expected_outputs=expected_outputs, - description=f"Diverse if-else test: {type(query_input).__name__} (contains_hello={contains_hello})", - ) - - result = runner.run_test_case(test_case) - - # Property: System should handle all inputs gracefully (no crashes) - assert result.success, f"Workflow failed with diverse input {repr(query_input)}: {result.error}" - - # Property: Correct branch logic must be preserved regardless of input type - assert result.actual_outputs == expected_outputs, ( - f"Branch logic violated. Input: {repr(query_input)}, " - f"Contains 'hello': {contains_hello}, Expected: {expected_outputs}, Got: {result.actual_outputs}" - ) - - -# Tests for the Layer system -def test_layer_system_basic(): - """Test basic layer functionality with DebugLoggingLayer.""" - from graphon.graph_engine.layers import DebugLoggingLayer - - runner = WorkflowRunner() - - # Load a simple echo workflow - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture(fixture_data, inputs={"query": "test layer system"}) - - # Create engine with layer - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Add debug logging layer - debug_layer = DebugLoggingLayer(level="DEBUG", include_inputs=True, include_outputs=True) - engine.layer(debug_layer) - - # Run workflow - events = list(engine.run()) - - # Verify events were generated - assert len(events) > 0 - assert isinstance(events[0], GraphRunStartedEvent) - assert isinstance(events[-1], GraphRunSucceededEvent) - - # Verify layer received context - assert debug_layer.graph_runtime_state is not None - assert debug_layer.command_channel is not None - - # Verify layer tracked execution stats - assert debug_layer.node_count > 0 - assert debug_layer.success_count > 0 - - -def test_layer_chaining(): - """Test chaining multiple layers.""" - from graphon.graph_engine.layers import DebugLoggingLayer, GraphEngineLayer - - # Create a custom test layer - class TestLayer(GraphEngineLayer): - def __init__(self): - super().__init__() - self.events_received = [] - self.graph_started = False - self.graph_ended = False - - def on_graph_start(self): - self.graph_started = True - - def on_event(self, event): - self.events_received.append(event.__class__.__name__) - - def on_graph_end(self, error): - self.graph_ended = True - - runner = WorkflowRunner() - - # Load workflow - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture(fixture_data, inputs={"query": "test chaining"}) - - # Create engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Chain multiple layers - test_layer = TestLayer() - debug_layer = DebugLoggingLayer(level="INFO") - - engine.layer(test_layer).layer(debug_layer) - - # Run workflow - events = list(engine.run()) - - # Verify both layers received events - assert test_layer.graph_started - assert test_layer.graph_ended - assert len(test_layer.events_received) > 0 - - # Verify debug layer also worked - assert debug_layer.node_count > 0 - - -def test_layer_error_handling(): - """Test that layer errors don't crash the engine.""" - from graphon.graph_engine.layers import GraphEngineLayer - - # Create a layer that throws errors - class FaultyLayer(GraphEngineLayer): - def on_graph_start(self): - raise RuntimeError("Intentional error in on_graph_start") - - def on_event(self, event): - raise RuntimeError("Intentional error in on_event") - - def on_graph_end(self, error): - raise RuntimeError("Intentional error in on_graph_end") - - runner = WorkflowRunner() - - # Load workflow - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture(fixture_data, inputs={"query": "test error handling"}) - - # Create engine with faulty layer - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Add faulty layer - engine.layer(FaultyLayer()) - - # Run workflow - should not crash despite layer errors - events = list(engine.run()) - - # Verify workflow still completed successfully - assert len(events) > 0 - assert isinstance(events[-1], GraphRunSucceededEvent) - assert events[-1].outputs == {"query": "test error handling"} - - -def test_event_sequence_validation(): - """Test the new event sequence validation feature.""" - from graphon.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent - - runner = TableTestRunner() - - # Test 1: Successful event sequence validation - test_case_success = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test event sequence"}, - expected_outputs={"query": "test event sequence"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, # Start node begins - NodeRunStreamChunkEvent, # Start node streaming - NodeRunSucceededEvent, # Start node completes - NodeRunStartedEvent, # End node begins - NodeRunSucceededEvent, # End node completes - GraphRunSucceededEvent, # Graph completes - ], - description="Test with correct event sequence", - ) - - result = runner.run_test_case(test_case_success) - assert result.success, f"Test should pass with correct event sequence. Error: {result.event_mismatch_details}" - assert result.event_sequence_match is True - assert result.event_mismatch_details is None - - # Test 2: Failed event sequence validation - wrong order - test_case_wrong_order = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test wrong order"}, - expected_outputs={"query": "test wrong order"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunSucceededEvent, # Wrong: expecting success before start - NodeRunStreamChunkEvent, - NodeRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - description="Test with incorrect event order", - ) - - result = runner.run_test_case(test_case_wrong_order) - assert not result.success, "Test should fail with incorrect event sequence" - assert result.event_sequence_match is False - assert result.event_mismatch_details is not None - assert "Event mismatch at position" in result.event_mismatch_details - - # Test 3: Failed event sequence validation - wrong count - test_case_wrong_count = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test wrong count"}, - expected_outputs={"query": "test wrong count"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Missing the second node's events - GraphRunSucceededEvent, - ], - description="Test with incorrect event count", - ) - - result = runner.run_test_case(test_case_wrong_count) - assert not result.success, "Test should fail with incorrect event count" - assert result.event_sequence_match is False - assert result.event_mismatch_details is not None - assert "Event count mismatch" in result.event_mismatch_details - - # Test 4: No event sequence validation (backward compatibility) - test_case_no_validation = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test no validation"}, - expected_outputs={"query": "test no validation"}, - # No expected_event_sequence provided - description="Test without event sequence validation", - ) - - result = runner.run_test_case(test_case_no_validation) - assert result.success, "Test should pass when no event sequence is provided" - assert result.event_sequence_match is None - assert result.event_mismatch_details is None - - -def test_event_sequence_validation_with_table_tests(): - """Test event sequence validation with table-driven tests.""" - from graphon.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent - - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test1"}, - expected_outputs={"query": "test1"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - description="Table test 1: Valid sequence", - ), - WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test2"}, - expected_outputs={"query": "test2"}, - # No event sequence validation for this test - description="Table test 2: No sequence validation", - ), - WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test3"}, - expected_outputs={"query": "test3"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - description="Table test 3: Valid sequence", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - # Check all tests passed - for i, result in enumerate(suite_result.results): - if i == 1: # Test 2 has no event sequence validation - assert result.event_sequence_match is None - else: - assert result.event_sequence_match is True - assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}" - - -def test_graph_run_emits_partial_success_when_node_failure_recovered(): - runner = TableTestRunner() - - fixture_data = runner.workflow_runner.load_fixture("basic_chatflow") - mock_config = MockConfigBuilder().with_node_error("llm", "mock llm failure").build() - - graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( - fixture_data=fixture_data, - query="hello", - use_mock_factory=True, - mock_config=mock_config, - ) - - llm_node = graph.nodes["llm"] - base_node_data = llm_node.node_data - base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE - base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)] - - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - events = list(engine.run()) - - assert isinstance(events[-1], GraphRunPartialSucceededEvent) - - partial_event = next(event for event in events if isinstance(event, GraphRunPartialSucceededEvent)) - assert partial_event.exceptions_count == 1 - assert partial_event.outputs.get("answer") == "fallback response" - - assert not any(isinstance(event, GraphRunSucceededEvent) for event in events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py deleted file mode 100644 index 348ceb6788b..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Unit tests for GraphExecution serialization helpers.""" - -from __future__ import annotations - -import json -from collections import deque -from unittest.mock import MagicMock - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState -from graphon.graph_engine.domain import GraphExecution -from graphon.graph_engine.response_coordinator import ResponseStreamCoordinator -from graphon.graph_engine.response_coordinator.path import Path -from graphon.graph_engine.response_coordinator.session import ResponseSession -from graphon.graph_events import NodeRunStreamChunkEvent -from graphon.nodes.base.template import Template, TextSegment, VariableSegment - - -class CustomGraphExecutionError(Exception): - """Custom exception used to verify error serialization.""" - - -def test_graph_execution_serialization_round_trip() -> None: - """GraphExecution serialization restores full aggregate state.""" - # Arrange - execution = GraphExecution(workflow_id="wf-1") - execution.start() - node_a = execution.get_or_create_node_execution("node-a") - node_a.mark_started(execution_id="exec-1") - node_a.increment_retry() - node_a.mark_failed("boom") - node_b = execution.get_or_create_node_execution("node-b") - node_b.mark_skipped() - execution.fail(CustomGraphExecutionError("serialization failure")) - - # Act - serialized = execution.dumps() - payload = json.loads(serialized) - restored = GraphExecution(workflow_id="wf-1") - restored.loads(serialized) - - # Assert - assert payload["type"] == "GraphExecution" - assert payload["version"] == "1.0" - assert restored.workflow_id == "wf-1" - assert restored.started is True - assert restored.completed is True - assert restored.aborted is False - assert isinstance(restored.error, CustomGraphExecutionError) - assert str(restored.error) == "serialization failure" - assert set(restored.node_executions) == {"node-a", "node-b"} - restored_node_a = restored.node_executions["node-a"] - assert restored_node_a.state is NodeState.TAKEN - assert restored_node_a.retry_count == 1 - assert restored_node_a.execution_id == "exec-1" - assert restored_node_a.error == "boom" - restored_node_b = restored.node_executions["node-b"] - assert restored_node_b.state is NodeState.SKIPPED - assert restored_node_b.retry_count == 0 - assert restored_node_b.execution_id is None - assert restored_node_b.error is None - - -def test_graph_execution_loads_replaces_existing_state() -> None: - """loads replaces existing runtime data with serialized snapshot.""" - # Arrange - source = GraphExecution(workflow_id="wf-2") - source.start() - source_node = source.get_or_create_node_execution("node-source") - source_node.mark_taken() - serialized = source.dumps() - - target = GraphExecution(workflow_id="wf-2") - target.start() - target.abort("pre-existing abort") - temp_node = target.get_or_create_node_execution("node-temp") - temp_node.increment_retry() - temp_node.mark_failed("temp error") - - # Act - target.loads(serialized) - - # Assert - assert target.aborted is False - assert target.error is None - assert target.started is True - assert target.completed is False - assert set(target.node_executions) == {"node-source"} - restored_node = target.node_executions["node-source"] - assert restored_node.state is NodeState.TAKEN - assert restored_node.retry_count == 0 - assert restored_node.execution_id is None - assert restored_node.error is None - - -def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> None: - """ResponseStreamCoordinator serialization restores coordinator internals.""" - - template_main = Template(segments=[TextSegment(text="Hi "), VariableSegment(selector=["node-source", "text"])]) - template_secondary = Template(segments=[TextSegment(text="secondary")]) - - class DummyNode: - def __init__(self, node_id: str, template: Template, execution_type: NodeExecutionType) -> None: - self.id = node_id - self.node_type = ( - BuiltinNodeTypes.ANSWER if execution_type == NodeExecutionType.RESPONSE else BuiltinNodeTypes.LLM - ) - self.execution_type = execution_type - self.state = NodeState.UNKNOWN - self.title = node_id - self.template = template - - def blocks_variable_output(self, *_args) -> bool: - return False - - response_node1 = DummyNode("response-1", template_main, NodeExecutionType.RESPONSE) - response_node2 = DummyNode("response-2", template_main, NodeExecutionType.RESPONSE) - response_node3 = DummyNode("response-3", template_main, NodeExecutionType.RESPONSE) - source_node = DummyNode("node-source", template_secondary, NodeExecutionType.EXECUTABLE) - - class DummyGraph: - def __init__(self) -> None: - self.nodes = { - response_node1.id: response_node1, - response_node2.id: response_node2, - response_node3.id: response_node3, - source_node.id: source_node, - } - self.edges: dict[str, object] = {} - self.root_node = response_node1 - - def get_outgoing_edges(self, _node_id: str): # pragma: no cover - not exercised - return [] - - def get_incoming_edges(self, _node_id: str): # pragma: no cover - not exercised - return [] - - graph = DummyGraph() - - def fake_from_node(cls, node: DummyNode) -> ResponseSession: - return ResponseSession(node_id=node.id, template=node.template) - - monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node)) - - coordinator = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type] - coordinator._response_nodes = {"response-1", "response-2", "response-3"} - coordinator._paths_maps = { - "response-1": [Path(edges=["edge-1"])], - "response-2": [Path(edges=[])], - "response-3": [Path(edges=["edge-2", "edge-3"])], - } - - active_session = ResponseSession(node_id="response-1", template=response_node1.template) - active_session.index = 1 - coordinator._active_session = active_session - waiting_session = ResponseSession(node_id="response-2", template=response_node2.template) - coordinator._waiting_sessions = deque([waiting_session]) - pending_session = ResponseSession(node_id="response-3", template=response_node3.template) - pending_session.index = 2 - coordinator._response_sessions = {"response-3": pending_session} - - coordinator._node_execution_ids = {"response-1": "exec-1"} - event = NodeRunStreamChunkEvent( - id="exec-1", - node_id="response-1", - node_type=BuiltinNodeTypes.ANSWER, - selector=["node-source", "text"], - chunk="chunk-1", - is_final=False, - ) - coordinator._stream_buffers = {("node-source", "text"): [event]} - coordinator._stream_positions = {("node-source", "text"): 1} - coordinator._closed_streams = {("node-source", "text")} - - serialized = coordinator.dumps() - - restored = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type] - monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node)) - restored.loads(serialized) - - assert restored._response_nodes == {"response-1", "response-2", "response-3"} - assert restored._paths_maps["response-1"][0].edges == ["edge-1"] - assert restored._active_session is not None - assert restored._active_session.node_id == "response-1" - assert restored._active_session.index == 1 - waiting_restored = list(restored._waiting_sessions) - assert len(waiting_restored) == 1 - assert waiting_restored[0].node_id == "response-2" - assert waiting_restored[0].index == 0 - assert set(restored._response_sessions) == {"response-3"} - assert restored._response_sessions["response-3"].index == 2 - assert restored._node_execution_ids == {"response-1": "exec-1"} - assert ("node-source", "text") in restored._stream_buffers - restored_event = restored._stream_buffers[("node-source", "text")][0] - assert restored_event.chunk == "chunk-1" - assert restored._stream_positions[("node-source", "text")] == 1 - assert ("node-source", "text") in restored._closed_streams diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py deleted file mode 100644 index a6417822d26..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py +++ /dev/null @@ -1,190 +0,0 @@ -import time -from collections.abc import Mapping - -from core.workflow.system_variables import build_system_variables -from graphon.entities import GraphInitParams -from graphon.enums import NodeState -from graphon.graph import Graph -from graphon.graph_engine.graph_state_manager import GraphStateManager -from graphon.graph_engine.ready_queue import InMemoryReadyQueue -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_llm_node( - *, - node_id: str, - runtime_state: GraphRuntimeState, - graph_init_params: GraphInitParams, - mock_config: MockConfig, -) -> MockLLMNode: - llm_data = LLMNodeData( - title=f"LLM {node_id}", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=f"Prompt {node_id}", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - return MockLLMNode( - id=llm_config["id"], - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - -def _build_graph(runtime_state: GraphRuntimeState) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - mock_config = MockConfig() - llm_a = _build_llm_node( - node_id="llm_a", - runtime_state=runtime_state, - graph_init_params=graph_init_params, - mock_config=mock_config, - ) - llm_b = _build_llm_node( - node_id="llm_b", - runtime_state=runtime_state, - graph_init_params=graph_init_params, - mock_config=mock_config, - ) - - end_data = EndNodeData(title="End", outputs=[], desc=None) - end_config = {"id": "end", "data": end_data.model_dump()} - end_node = EndNode( - id=end_config["id"], - config=end_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - builder = ( - Graph.new() - .add_root(start_node) - .add_node(llm_a, from_node_id="start") - .add_node(llm_b, from_node_id="start") - .add_node(end_node, from_node_id="llm_a") - ) - return builder.connect(tail="llm_b", head="end").build() - - -def _edge_state_map(graph: Graph) -> Mapping[tuple[str, str, str], NodeState]: - return {(edge.tail, edge.head, edge.source_handle): edge.state for edge in graph.edges.values()} - - -def test_runtime_state_snapshot_restores_graph_states() -> None: - runtime_state = _build_runtime_state() - graph = _build_graph(runtime_state) - runtime_state.attach_graph(graph) - - graph.nodes["llm_a"].state = NodeState.TAKEN - graph.nodes["llm_b"].state = NodeState.SKIPPED - - for edge in graph.edges.values(): - if edge.tail == "start" and edge.head == "llm_a": - edge.state = NodeState.TAKEN - elif edge.tail == "start" and edge.head == "llm_b": - edge.state = NodeState.SKIPPED - elif edge.head == "end" and edge.tail == "llm_a": - edge.state = NodeState.TAKEN - elif edge.head == "end" and edge.tail == "llm_b": - edge.state = NodeState.SKIPPED - - snapshot = runtime_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resumed_graph = _build_graph(resumed_state) - resumed_state.attach_graph(resumed_graph) - - assert resumed_graph.nodes["llm_a"].state == NodeState.TAKEN - assert resumed_graph.nodes["llm_b"].state == NodeState.SKIPPED - assert _edge_state_map(resumed_graph) == _edge_state_map(graph) - - -def test_join_readiness_uses_restored_edge_states() -> None: - runtime_state = _build_runtime_state() - graph = _build_graph(runtime_state) - runtime_state.attach_graph(graph) - - ready_queue = InMemoryReadyQueue() - state_manager = GraphStateManager(graph, ready_queue) - - for edge in graph.get_incoming_edges("end"): - if edge.tail == "llm_a": - edge.state = NodeState.TAKEN - if edge.tail == "llm_b": - edge.state = NodeState.UNKNOWN - - assert state_manager.is_node_ready("end") is False - - for edge in graph.get_incoming_edges("end"): - if edge.tail == "llm_b": - edge.state = NodeState.TAKEN - - assert state_manager.is_node_ready("end") is True - - snapshot = runtime_state.dumps() - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resumed_graph = _build_graph(resumed_state) - resumed_state.attach_graph(resumed_graph) - - resumed_state_manager = GraphStateManager(resumed_graph, InMemoryReadyQueue()) - assert resumed_state_manager.is_node_ready("end") is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py deleted file mode 100644 index ca9a929591a..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ /dev/null @@ -1,389 +0,0 @@ -import datetime -import time -from collections.abc import Iterable -from unittest import mock -from unittest.mock import MagicMock - -from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.graph import Graph -from graphon.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.graph_events.node import NodeRunHumanInputFormFilledEvent -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.base.entities import OutputVariableEntity, OutputVariableType -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _build_branching_graph( - mock_config: MockConfig, - form_repository: HumanInputFormRepository, - graph_runtime_state: GraphRuntimeState | None = None, -) -> tuple[Graph, GraphRuntimeState]: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - if graph_runtime_state is None: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="test-execution-id", - ), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: - llm_data = LLMNodeData( - title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=prompt_text, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - llm_node = MockLLMNode( - id=node_id, - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - credentials_provider=mock.Mock(), - model_factory=mock.Mock(), - ) - return llm_node - - llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream") - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[ - UserAction(id="primary", title="Primary"), - UserAction(id="secondary", title="Secondary"), - ], - ) - - human_config = {"id": "human", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - form_repository=form_repository, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), - ) - - llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") - llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary") - - end_primary_data = EndNodeData( - title="End Primary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="primary_text", value_type=OutputVariableType.STRING, value_selector=["llm_primary", "text"] - ), - ], - desc=None, - ) - end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()} - end_primary = EndNode( - id=end_primary_config["id"], - config=end_primary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - end_secondary_data = EndNodeData( - title="End Secondary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="secondary_text", - value_type=OutputVariableType.STRING, - value_selector=["llm_secondary", "text"], - ), - ], - desc=None, - ) - end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()} - end_secondary = EndNode( - id=end_secondary_config["id"], - config=end_secondary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = ( - Graph.new() - .add_root(start_node) - .add_node(llm_initial) - .add_node(human_node) - .add_node(llm_primary, from_node_id="human", source_handle="primary") - .add_node(end_primary, from_node_id="llm_primary") - .add_node(llm_secondary, from_node_id="human", source_handle="secondary") - .add_node(end_secondary, from_node_id="llm_secondary") - .build() - ) - return graph, graph_runtime_state - - -def _expected_mock_llm_chunks(text: str) -> list[str]: - chunks: list[str] = [] - for index, word in enumerate(text.split(" ")): - chunk = word if index == 0 else f" {word}" - chunks.append(chunk) - chunks.append("") - return chunks - - -def _assert_stream_chunk_sequence( - chunk_events: Iterable[NodeRunStreamChunkEvent], - expected_nodes: list[str], - expected_chunks: list[str], -) -> None: - actual_nodes = [event.node_id for event in chunk_events] - actual_chunks = [event.chunk for event in chunk_events] - assert actual_nodes == expected_nodes - assert actual_chunks == expected_chunks - - -def test_human_input_llm_streaming_across_multiple_branches() -> None: - mock_config = MockConfig() - mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"}) - mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"}) - mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"}) - - branch_scenarios = [ - { - "handle": "primary", - "resume_llm": "llm_primary", - "end_node": "end_primary", - "expected_pre_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes - ("end_primary", ["\n"]), # literal segment emitted when end_primary session activates - ], - "expected_post_chunks": [ - ("llm_primary", _expected_mock_llm_chunks("Primary stream output")), # live stream from chosen branch - ], - }, - { - "handle": "secondary", - "resume_llm": "llm_secondary", - "end_node": "end_secondary", - "expected_pre_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes - ("end_secondary", ["\n"]), # literal segment emitted when end_secondary session activates - ], - "expected_post_chunks": [ - ("llm_secondary", _expected_mock_llm_chunks("Secondary")), # live stream from chosen branch - ], - }, - ] - - for scenario in branch_scenarios: - runner = TableTestRunner() - - mock_create_repo = MagicMock(spec=HumanInputFormRepository) - mock_create_repo.get_form.return_value = None - mock_form_entity = MagicMock(spec=HumanInputFormEntity) - mock_form_entity.id = "test_form_id" - mock_form_entity.submission_token = "test_web_app_token" - mock_form_entity.recipients = [] - mock_form_entity.rendered_content = "rendered" - mock_form_entity.submitted = False - mock_create_repo.create_form.return_value = mock_form_entity - - def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]: - return _build_branching_graph(mock_config, mock_create_repo) - - initial_case = WorkflowTestCase( - description="HumanInput pause before branching decision", - graph_factory=initial_graph_factory, - expected_event_sequence=[ - GraphRunStartedEvent, # initial run: graph execution starts - NodeRunStartedEvent, # start node begins execution - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial starts streaming - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # human node begins and issues pause - NodeRunPauseRequestedEvent, # human node requests pause awaiting input - GraphRunPausedEvent, # graph run pauses awaiting resume - ], - ) - - initial_result = runner.run_test_case(initial_case) - - assert initial_result.success, initial_result.event_mismatch_details - assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events) - - pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"]) - post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"]) - expected_pre_chunk_events_in_resumption = [ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunHumanInputFormFilledEvent, - ] - - expected_resume_sequence: list[type] = ( - expected_pre_chunk_events_in_resumption - + [NodeRunStreamChunkEvent] * pre_chunk_count - + [ - NodeRunSucceededEvent, - NodeRunStartedEvent, - ] - + [NodeRunStreamChunkEvent] * post_chunk_count - + [ - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ] - ) - - mock_get_repo = MagicMock(spec=HumanInputFormRepository) - submitted_form = MagicMock(spec=HumanInputFormEntity) - submitted_form.id = mock_form_entity.id - submitted_form.submission_token = mock_form_entity.submission_token - submitted_form.recipients = [] - submitted_form.rendered_content = mock_form_entity.rendered_content - submitted_form.submitted = True - submitted_form.selected_action_id = scenario["handle"] - submitted_form.submitted_data = {} - submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - mock_get_repo.get_form.return_value = submitted_form - - def resume_graph_factory( - initial_result=initial_result, mock_get_repo=mock_get_repo - ) -> tuple[Graph, GraphRuntimeState]: - assert initial_result.graph_runtime_state is not None - serialized_runtime_state = initial_result.graph_runtime_state.dumps() - resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) - return _build_branching_graph(mock_config, mock_get_repo, resume_runtime_state) - - resume_case = WorkflowTestCase( - description=f"HumanInput resumes via {scenario['handle']} branch", - graph_factory=resume_graph_factory, - expected_event_sequence=expected_resume_sequence, - ) - - resume_result = runner.run_test_case(resume_case) - - assert resume_result.success, resume_result.event_mismatch_details - - resume_events = resume_result.events - - chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)] - assert len(chunk_events) == pre_chunk_count + post_chunk_count - - pre_chunk_events = chunk_events[:pre_chunk_count] - post_chunk_events = chunk_events[pre_chunk_count:] - - expected_pre_nodes: list[str] = [] - expected_pre_chunks: list[str] = [] - for node_id, chunks in scenario["expected_pre_chunks"]: - expected_pre_nodes.extend([node_id] * len(chunks)) - expected_pre_chunks.extend(chunks) - _assert_stream_chunk_sequence(pre_chunk_events, expected_pre_nodes, expected_pre_chunks) - - expected_post_nodes: list[str] = [] - expected_post_chunks: list[str] = [] - for node_id, chunks in scenario["expected_post_chunks"]: - expected_post_nodes.extend([node_id] * len(chunks)) - expected_post_chunks.extend(chunks) - _assert_stream_chunk_sequence(post_chunk_events, expected_post_nodes, expected_post_chunks) - - human_success_index = next( - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human" - ) - pre_indices = [ - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index - ] - expected_pre_chunk_events_count_in_resumption = len(expected_pre_chunk_events_in_resumption) - assert pre_indices == list(range(expected_pre_chunk_events_count_in_resumption, human_success_index)) - - resume_chunk_indices = [ - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"] - ] - assert resume_chunk_indices, "Expected streaming output from the selected branch" - resume_start_index = next( - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"] - ) - resume_success_index = next( - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"] - ) - assert resume_start_index < min(resume_chunk_indices) - assert max(resume_chunk_indices) < resume_success_index - - started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)] - assert started_nodes == ["human", scenario["resume_llm"], scenario["end_node"]] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py deleted file mode 100644 index c50aaafe2cf..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ /dev/null @@ -1,346 +0,0 @@ -import datetime -import time -from unittest import mock -from unittest.mock import MagicMock - -from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.graph import Graph -from graphon.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.graph_events.node import NodeRunHumanInputFormFilledEvent -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.base.entities import OutputVariableEntity, OutputVariableType -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _build_llm_human_llm_graph( - mock_config: MockConfig, - form_repository: HumanInputFormRepository, - graph_runtime_state: GraphRuntimeState | None = None, -) -> tuple[Graph, GraphRuntimeState]: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - if graph_runtime_state is None: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id," - ), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: - llm_data = LLMNodeData( - title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=prompt_text, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - llm_node = MockLLMNode( - id=node_id, - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - credentials_provider=mock.Mock(), - model_factory=mock.Mock(), - ) - return llm_node - - llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt") - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[ - UserAction(id="accept", title="Accept"), - UserAction(id="reject", title="Reject"), - ], - ) - - human_config = {"id": "human", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - form_repository=form_repository, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), - ) - - llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt") - - end_data = EndNodeData( - title="End", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="resume_text", value_type=OutputVariableType.STRING, value_selector=["llm_resume", "text"] - ), - ], - desc=None, - ) - end_config = {"id": "end", "data": end_data.model_dump()} - end_node = EndNode( - id=end_config["id"], - config=end_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = ( - Graph.new() - .add_root(start_node) - .add_node(llm_first) - .add_node(human_node) - .add_node(llm_second, source_handle="accept") - .add_node(end_node) - .build() - ) - return graph, graph_runtime_state - - -def _expected_mock_llm_chunks(text: str) -> list[str]: - chunks: list[str] = [] - for index, word in enumerate(text.split(" ")): - chunk = word if index == 0 else f" {word}" - chunks.append(chunk) - chunks.append("") - return chunks - - -def test_human_input_llm_streaming_order_across_pause() -> None: - runner = TableTestRunner() - - initial_text = "Hello, pause" - resume_text = "Welcome back!" - - mock_config = MockConfig() - mock_config.set_node_outputs("llm_initial", {"text": initial_text}) - mock_config.set_node_outputs("llm_resume", {"text": resume_text}) - - expected_initial_sequence: list[type] = [ - GraphRunStartedEvent, # graph run begins - NodeRunStartedEvent, # start node begins - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial begins streaming - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # human node begins and requests pause - NodeRunPauseRequestedEvent, # human node pause requested - GraphRunPausedEvent, # graph run pauses awaiting resume - ] - - mock_create_repo = MagicMock(spec=HumanInputFormRepository) - mock_create_repo.get_form.return_value = None - mock_form_entity = MagicMock(spec=HumanInputFormEntity) - mock_form_entity.id = "test_form_id" - mock_form_entity.submission_token = "test_web_app_token" - mock_form_entity.recipients = [] - mock_form_entity.rendered_content = "rendered" - mock_form_entity.submitted = False - mock_create_repo.create_form.return_value = mock_form_entity - - def graph_factory() -> tuple[Graph, GraphRuntimeState]: - return _build_llm_human_llm_graph(mock_config, mock_create_repo) - - initial_case = WorkflowTestCase( - description="HumanInput pause preserves LLM streaming order", - graph_factory=graph_factory, - expected_event_sequence=expected_initial_sequence, - ) - - initial_result = runner.run_test_case(initial_case) - - assert initial_result.success, initial_result.event_mismatch_details - - initial_events = initial_result.events - initial_chunks = _expected_mock_llm_chunks(initial_text) - - initial_stream_chunk_events = [event for event in initial_events if isinstance(event, NodeRunStreamChunkEvent)] - assert initial_stream_chunk_events == [] - - pause_index = next(i for i, event in enumerate(initial_events) if isinstance(event, GraphRunPausedEvent)) - llm_succeeded_index = next( - i - for i, event in enumerate(initial_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_initial" - ) - assert llm_succeeded_index < pause_index - - graph_runtime_state = initial_result.graph_runtime_state - graph = initial_result.graph - assert graph_runtime_state is not None - assert graph is not None - - coordinator = graph_runtime_state.response_coordinator - stream_buffers = coordinator._stream_buffers # Tests may access internals for assertions - assert ("llm_initial", "text") in stream_buffers - initial_stream_chunks = [event.chunk for event in stream_buffers[("llm_initial", "text")]] - assert initial_stream_chunks == initial_chunks - assert ("llm_resume", "text") not in stream_buffers - - resume_chunks = _expected_mock_llm_chunks(resume_text) - expected_resume_sequence: list[type] = [ - GraphRunStartedEvent, # resumed graph run begins - NodeRunStartedEvent, # human node restarts - # Form Filled should be generated first, then the node execution ends and stream chunk is generated. - NodeRunHumanInputFormFilledEvent, - NodeRunStreamChunkEvent, # cached llm_initial chunk 1 - NodeRunStreamChunkEvent, # cached llm_initial chunk 2 - NodeRunStreamChunkEvent, # cached llm_initial final chunk - NodeRunStreamChunkEvent, # end node emits combined template separator - NodeRunSucceededEvent, # human node finishes instantly after input - NodeRunStartedEvent, # llm_resume begins streaming - NodeRunStreamChunkEvent, # llm_resume chunk 1 - NodeRunStreamChunkEvent, # llm_resume chunk 2 - NodeRunStreamChunkEvent, # llm_resume final chunk - NodeRunSucceededEvent, # llm_resume completes streaming - NodeRunStartedEvent, # end node starts - NodeRunSucceededEvent, # end node finishes - GraphRunSucceededEvent, # graph run succeeds after resume - ] - - mock_get_repo = MagicMock(spec=HumanInputFormRepository) - submitted_form = MagicMock(spec=HumanInputFormEntity) - submitted_form.id = mock_form_entity.id - submitted_form.submission_token = mock_form_entity.submission_token - submitted_form.recipients = [] - submitted_form.rendered_content = mock_form_entity.rendered_content - submitted_form.submitted = True - submitted_form.selected_action_id = "accept" - submitted_form.submitted_data = {} - submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - mock_get_repo.get_form.return_value = submitted_form - - def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]: - # restruct the graph runtime state - serialized_runtime_state = initial_result.graph_runtime_state.dumps() - resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) - return _build_llm_human_llm_graph( - mock_config, - mock_get_repo, - resume_runtime_state, - ) - - resume_case = WorkflowTestCase( - description="HumanInput resume continues LLM streaming order", - graph_factory=resume_graph_factory, - expected_event_sequence=expected_resume_sequence, - ) - - resume_result = runner.run_test_case(resume_case) - - assert resume_result.success, resume_result.event_mismatch_details - - resume_events = resume_result.events - - success_index = next(i for i, event in enumerate(resume_events) if isinstance(event, GraphRunSucceededEvent)) - llm_resume_succeeded_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume" - ) - assert llm_resume_succeeded_index < success_index - - resume_chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)] - assert [event.node_id for event in resume_chunk_events[:3]] == ["llm_initial"] * 3 - assert [event.chunk for event in resume_chunk_events[:3]] == initial_chunks - assert resume_chunk_events[3].node_id == "end" - assert resume_chunk_events[3].chunk == "\n" - assert [event.node_id for event in resume_chunk_events[4:]] == ["llm_resume"] * 3 - assert [event.chunk for event in resume_chunk_events[4:]] == resume_chunks - - human_success_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human" - ) - cached_chunk_indices = [ - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id in {"llm_initial", "end"} - ] - assert all(index < human_success_index for index in cached_chunk_indices) - - llm_resume_start_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == "llm_resume" - ) - llm_resume_success_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume" - ) - llm_resume_chunk_indices = [ - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == "llm_resume" - ] - assert llm_resume_chunk_indices - first_resume_chunk_index = min(llm_resume_chunk_indices) - last_resume_chunk_index = max(llm_resume_chunk_indices) - assert llm_resume_start_index < first_resume_chunk_index - assert last_resume_chunk_index < llm_resume_success_index - - started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)] - assert started_nodes == ["human", "llm_resume", "end"] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py deleted file mode 100644 index 246df45d5f6..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py +++ /dev/null @@ -1,324 +0,0 @@ -import time -from unittest import mock - -from core.workflow.system_variables import build_system_variables -from graphon.graph import Graph -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.base.entities import OutputVariableEntity, OutputVariableType -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.if_else.entities import IfElseNodeData -from graphon.nodes.if_else.if_else_node import IfElseNode -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.utils.condition.entities import Condition -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - graph_config=graph_config, - user_from="account", - invoke_from="debugger", - ) - - variable_pool = VariablePool( - system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"), - user_inputs={}, - conversation_variables=[], - ) - variable_pool.add(("branch", "value"), branch_value) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: - llm_data = LLMNodeData( - title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=prompt_text, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - llm_node = MockLLMNode( - id=node_id, - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - credentials_provider=mock.Mock(), - model_factory=mock.Mock(), - ) - return llm_node - - llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream") - - if_else_data = IfElseNodeData( - title="IfElse", - cases=[ - IfElseNodeData.Case( - case_id="primary", - logical_operator="and", - conditions=[ - Condition(variable_selector=["branch", "value"], comparison_operator="is", value="primary") - ], - ), - IfElseNodeData.Case( - case_id="secondary", - logical_operator="and", - conditions=[ - Condition(variable_selector=["branch", "value"], comparison_operator="is", value="secondary") - ], - ), - ], - ) - if_else_config = {"id": "if_else", "data": if_else_data.model_dump()} - if_else_node = IfElseNode( - id=if_else_config["id"], - config=if_else_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") - llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary") - - end_primary_data = EndNodeData( - title="End Primary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="primary_text", value_type=OutputVariableType.STRING, value_selector=["llm_primary", "text"] - ), - ], - desc=None, - ) - end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()} - end_primary = EndNode( - id=end_primary_config["id"], - config=end_primary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - end_secondary_data = EndNodeData( - title="End Secondary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="secondary_text", - value_type=OutputVariableType.STRING, - value_selector=["llm_secondary", "text"], - ), - ], - desc=None, - ) - end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()} - end_secondary = EndNode( - id=end_secondary_config["id"], - config=end_secondary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = ( - Graph.new() - .add_root(start_node) - .add_node(llm_initial) - .add_node(if_else_node) - .add_node(llm_primary, from_node_id="if_else", source_handle="primary") - .add_node(end_primary, from_node_id="llm_primary") - .add_node(llm_secondary, from_node_id="if_else", source_handle="secondary") - .add_node(end_secondary, from_node_id="llm_secondary") - .build() - ) - return graph, graph_runtime_state - - -def _expected_mock_llm_chunks(text: str) -> list[str]: - chunks: list[str] = [] - for index, word in enumerate(text.split(" ")): - chunk = word if index == 0 else f" {word}" - chunks.append(chunk) - chunks.append("") - return chunks - - -def test_if_else_llm_streaming_order() -> None: - mock_config = MockConfig() - mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"}) - mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"}) - mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"}) - - scenarios = [ - { - "branch": "primary", - "resume_llm": "llm_primary", - "end_node": "end_primary", - "expected_sequence": [ - GraphRunStartedEvent, # graph run begins - NodeRunStartedEvent, # start node begins execution - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial starts and streams - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # if_else evaluates conditions - NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed - NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed - NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed - NodeRunStreamChunkEvent, # template literal newline emitted - NodeRunSucceededEvent, # if_else completes branch selection - NodeRunStartedEvent, # llm_primary begins streaming - NodeRunStreamChunkEvent, # llm_primary chunk 1 - NodeRunStreamChunkEvent, # llm_primary chunk 2 - NodeRunStreamChunkEvent, # llm_primary chunk 3 - NodeRunStreamChunkEvent, # llm_primary final chunk - NodeRunSucceededEvent, # llm_primary completes streaming - NodeRunStartedEvent, # end_primary node starts - NodeRunSucceededEvent, # end_primary finishes aggregation - GraphRunSucceededEvent, # graph run succeeds - ], - "expected_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), - ("end_primary", ["\n"]), - ("llm_primary", _expected_mock_llm_chunks("Primary stream output")), - ], - }, - { - "branch": "secondary", - "resume_llm": "llm_secondary", - "end_node": "end_secondary", - "expected_sequence": [ - GraphRunStartedEvent, # graph run begins - NodeRunStartedEvent, # start node begins execution - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial starts and streams - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # if_else evaluates conditions - NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed - NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed - NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed - NodeRunStreamChunkEvent, # template literal newline emitted - NodeRunSucceededEvent, # if_else completes branch selection - NodeRunStartedEvent, # llm_secondary begins streaming - NodeRunStreamChunkEvent, # llm_secondary chunk 1 - NodeRunStreamChunkEvent, # llm_secondary final chunk - NodeRunSucceededEvent, # llm_secondary completes - NodeRunStartedEvent, # end_secondary node starts - NodeRunSucceededEvent, # end_secondary finishes aggregation - GraphRunSucceededEvent, # graph run succeeds - ], - "expected_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), - ("end_secondary", ["\n"]), - ("llm_secondary", _expected_mock_llm_chunks("Secondary")), - ], - }, - ] - - for scenario in scenarios: - runner = TableTestRunner() - - def graph_factory( - branch_value: str = scenario["branch"], - cfg: MockConfig = mock_config, - ) -> tuple[Graph, GraphRuntimeState]: - return _build_if_else_graph(branch_value, cfg) - - test_case = WorkflowTestCase( - description=f"IfElse streaming via {scenario['branch']} branch", - graph_factory=graph_factory, - expected_event_sequence=scenario["expected_sequence"], - ) - - result = runner.run_test_case(test_case) - - assert result.success, result.event_mismatch_details - - chunk_events = [event for event in result.events if isinstance(event, NodeRunStreamChunkEvent)] - expected_nodes: list[str] = [] - expected_chunks: list[str] = [] - for node_id, chunks in scenario["expected_chunks"]: - expected_nodes.extend([node_id] * len(chunks)) - expected_chunks.extend(chunks) - assert [event.node_id for event in chunk_events] == expected_nodes - assert [event.chunk for event in chunk_events] == expected_chunks - - branch_node_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == "if_else" - ) - branch_success_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "if_else" - ) - pre_branch_chunk_indices = [ - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStreamChunkEvent) and index < branch_success_index - ] - assert len(pre_branch_chunk_indices) == len(_expected_mock_llm_chunks("Initial stream")) + 1 - assert min(pre_branch_chunk_indices) == branch_node_index + 1 - assert max(pre_branch_chunk_indices) < branch_success_index - - resume_chunk_indices = [ - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"] - ] - assert resume_chunk_indices - resume_start_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"] - ) - resume_success_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"] - ) - assert resume_start_index < min(resume_chunk_indices) - assert max(resume_chunk_indices) < resume_success_index - - started_nodes = [event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)] - assert started_nodes == ["start", "llm_initial", "if_else", scenario["resume_llm"], scenario["end_node"]] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py b/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py deleted file mode 100644 index b9bf4be13a0..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py +++ /dev/null @@ -1,126 +0,0 @@ -""" -Test cases for the Iteration node's flatten_output functionality. - -This module tests the iteration node's ability to: -1. Flatten array outputs when flatten_output=True (default) -2. Preserve nested array structure when flatten_output=False -""" - -from .test_database_utils import skip_if_database_unavailable -from .test_mock_config import MockConfigBuilder, NodeMockConfig -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _create_iteration_mock_config(): - """Helper to create a mock config for iteration tests.""" - - def code_inner_handler(node): - pool = node.graph_runtime_state.variable_pool - item_seg = pool.get(["iteration_node", "item"]) - if item_seg is not None: - item = item_seg.to_object() - return {"result": [item, item * 2]} - # This fallback is likely unreachable, but if it is, - # it doesn't simulate iteration with different values as the comment suggests. - return {"result": [1, 2]} - - return ( - MockConfigBuilder() - .with_node_output("code_node", {"result": [1, 2, 3]}) - .with_node_config(NodeMockConfig(node_id="code_inner_node", custom_handler=code_inner_handler)) - .build() - ) - - -@skip_if_database_unavailable() -def test_iteration_with_flatten_output_enabled(): - """ - Test iteration node with flatten_output=True (default behavior). - - The fixture implements an iteration that: - 1. Iterates over [1, 2, 3] - 2. For each item, outputs [item, item*2] - 3. With flatten_output=True, should output [1, 2, 2, 4, 3, 6] - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="iteration_flatten_output_enabled_workflow", - inputs={}, - expected_outputs={"output": [1, 2, 2, 4, 3, 6]}, - description="Iteration with flatten_output=True flattens nested arrays", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs is not None, "Should have outputs" - assert result.actual_outputs == {"output": [1, 2, 2, 4, 3, 6]}, ( - f"Expected flattened output [1, 2, 2, 4, 3, 6], got {result.actual_outputs}" - ) - - -@skip_if_database_unavailable() -def test_iteration_with_flatten_output_disabled(): - """ - Test iteration node with flatten_output=False. - - The fixture implements an iteration that: - 1. Iterates over [1, 2, 3] - 2. For each item, outputs [item, item*2] - 3. With flatten_output=False, should output [[1, 2], [2, 4], [3, 6]] - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="iteration_flatten_output_disabled_workflow", - inputs={}, - expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]}, - description="Iteration with flatten_output=False preserves nested structure", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs is not None, "Should have outputs" - assert result.actual_outputs == {"output": [[1, 2], [2, 4], [3, 6]]}, ( - f"Expected nested output [[1, 2], [2, 4], [3, 6]], got {result.actual_outputs}" - ) - - -@skip_if_database_unavailable() -def test_iteration_flatten_output_comparison(): - """ - Run both flatten_output configurations in parallel to verify the difference. - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="iteration_flatten_output_enabled_workflow", - inputs={}, - expected_outputs={"output": [1, 2, 2, 4, 3, 6]}, - description="flatten_output=True: Flattened output", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ), - WorkflowTestCase( - fixture_path="iteration_flatten_output_disabled_workflow", - inputs={}, - expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]}, - description="flatten_output=False: Nested output", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ), - ] - - suite_result = runner.run_table_tests(test_cases, parallel=True) - - # Assert all tests passed - assert suite_result.passed_tests == 2, f"Expected 2 passed tests, got {suite_result.passed_tests}" - assert suite_result.failed_tests == 0, f"Expected 0 failed tests, got {suite_result.failed_tests}" - assert suite_result.success_rate == 100.0, f"Expected 100% success rate, got {suite_result.success_rate}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py deleted file mode 100644 index 821da46b760..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Test case for loop with inner answer output error scenario. - -This test validates the behavior of a loop containing an answer node -inside the loop that may produce output errors. -""" - -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_loop_contains_answer(): - """ - Test loop with inner answer node that may have output errors. - - The fixture implements a loop that: - 1. Iterates 4 times (index 0-3) - 2. Contains an inner answer node that outputs index and item values - 3. Has a break condition when index equals 4 - 4. Tests error handling for answer nodes within loops - """ - fixture_name = "loop_contains_answer" - mock_config = MockConfigBuilder().build() - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - query="1", - expected_outputs={"answer": "1\n2\n1 + 2"}, - expected_event_sequence=[ - # Graph start - GraphRunStartedEvent, - # Start - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Loop start - NodeRunStartedEvent, - NodeRunLoopStartedEvent, - # Variable assigner - NodeRunStartedEvent, - NodeRunVariableUpdatedEvent, - NodeRunStreamChunkEvent, # 1 - NodeRunStreamChunkEvent, # \n - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Loop next - NodeRunLoopNextEvent, - # Variable assigner - NodeRunStartedEvent, - NodeRunVariableUpdatedEvent, - NodeRunStreamChunkEvent, # 2 - NodeRunStreamChunkEvent, # \n - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Loop end - NodeRunLoopSucceededEvent, - NodeRunStreamChunkEvent, # 1 - NodeRunStreamChunkEvent, # + - NodeRunStreamChunkEvent, # 2 - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Graph end - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py deleted file mode 100644 index ad8d777ea69..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -Test cases for the Loop node functionality using TableTestRunner. - -This module tests the loop node's ability to: -1. Execute iterations with loop variables -2. Handle break conditions correctly -3. Update and propagate loop variables between iterations -4. Output the final loop variable value -""" - -from tests.unit_tests.core.workflow.graph_engine.test_table_runner import ( - TableTestRunner, - WorkflowTestCase, -) - - -def test_loop_with_break_condition(): - """ - Test loop node with break condition. - - The increment_loop_with_break_condition_workflow.yml fixture implements a loop that: - 1. Starts with num=1 - 2. Increments num by 1 each iteration - 3. Breaks when num >= 5 - 4. Should output {"num": 5} - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="increment_loop_with_break_condition_workflow", - inputs={}, # No inputs needed for this test - expected_outputs={"num": 5}, - description="Loop with break condition when num >= 5", - ) - - result = runner.run_test_case(test_case) - - # Assert the test passed - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs is not None, "Should have outputs" - assert result.actual_outputs == {"num": 5}, f"Expected {{'num': 5}}, got {result.actual_outputs}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py deleted file mode 100644 index 4a60c7769cd..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py +++ /dev/null @@ -1,72 +0,0 @@ -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_loop_with_tool(): - fixture_name = "search_dify_from_2023_to_2025" - mock_config = ( - MockConfigBuilder() - .with_tool_response( - { - "text": "mocked search result", - } - ) - .build() - ) - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - expected_outputs={ - "answer": """- mocked search result -- mocked search result""" - }, - expected_event_sequence=[ - GraphRunStartedEvent, - # START - NodeRunStartedEvent, - NodeRunSucceededEvent, - # LOOP START - NodeRunStartedEvent, - NodeRunLoopStartedEvent, - # 2023 - NodeRunStartedEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunVariableUpdatedEvent, - NodeRunVariableUpdatedEvent, - NodeRunSucceededEvent, - NodeRunLoopNextEvent, - # 2024 - NodeRunStartedEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunVariableUpdatedEvent, - NodeRunVariableUpdatedEvent, - NodeRunSucceededEvent, - # LOOP END - NodeRunLoopSucceededEvent, - NodeRunStreamChunkEvent, # loop.res - NodeRunSucceededEvent, - # ANSWER - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py deleted file mode 100644 index c511548749c..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py +++ /dev/null @@ -1,281 +0,0 @@ -""" -Example demonstrating the auto-mock system for testing workflows. - -This example shows how to test workflows with third-party service nodes -without making actual API calls. -""" - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def example_test_llm_workflow(): - """ - Example: Testing a workflow with an LLM node. - - This demonstrates how to test a workflow that uses an LLM service - without making actual API calls to OpenAI, Anthropic, etc. - """ - print("\n=== Example: Testing LLM Workflow ===\n") - - # Initialize the test runner - runner = TableTestRunner() - - # Configure mock responses - mock_config = MockConfigBuilder().with_llm_response("I'm a helpful AI assistant. How can I help you today?").build() - - # Define the test case - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Hello, AI!"}, - expected_outputs={"answer": "I'm a helpful AI assistant. How can I help you today?"}, - description="Testing LLM workflow with mocked response", - use_auto_mock=True, # Enable auto-mocking - mock_config=mock_config, - ) - - # Run the test - result = runner.run_test_case(test_case) - - if result.success: - print("โœ… Test passed!") - print(f" Input: {test_case.inputs['query']}") - print(f" Output: {result.actual_outputs['answer']}") - print(f" Execution time: {result.execution_time:.2f}s") - else: - print(f"โŒ Test failed: {result.error}") - - return result.success - - -def example_test_with_custom_outputs(): - """ - Example: Testing with custom outputs for specific nodes. - - This shows how to provide different mock outputs for specific node IDs, - useful when testing complex workflows with multiple LLM/tool nodes. - """ - print("\n=== Example: Custom Node Outputs ===\n") - - runner = TableTestRunner() - - # Configure mock with specific outputs for different nodes - mock_config = MockConfigBuilder().build() - - # Set custom output for a specific LLM node - mock_config.set_node_outputs( - "llm_node", - { - "text": "This is a custom response for the specific LLM node", - "usage": { - "prompt_tokens": 50, - "completion_tokens": 20, - "total_tokens": 70, - }, - "finish_reason": "stop", - }, - ) - - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Tell me about custom outputs"}, - expected_outputs={"answer": "This is a custom response for the specific LLM node"}, - description="Testing with custom node outputs", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if result.success: - print("โœ… Test with custom outputs passed!") - print(f" Custom output: {result.actual_outputs['answer']}") - else: - print(f"โŒ Test failed: {result.error}") - - return result.success - - -def example_test_http_and_tool_workflow(): - """ - Example: Testing a workflow with HTTP request and tool nodes. - - This demonstrates mocking external HTTP calls and tool executions. - """ - print("\n=== Example: HTTP and Tool Workflow ===\n") - - runner = TableTestRunner() - - # Configure mocks for HTTP and Tool nodes - mock_config = MockConfigBuilder().build() - - # Mock HTTP response - mock_config.set_node_outputs( - "http_node", - { - "status_code": 200, - "body": '{"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}', - "headers": {"content-type": "application/json"}, - }, - ) - - # Mock tool response (e.g., JSON parser) - mock_config.set_node_outputs( - "tool_node", - { - "result": {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}, - }, - ) - - test_case = WorkflowTestCase( - fixture_path="http-tool-workflow", - inputs={"url": "https://api.example.com/users"}, - expected_outputs={ - "status_code": 200, - "parsed_data": {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}, - }, - description="Testing HTTP and Tool workflow", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if result.success: - print("โœ… HTTP and Tool workflow test passed!") - print(f" HTTP Status: {result.actual_outputs['status_code']}") - print(f" Parsed Data: {result.actual_outputs['parsed_data']}") - else: - print(f"โŒ Test failed: {result.error}") - - return result.success - - -def example_test_error_simulation(): - """ - Example: Simulating errors in specific nodes. - - This shows how to test error handling in workflows by simulating - failures in specific nodes. - """ - print("\n=== Example: Error Simulation ===\n") - - runner = TableTestRunner() - - # Configure mock to simulate an error - mock_config = MockConfigBuilder().build() - mock_config.set_node_error("llm_node", "API rate limit exceeded") - - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "This will fail"}, - expected_outputs={}, # We expect failure - description="Testing error handling", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if not result.success: - print("โœ… Error simulation worked as expected!") - print(f" Simulated error: {result.error}") - else: - print("โŒ Expected failure but test succeeded") - - return not result.success # Success means we got the expected error - - -def example_test_with_delays(): - """ - Example: Testing with simulated execution delays. - - This demonstrates how to simulate realistic execution times - for performance testing. - """ - print("\n=== Example: Simulated Delays ===\n") - - runner = TableTestRunner() - - # Configure mock with delays - mock_config = ( - MockConfigBuilder() - .with_delays(True) # Enable delay simulation - .with_llm_response("Response after delay") - .build() - ) - - # Add specific delay for the LLM node - from .test_mock_config import NodeMockConfig - - node_config = NodeMockConfig( - node_id="llm_node", - outputs={"text": "Response after delay"}, - delay=0.5, # 500ms delay - ) - mock_config.set_node_config("llm_node", node_config) - - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Test with delay"}, - expected_outputs={"answer": "Response after delay"}, - description="Testing with simulated delays", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if result.success: - print("โœ… Delay simulation test passed!") - print(f" Execution time: {result.execution_time:.2f}s") - print(" (Should be >= 0.5s due to simulated delay)") - else: - print(f"โŒ Test failed: {result.error}") - - return result.success and result.execution_time >= 0.5 - - -def run_all_examples(): - """Run all example tests.""" - print("\n" + "=" * 50) - print("AUTO-MOCK SYSTEM EXAMPLES") - print("=" * 50) - - examples = [ - example_test_llm_workflow, - example_test_with_custom_outputs, - example_test_http_and_tool_workflow, - example_test_error_simulation, - example_test_with_delays, - ] - - results = [] - for example in examples: - try: - results.append(example()) - except Exception as e: - print(f"\nโŒ Example failed with exception: {e}") - results.append(False) - - print("\n" + "=" * 50) - print("SUMMARY") - print("=" * 50) - - passed = sum(results) - total = len(results) - print(f"\nโœ… Passed: {passed}/{total}") - - if passed == total: - print("\n๐ŸŽ‰ All examples passed successfully!") - else: - print(f"\nโš ๏ธ {total - passed} example(s) failed") - - return passed == total - - -if __name__ == "__main__": - import sys - - success = run_all_examples() - sys.exit(0 if success else 1) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 76b2984a4b8..88989db8565 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -7,11 +7,12 @@ requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request from typing import TYPE_CHECKING, Any -from core.workflow.node_factory import DifyNodeFactory from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node +from core.workflow.node_factory import DifyNodeFactory + from .test_mock_nodes import ( MockAgentNode, MockCodeNode, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py deleted file mode 100644 index aff479104f9..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ /dev/null @@ -1,199 +0,0 @@ -""" -Simple test to verify MockNodeFactory works with iteration nodes. -""" - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from graphon.enums import BuiltinNodeTypes -from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder -from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory - - -def test_mock_factory_registers_iteration_node(): - """Test that MockNodeFactory has iteration node registered.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create a MockNodeFactory instance - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={"nodes": [], "edges": []}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # Check that iteration node is registered - assert BuiltinNodeTypes.ITERATION in factory._mock_node_types - print("โœ“ Iteration node is registered in MockNodeFactory") - - # Check that loop node is registered - assert BuiltinNodeTypes.LOOP in factory._mock_node_types - print("โœ“ Loop node is registered in MockNodeFactory") - - # Check the class types - from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode, MockLoopNode - - assert factory._mock_node_types[BuiltinNodeTypes.ITERATION] == MockIterationNode - print("โœ“ Iteration node maps to MockIterationNode class") - - assert factory._mock_node_types[BuiltinNodeTypes.LOOP] == MockLoopNode - print("โœ“ Loop node maps to MockLoopNode class") - - -def test_mock_iteration_node_preserves_config(): - """Test that MockIterationNode preserves mock configuration.""" - - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode - - # Create mock config - mock_config = MockConfigBuilder().with_llm_response("Test response").build() - - # Create minimal graph init params - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={"nodes": [], "edges": []}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - - # Create minimal runtime state - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - - # Create mock iteration node - node_config = { - "id": "iter1", - "data": { - "type": "iteration", - "title": "Test", - "iterator_selector": ["start", "items"], - "output_selector": ["node", "text"], - "start_node_id": "node1", - }, - } - - mock_node = MockIterationNode( - id="iter1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Verify the mock config is preserved - assert mock_node.mock_config == mock_config - print("โœ“ MockIterationNode preserves mock configuration") - - # Check that _create_graph_engine method exists and is overridden - assert hasattr(mock_node, "_create_graph_engine") - assert MockIterationNode._create_graph_engine != MockIterationNode.__bases__[1]._create_graph_engine - print("โœ“ MockIterationNode overrides _create_graph_engine method") - - -def test_mock_loop_node_preserves_config(): - """Test that MockLoopNode preserves mock configuration.""" - - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode - - # Create mock config - mock_config = MockConfigBuilder().with_http_response({"status": 200}).build() - - # Create minimal graph init params - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={"nodes": [], "edges": []}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - - # Create minimal runtime state - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - - # Create mock loop node - node_config = { - "id": "loop1", - "data": { - "type": "loop", - "title": "Test", - "loop_count": 3, - "start_node_id": "node1", - "loop_variables": [], - "outputs": {}, - "break_conditions": [], - "logical_operator": "and", - }, - } - - mock_node = MockLoopNode( - id="loop1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Verify the mock config is preserved - assert mock_node.mock_config == mock_config - print("โœ“ MockLoopNode preserves mock configuration") - - # Check that _create_graph_engine method exists and is overridden - assert hasattr(mock_node, "_create_graph_engine") - assert MockLoopNode._create_graph_engine != MockLoopNode.__bases__[1]._create_graph_engine - print("โœ“ MockLoopNode overrides _create_graph_engine method") - - -if __name__ == "__main__": - test_mock_factory_registers_iteration_node() - test_mock_iteration_node_preserves_config() - test_mock_loop_node_preserves_config() - print("\nโœ… All tests passed! MockNodeFactory now supports iteration and loop nodes.") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 971b9b2bbff..8b7fbd1b303 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -10,10 +10,6 @@ from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock -from core.model_manager import ModelInstance -from core.workflow.node_runtime import DifyToolNodeRuntime -from core.workflow.nodes.agent import AgentNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent @@ -31,6 +27,11 @@ from graphon.nodes.template_transform import TemplateTransformNode from graphon.nodes.tool import ToolNode from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError +from core.model_manager import ModelInstance +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.nodes.agent import AgentNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode + if TYPE_CHECKING: from graphon.entities import GraphInitParams from graphon.runtime import GraphRuntimeState diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py deleted file mode 100644 index 15f6f513983..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ /dev/null @@ -1,670 +0,0 @@ -""" -Test cases for Mock Template Transform and Code nodes. - -This module tests the functionality of MockTemplateTransformNode and MockCodeNode -to ensure they work correctly with the TableTestRunner. -""" - -from configs import dify_config -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.nodes.code.limits import CodeNodeLimits -from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig -from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory -from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockCodeNode, MockTemplateTransformNode - -DEFAULT_CODE_LIMITS = CodeNodeLimits( - max_string_length=dify_config.CODE_MAX_STRING_LENGTH, - max_number=dify_config.CODE_MAX_NUMBER, - min_number=dify_config.CODE_MIN_NUMBER, - max_precision=dify_config.CODE_MAX_PRECISION, - max_depth=dify_config.CODE_MAX_DEPTH, - max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, - max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, - max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, -) - - -class _NoopCodeExecutor: - def execute(self, *, language: object, code: str, inputs: dict[str, object]) -> dict[str, object]: - _ = (language, code, inputs) - return {} - - def is_execution_error(self, error: Exception) -> bool: - _ = error - return False - - -class TestMockTemplateTransformNode: - """Test cases for MockTemplateTransformNode.""" - - def test_mock_template_transform_node_default_output(self): - """Test that MockTemplateTransformNode processes templates with Jinja2.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "output" in result.outputs - # The template "Hello {{ name }}" with no name variable renders as "Hello " - assert result.outputs["output"] == "Hello " - - def test_mock_template_transform_node_custom_output(self): - """Test that MockTemplateTransformNode returns custom configured output.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config with custom output - mock_config = ( - MockConfigBuilder().with_node_output("template_node_1", {"output": "Custom template output"}).build() - ) - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "output" in result.outputs - assert result.outputs["output"] == "Custom template output" - - def test_mock_template_transform_node_error_simulation(self): - """Test that MockTemplateTransformNode can simulate errors.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config with error - mock_config = MockConfigBuilder().with_node_error("template_node_1", "Simulated template error").build() - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.FAILED - assert result.error == "Simulated template error" - - def test_mock_template_transform_node_with_variables(self): - """Test that MockTemplateTransformNode processes templates with variables.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - from graphon.variables import StringVariable - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - # Add a variable to the pool - variable_pool.add(["test", "name"], StringVariable(name="name", value="World", selector=["test", "name"])) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config with a variable - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [{"variable": "name", "value_selector": ["test", "name"]}], - "template": "Hello {{ name }}!", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "output" in result.outputs - assert result.outputs["output"] == "Hello World!" - - -class TestMockCodeNode: - """Test cases for MockCodeNode.""" - - def test_mock_code_node_default_output(self): - """Test that MockCodeNode returns default output.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "result = 'test'", - "outputs": {}, # Empty outputs for default case - }, - } - - # Create mock node - mock_node = MockCodeNode( - id="code_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - code_executor=_NoopCodeExecutor(), - code_limits=DEFAULT_CODE_LIMITS, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "result" in result.outputs - assert result.outputs["result"] == "mocked code execution result" - - def test_mock_code_node_with_output_schema(self): - """Test that MockCodeNode generates outputs based on schema.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config with output schema - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "name = 'test'\ncount = 42\nitems = ['a', 'b']", - "outputs": { - "name": {"type": "string"}, - "count": {"type": "number"}, - "items": {"type": "array[string]"}, - }, - }, - } - - # Create mock node - mock_node = MockCodeNode( - id="code_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - code_executor=_NoopCodeExecutor(), - code_limits=DEFAULT_CODE_LIMITS, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "name" in result.outputs - assert result.outputs["name"] == "mocked_name" - assert "count" in result.outputs - assert result.outputs["count"] == 42 - assert "items" in result.outputs - assert result.outputs["items"] == ["item1", "item2"] - - def test_mock_code_node_custom_output(self): - """Test that MockCodeNode returns custom configured output.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config with custom output - mock_config = ( - MockConfigBuilder() - .with_node_output("code_node_1", {"result": "Custom code result", "status": "success"}) - .build() - ) - - # Create node config - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "result = 'test'", - "outputs": {}, # Empty outputs for default case - }, - } - - # Create mock node - mock_node = MockCodeNode( - id="code_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - code_executor=_NoopCodeExecutor(), - code_limits=DEFAULT_CODE_LIMITS, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "result" in result.outputs - assert result.outputs["result"] == "Custom code result" - assert "status" in result.outputs - assert result.outputs["status"] == "success" - - -class TestMockNodeFactory: - """Test cases for MockNodeFactory with new node types.""" - - def test_code_and_template_nodes_mocked_by_default(self): - """Test that CODE and TEMPLATE_TRANSFORM nodes are mocked by default (they require SSRF proxy).""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create factory - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - # Verify that CODE and TEMPLATE_TRANSFORM ARE mocked by default (they require SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Verify that other third-party service nodes ARE also mocked by default - assert factory.should_mock_node(BuiltinNodeTypes.LLM) - assert factory.should_mock_node(BuiltinNodeTypes.AGENT) - - def test_factory_creates_mock_template_transform_node(self): - """Test that MockNodeFactory creates MockTemplateTransformNode for template-transform type.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create factory - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create node through factory - node = factory.create_node(node_config) - - # Verify the correct mock type was created - assert isinstance(node, MockTemplateTransformNode) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - def test_factory_creates_mock_code_node(self): - """Test that MockNodeFactory creates MockCodeNode for code type.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create factory - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - # Create node config - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "result = 42", - "outputs": {}, # Required field for CodeNodeData - }, - } - - # Create node through factory - node = factory.create_node(node_config) - - # Verify the correct mock type was created - assert isinstance(node, MockCodeNode) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py deleted file mode 100644 index cb5200f8dc0..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ /dev/null @@ -1,231 +0,0 @@ -""" -Simple test to validate the auto-mock system without external dependencies. -""" - -import sys - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from graphon.enums import BuiltinNodeTypes -from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig -from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory - - -def test_mock_config_builder(): - """Test the MockConfigBuilder fluent interface.""" - print("Testing MockConfigBuilder...") - - config = ( - MockConfigBuilder() - .with_llm_response("LLM response") - .with_agent_response("Agent response") - .with_tool_response({"tool": "output"}) - .with_retrieval_response("Retrieval content") - .with_http_response({"status_code": 201, "body": "created"}) - .with_node_output("node1", {"output": "value"}) - .with_node_error("node2", "error message") - .with_delays(True) - .build() - ) - - assert config.default_llm_response == "LLM response" - assert config.default_agent_response == "Agent response" - assert config.default_tool_response == {"tool": "output"} - assert config.default_retrieval_response == "Retrieval content" - assert config.default_http_response == {"status_code": 201, "body": "created"} - assert config.simulate_delays is True - - node1_config = config.get_node_config("node1") - assert node1_config is not None - assert node1_config.outputs == {"output": "value"} - - node2_config = config.get_node_config("node2") - assert node2_config is not None - assert node2_config.error == "error message" - - print("โœ“ MockConfigBuilder test passed") - - -def test_mock_config_operations(): - """Test MockConfig operations.""" - print("Testing MockConfig operations...") - - config = MockConfig() - - # Test setting node outputs - config.set_node_outputs("test_node", {"result": "test_value"}) - node_config = config.get_node_config("test_node") - assert node_config is not None - assert node_config.outputs == {"result": "test_value"} - - # Test setting node error - config.set_node_error("error_node", "Test error") - error_config = config.get_node_config("error_node") - assert error_config is not None - assert error_config.error == "Test error" - - # Test default configs by node type - config.set_default_config(BuiltinNodeTypes.LLM, {"temperature": 0.7}) - llm_config = config.get_default_config(BuiltinNodeTypes.LLM) - assert llm_config == {"temperature": 0.7} - - print("โœ“ MockConfig operations test passed") - - -def test_node_mock_config(): - """Test NodeMockConfig.""" - print("Testing NodeMockConfig...") - - # Test with custom handler - def custom_handler(node): - return {"custom": "output"} - - node_config = NodeMockConfig( - node_id="test_node", outputs={"text": "test"}, error=None, delay=0.5, custom_handler=custom_handler - ) - - assert node_config.node_id == "test_node" - assert node_config.outputs == {"text": "test"} - assert node_config.delay == 0.5 - assert node_config.custom_handler is not None - - # Test custom handler - result = node_config.custom_handler(None) - assert result == {"custom": "output"} - - print("โœ“ NodeMockConfig test passed") - - -def test_mock_factory_detection(): - """Test MockNodeFactory node type detection.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - print("Testing MockNodeFactory detection...") - - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # Test that third-party service nodes are identified for mocking - assert factory.should_mock_node(BuiltinNodeTypes.LLM) - assert factory.should_mock_node(BuiltinNodeTypes.AGENT) - assert factory.should_mock_node(BuiltinNodeTypes.TOOL) - assert factory.should_mock_node(BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL) - assert factory.should_mock_node(BuiltinNodeTypes.HTTP_REQUEST) - assert factory.should_mock_node(BuiltinNodeTypes.PARAMETER_EXTRACTOR) - assert factory.should_mock_node(BuiltinNodeTypes.DOCUMENT_EXTRACTOR) - - # Test that CODE and TEMPLATE_TRANSFORM are mocked (they require SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Test that non-service nodes are not mocked - assert not factory.should_mock_node(BuiltinNodeTypes.START) - assert not factory.should_mock_node(BuiltinNodeTypes.END) - assert not factory.should_mock_node(BuiltinNodeTypes.IF_ELSE) - assert not factory.should_mock_node(BuiltinNodeTypes.VARIABLE_AGGREGATOR) - - print("โœ“ MockNodeFactory detection test passed") - - -def test_mock_factory_registration(): - """Test registering and unregistering mock node types.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - print("Testing MockNodeFactory registration...") - - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # TEMPLATE_TRANSFORM is mocked by default (requires SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Unregister mock - factory.unregister_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - assert not factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Register custom mock (using a dummy class for testing) - class DummyMockNode: - pass - - factory.register_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM, DummyMockNode) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - print("โœ“ MockNodeFactory registration test passed") - - -def run_all_tests(): - """Run all tests.""" - print("\n=== Running Auto-Mock System Tests ===\n") - - try: - test_mock_config_builder() - test_mock_config_operations() - test_node_mock_config() - test_mock_factory_detection() - test_mock_factory_registration() - - print("\n=== All tests passed! โœ… ===\n") - return True - except AssertionError as e: - print(f"\nโŒ Test failed: {e}") - return False - except Exception as e: - print(f"\nโŒ Unexpected error: {e}") - import traceback - - traceback.print_exc() - return False - - -if __name__ == "__main__": - success = run_all_tests() - sys.exit(0 if success else 1) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index 37b43bd3749..8311a1e847a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,18 +4,10 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.entities import WorkflowStartReason from graphon.graph import Graph -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from graphon.graph_engine.config import GraphEngineConfig -from graphon.graph_engine.graph_engine import GraphEngine +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, @@ -31,6 +23,14 @@ from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState, VariablePool + +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py deleted file mode 100644 index 59e54bd39a6..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ /dev/null @@ -1,336 +0,0 @@ -import time -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.graph import Graph -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from graphon.graph_engine.config import GraphEngineConfig -from graphon.graph_engine.graph_engine import GraphEngine -from graphon.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig, NodeMockConfig -from .test_mock_nodes import MockLLMNode - - -@dataclass -class StaticForm(HumanInputFormEntity): - form_id: str - rendered: str - is_submitted: bool - action_id: str | None = None - data: Mapping[str, Any] | None = None - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() + timedelta(days=1) - - @property - def id(self) -> str: - return self.form_id - - @property - def submission_token(self) -> str | None: - return "token" - - @property - def recipients(self) -> list: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class StaticRepo(HumanInputFormRepository): - def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: - self._forms_by_node_id = dict(forms_by_node_id) - - def get_form(self, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_node_id.get(node_id) - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - raise AssertionError("create_form should not be called in resume scenario") - - -class DelayedHumanInputNode(HumanInputNode): - def __init__(self, delay_seconds: float, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._delay_seconds = delay_seconds - - def _run(self): - if self._delay_seconds > 0: - time.sleep(self._delay_seconds) - yield from super()._run() - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ) - - human_a_config = {"id": "human_a", "data": human_data.model_dump()} - human_a = HumanInputNode( - id=human_a_config["id"], - config=human_a_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), - ) - - human_b_config = {"id": "human_b", "data": human_data.model_dump()} - human_b = DelayedHumanInputNode( - id=human_b_config["id"], - config=human_b_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), - delay_seconds=0.2, - ) - - llm_data = LLMNodeData( - title="LLM A", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt A", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_config = {"id": "llm_a", "data": llm_data.model_dump()} - llm_a = MockLLMNode( - id=llm_config["id"], - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(human_a, from_node_id="start") - .add_node(human_b, from_node_id="start") - .add_node(llm_a, from_node_id="human_a", source_handle="approve") - .build() - ) - - -def test_parallel_human_input_pause_preserves_node_finished() -> None: - runtime_state = _build_runtime_state() - - runtime_state.graph_execution.start() - runtime_state.register_paused_node("human_a") - runtime_state.register_paused_node("human_b") - - submitted = StaticForm( - form_id="form-a", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - pending = StaticForm( - form_id="form-b", - rendered="rendered", - is_submitted=False, - action_id=None, - data=None, - status_value=HumanInputFormStatus.WAITING, - ) - repo = StaticRepo({"human_a": submitted, "human_b": pending}) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - - graph = _build_graph(runtime_state, repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - events = list(engine.run()) - - llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) - llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) - human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) - graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) - graph_started = any(isinstance(e, GraphRunStartedEvent) for e in events) - - assert graph_started - assert graph_paused - assert human_b_pause - assert llm_started - assert llm_succeeded - - -def test_parallel_human_input_pause_preserves_node_finished_after_snapshot_resume() -> None: - base_state = _build_runtime_state() - base_state.graph_execution.start() - base_state.register_paused_node("human_a") - base_state.register_paused_node("human_b") - snapshot = base_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - submitted = StaticForm( - form_id="form-a", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - pending = StaticForm( - form_id="form-b", - rendered="rendered", - is_submitted=False, - action_id=None, - data=None, - status_value=HumanInputFormStatus.WAITING, - ) - repo = StaticRepo({"human_a": submitted, "human_b": pending}) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - - graph = _build_graph(resumed_state, repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=resumed_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - events = list(engine.run()) - - start_event = next(e for e in events if isinstance(e, GraphRunStartedEvent)) - assert start_event.reason is WorkflowStartReason.RESUMPTION - - llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) - llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) - human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) - graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) - - assert graph_paused - assert human_b_pause - assert llm_started - assert llm_succeeded diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py deleted file mode 100644 index 1a437344628..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ /dev/null @@ -1,286 +0,0 @@ -""" -Test for parallel streaming workflow behavior. - -This test validates that: -- LLM 1 always speaks English -- LLM 2 always speaks Chinese -- 2 LLMs run parallel, but LLM 2 will output before LLM 1 -- All chunks should be sent before Answer Node started -""" - -import time -from unittest.mock import MagicMock, patch -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.model_manager import ModelInstance -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from core.workflow.system_variables import build_system_variables -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_events import ( - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.node_events import NodeRunResult, StreamCompletedEvent -from graphon.nodes.llm.node import LLMNode -from graphon.runtime import GraphRuntimeState, VariablePool -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_table_runner import TableTestRunner - - -def create_llm_generator_with_delay(chunks: list[str], delay: float = 0.1): - """Create a generator that simulates LLM streaming output with delay""" - - def llm_generator(self): - for i, chunk in enumerate(chunks): - time.sleep(delay) # Simulate network delay - yield NodeRunStreamChunkEvent( - id=str(uuid4()), - node_id=self.id, - node_type=self.node_type, - selector=[self.id, "text"], - chunk=chunk, - is_final=i == len(chunks) - 1, - ) - - # Complete response - full_text = "".join(chunks) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": full_text}, - ) - ) - - return llm_generator - - -def test_parallel_streaming_workflow(): - """ - Test parallel streaming workflow to verify: - 1. All chunks from LLM 2 are output before LLM 1 - 2. At least one chunk from LLM 2 is output before LLM 1 completes (Success) - 3. At least one chunk from LLM 1 is output before LLM 2 completes (EXPECTED TO FAIL) - 4. All chunks are output before End begins - 5. The final output content matches the order defined in the Answer - - Test setup: - - LLM 1 outputs English (slower) - - LLM 2 outputs Chinese (faster) - - Both run in parallel - - This test is expected to FAIL because chunks are currently buffered - until after node completion instead of streaming during execution. - """ - runner = TableTestRunner() - - # Load the workflow configuration - fixture_data = runner.workflow_runner.load_fixture("multilingual_parallel_llm_streaming_workflow") - workflow_config = fixture_data.get("workflow", {}) - graph_config = workflow_config.get("graph", {}) - - # Create graph initialization parameters - init_params = build_test_graph_init_params( - workflow_id="test_workflow", - graph_config=graph_config, - tenant_id="test_tenant", - app_id="test_app", - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - call_depth=0, - ) - - # Create variable pool with system variables - system_variables = build_system_variables( - user_id="test_user", - app_id="test_app", - workflow_id=init_params.workflow_id, - files=[], - query="Tell me about yourself", # User query - ) - variable_pool = VariablePool( - system_variables=system_variables, - user_inputs={}, - ) - - # Create graph runtime state - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - # Create node factory and graph - node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state) - with patch.object( - DifyNodeFactory, "_build_model_instance_for_llm_node", return_value=MagicMock(spec=ModelInstance), autospec=True - ): - graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=get_default_root_node_id(graph_config), - ) - - # Create the graph engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Define LLM outputs - llm1_chunks = ["Hello", ", ", "I", " ", "am", " ", "an", " ", "AI", " ", "assistant", "."] # English (slower) - llm2_chunks = ["ไฝ ๅฅฝ", "๏ผŒ", "ๆˆ‘", "ๆ˜ฏ", "AI", "ๅŠฉๆ‰‹", "ใ€‚"] # Chinese (faster) - - # Create generators with different delays (LLM 2 is faster) - llm1_generator = create_llm_generator_with_delay(llm1_chunks, delay=0.05) # Slower - llm2_generator = create_llm_generator_with_delay(llm2_chunks, delay=0.01) # Faster - - # Track which LLM node is being called - llm_call_order = [] - generators = { - "1754339718571": llm1_generator, # LLM 1 node ID - "1754339725656": llm2_generator, # LLM 2 node ID - } - - def mock_llm_run(self): - llm_call_order.append(self.id) - generator = generators.get(self.id) - if generator: - yield from generator(self) - else: - raise Exception(f"Unexpected LLM node ID: {self.id}") - - # Execute with mocked LLMs - with patch.object(LLMNode, "_run", new=mock_llm_run): - events = list(engine.run()) - - # Check for successful completion - success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)] - assert len(success_events) > 0, "Workflow should complete successfully" - - # Get all streaming chunk events - stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] - - # Get Answer node start event - answer_start_events = [ - e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.ANSWER - ] - assert len(answer_start_events) == 1, f"Expected 1 Answer node start event, got {len(answer_start_events)}" - answer_start_event = answer_start_events[0] - - # Find the index of Answer node start - answer_start_index = events.index(answer_start_event) - - # Collect chunk events by node - llm1_chunks_events = [e for e in stream_chunk_events if e.node_id == "1754339718571"] - llm2_chunks_events = [e for e in stream_chunk_events if e.node_id == "1754339725656"] - - # Verify both LLMs produced chunks - assert len(llm1_chunks_events) == len(llm1_chunks), ( - f"Expected {len(llm1_chunks)} chunks from LLM 1, got {len(llm1_chunks_events)}" - ) - assert len(llm2_chunks_events) == len(llm2_chunks), ( - f"Expected {len(llm2_chunks)} chunks from LLM 2, got {len(llm2_chunks_events)}" - ) - - # 1. Verify chunk ordering based on actual implementation - llm1_chunk_indices = [events.index(e) for e in llm1_chunks_events] - llm2_chunk_indices = [events.index(e) for e in llm2_chunks_events] - - # In the current implementation, chunks may be interleaved or in a specific order - # Update this based on actual behavior observed - if llm1_chunk_indices and llm2_chunk_indices: - # Check the actual ordering - if LLM 2 chunks come first (as seen in debug) - assert max(llm2_chunk_indices) < min(llm1_chunk_indices), ( - f"All LLM 2 chunks should be output before LLM 1 chunks. " - f"LLM 2 chunk indices: {llm2_chunk_indices}, LLM 1 chunk indices: {llm1_chunk_indices}" - ) - - # Get indices of all chunk events - chunk_indices = [events.index(e) for e in stream_chunk_events if e in llm1_chunks_events + llm2_chunks_events] - - # 4. Verify all chunks were sent before Answer node started - assert all(idx < answer_start_index for idx in chunk_indices), ( - "All LLM chunks should be sent before Answer node starts" - ) - - # The test has successfully verified: - # 1. Both LLMs run in parallel (they start at the same time) - # 2. LLM 2 (Chinese) outputs all its chunks before LLM 1 (English) due to faster processing - # 3. All LLM chunks are sent before the Answer node starts - - # Get LLM completion events - llm_completed_events = [ - (i, e) - for i, e in enumerate(events) - if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM - ] - - # Check LLM completion order - in the current implementation, LLMs run sequentially - # LLM 1 completes first, then LLM 2 runs and completes - assert len(llm_completed_events) == 2, f"Expected 2 LLM completion events, got {len(llm_completed_events)}" - llm2_complete_idx = next((i for i, e in llm_completed_events if e.node_id == "1754339725656"), None) - llm1_complete_idx = next((i for i, e in llm_completed_events if e.node_id == "1754339718571"), None) - assert llm2_complete_idx is not None, "LLM 2 completion event not found" - assert llm1_complete_idx is not None, "LLM 1 completion event not found" - # In the actual implementation, LLM 1 completes before LLM 2 (sequential execution) - assert llm1_complete_idx < llm2_complete_idx, ( - f"LLM 1 should complete before LLM 2 in sequential execution, but LLM 1 completed at {llm1_complete_idx} " - f"and LLM 2 completed at {llm2_complete_idx}" - ) - - # 2. In sequential execution, LLM 2 chunks appear AFTER LLM 1 completes - if llm2_chunk_indices: - # LLM 1 completes first, then LLM 2 starts streaming - assert min(llm2_chunk_indices) > llm1_complete_idx, ( - f"LLM 2 chunks should appear after LLM 1 completes in sequential execution. " - f"First LLM 2 chunk at index {min(llm2_chunk_indices)}, LLM 1 completed at index {llm1_complete_idx}" - ) - - # 3. In the current implementation, LLM 1 chunks appear after LLM 2 completes - # This is because chunks are buffered and output after both nodes complete - if llm1_chunk_indices and llm2_complete_idx: - # Check if LLM 1 chunks exist and where they appear relative to LLM 2 completion - # In current behavior, LLM 1 chunks typically appear after LLM 2 completes - pass # Skipping this check as the chunk ordering is implementation-dependent - - # CURRENT BEHAVIOR: Chunks are buffered and appear after node completion - # In the sequential execution, LLM 1 completes first without streaming, - # then LLM 2 streams its chunks - assert stream_chunk_events, "Expected streaming events, but got none" - - first_chunk_index = events.index(stream_chunk_events[0]) - llm_success_indices = [i for i, e in llm_completed_events] - - # Current implementation: LLM 1 completes first, then chunks start appearing - # This is the actual behavior we're testing - if llm_success_indices: - # At least one LLM (LLM 1) completes before any chunks appear - assert min(llm_success_indices) < first_chunk_index, ( - f"In current implementation, LLM 1 completes before chunks start streaming. " - f"First chunk at index {first_chunk_index}, LLM 1 completed at index {min(llm_success_indices)}" - ) - - # 5. Verify final output content matches the order defined in Answer node - # According to Answer node configuration: '{{#1754339725656.text#}}{{#1754339718571.text#}}' - # This means LLM 2 output should come first, then LLM 1 output - answer_complete_events = [ - e for e in events if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.ANSWER - ] - assert len(answer_complete_events) == 1, f"Expected 1 Answer completion event, got {len(answer_complete_events)}" - - answer_outputs = answer_complete_events[0].node_run_result.outputs - expected_answer_text = "ไฝ ๅฅฝ๏ผŒๆˆ‘ๆ˜ฏAIๅŠฉๆ‰‹ใ€‚Hello, I am an AI assistant." - - if "answer" in answer_outputs: - actual_answer_text = answer_outputs["answer"] - assert actual_answer_text == expected_answer_text, ( - f"Answer content should match the order defined in Answer node. " - f"Expected: '{expected_answer_text}', Got: '{actual_answer_text}'" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py deleted file mode 100644 index bcf123ee804..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ /dev/null @@ -1,311 +0,0 @@ -import time -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.graph import Graph -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from graphon.graph_engine.config import GraphEngineConfig -from graphon.graph_engine.graph_engine import GraphEngine -from graphon.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig, NodeMockConfig -from .test_mock_nodes import MockLLMNode - - -@dataclass -class StaticForm(HumanInputFormEntity): - form_id: str - rendered: str - is_submitted: bool - action_id: str | None = None - data: Mapping[str, Any] | None = None - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() + timedelta(days=1) - - @property - def id(self) -> str: - return self.form_id - - @property - def submission_token(self) -> str | None: - return "token" - - @property - def recipients(self) -> list: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class StaticRepo(HumanInputFormRepository): - def __init__(self, form: HumanInputFormEntity) -> None: - self._form = form - - def get_form(self, node_id: str) -> HumanInputFormEntity | None: - if node_id != "human_pause": - return None - return self._form - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - raise AssertionError("create_form should not be called in this test") - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - llm_a_data = LLMNodeData( - title="LLM A", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt A", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_a_config = {"id": "llm_a", "data": llm_a_data.model_dump()} - llm_a = MockLLMNode( - id=llm_a_config["id"], - config=llm_a_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - llm_b_data = LLMNodeData( - title="LLM B", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt B", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_b_config = {"id": "llm_b", "data": llm_b_data.model_dump()} - llm_b = MockLLMNode( - id=llm_b_config["id"], - config=llm_b_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Pause here", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ) - human_config = {"id": "human_pause", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), - ) - - end_human_data = EndNodeData(title="End Human", outputs=[], desc=None) - end_human_config = {"id": "end_human", "data": end_human_data.model_dump()} - end_human = EndNode( - id=end_human_config["id"], - config=end_human_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(llm_a, from_node_id="start") - .add_node(human_node, from_node_id="start") - .add_node(llm_b, from_node_id="llm_a") - .add_node(end_human, from_node_id="human_pause", source_handle="approve") - .build() - ) - - -def _get_node_started_event(events: list[object], node_id: str) -> NodeRunStartedEvent | None: - for event in events: - if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: - return event - return None - - -def test_pause_defers_ready_nodes_until_resume() -> None: - runtime_state = _build_runtime_state() - - paused_form = StaticForm( - form_id="form-pause", - rendered="rendered", - is_submitted=False, - status_value=HumanInputFormStatus.WAITING, - ) - pause_repo = StaticRepo(paused_form) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - mock_config.set_node_config( - "llm_b", - NodeMockConfig(node_id="llm_b", outputs={"text": "LLM B output"}, delay=0.0), - ) - - graph = _build_graph(runtime_state, pause_repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - paused_events = list(engine.run()) - - assert any(isinstance(e, GraphRunPausedEvent) for e in paused_events) - assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in paused_events) - assert _get_node_started_event(paused_events, "llm_b") is None - - snapshot = runtime_state.dumps() - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - submitted_form = StaticForm( - form_id="form-pause", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - resume_repo = StaticRepo(submitted_form) - - resumed_graph = _build_graph(resumed_state, resume_repo, mock_config) - resumed_engine = GraphEngine( - workflow_id="workflow", - graph=resumed_graph, - graph_runtime_state=resumed_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - resumed_events = list(resumed_engine.run()) - - start_event = next(e for e in resumed_events if isinstance(e, GraphRunStartedEvent)) - assert start_event.reason is WorkflowStartReason.RESUMPTION - - llm_b_started = _get_node_started_event(resumed_events, "llm_b") - assert llm_b_started is not None - assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_b" for e in resumed_events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py deleted file mode 100644 index 79d3d5bcfe8..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ /dev/null @@ -1,219 +0,0 @@ -import datetime -import time -from typing import Any -from unittest.mock import MagicMock - -from core.repositories.human_input_repository import ( - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.graph import Graph -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from graphon.graph_engine.graph_engine import GraphEngine -from graphon.graph_events import ( - GraphEngineEvent, - GraphRunPausedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from graphon.graph_events.graph import GraphRunStartedEvent -from graphon.nodes.base.entities import OutputVariableEntity -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="test-execution-id", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.submission_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = True - form_entity.selected_action_id = action_id - form_entity.submitted_data = {} - form_entity.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - repo.get_form.return_value = form_entity - return repo - - -def _mock_form_repository_without_submission() -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.submission_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = False - repo.create_form.return_value = form_entity - repo.get_form.return_value = None - return repo - - -def _build_human_input_graph( - runtime_state: GraphRuntimeState, - form_repository: HumanInputFormRepository, -) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="service-api", - call_depth=0, - ) - - start_data = StartNodeData(title="start", variables=[]) - start_node = StartNode( - id="start", - config={"id": "start", "data": start_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="human", - form_content="Awaiting human input", - inputs=[], - user_actions=[ - UserAction(id="continue", title="Continue"), - ], - ) - human_node = HumanInputNode( - id="human", - config={"id": "human", "data": human_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - form_repository=form_repository, - runtime=DifyHumanInputNodeRuntime(params.run_context), - ) - - end_data = EndNodeData( - title="end", - outputs=[ - OutputVariableEntity(variable="result", value_selector=["human", "action_id"]), - ], - desc=None, - ) - end_node = EndNode( - id="end", - config={"id": "end", "data": end_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(human_node) - .add_node(end_node, from_node_id="human", source_handle="continue") - .build() - ) - - -def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]: - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - ) - return list(engine.run()) - - -def _node_successes(events: list[GraphEngineEvent]) -> list[str]: - return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)] - - -def _node_start_event(events: list[GraphEngineEvent], node_id: str) -> NodeRunStartedEvent | None: - for event in events: - if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: - return event - return None - - -def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any: - segment = variable_pool.get(selector) - assert segment is not None - return getattr(segment, "value", segment) - - -def test_engine_resume_restores_state_and_completion(): - # Baseline run without pausing - baseline_state = _build_runtime_state() - baseline_repo = _mock_form_repository_with_submission(action_id="continue") - baseline_graph = _build_human_input_graph(baseline_state, baseline_repo) - baseline_events = _run_graph(baseline_graph, baseline_state) - assert baseline_events - first_paused_event = baseline_events[0] - assert isinstance(first_paused_event, GraphRunStartedEvent) - assert first_paused_event.reason is WorkflowStartReason.INITIAL - assert isinstance(baseline_events[-1], GraphRunSucceededEvent) - baseline_success_nodes = _node_successes(baseline_events) - - # Run with pause - paused_state = _build_runtime_state() - pause_repo = _mock_form_repository_without_submission() - paused_graph = _build_human_input_graph(paused_state, pause_repo) - paused_events = _run_graph(paused_graph, paused_state) - assert paused_events - first_paused_event = paused_events[0] - assert isinstance(first_paused_event, GraphRunStartedEvent) - assert first_paused_event.reason is WorkflowStartReason.INITIAL - assert isinstance(paused_events[-1], GraphRunPausedEvent) - snapshot = paused_state.dumps() - - # Resume from snapshot - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resume_repo = _mock_form_repository_with_submission(action_id="continue") - resumed_graph = _build_human_input_graph(resumed_state, resume_repo) - resumed_events = _run_graph(resumed_graph, resumed_state) - assert resumed_events - first_resumed_event = resumed_events[0] - assert isinstance(first_resumed_event, GraphRunStartedEvent) - assert first_resumed_event.reason is WorkflowStartReason.RESUMPTION - assert isinstance(resumed_events[-1], GraphRunSucceededEvent) - - combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events) - assert combined_success_nodes == baseline_success_nodes - - paused_human_started = _node_start_event(paused_events, "human") - resumed_human_started = _node_start_event(resumed_events, "human") - assert paused_human_started is not None - assert resumed_human_started is not None - assert paused_human_started.id == resumed_human_started.id - - assert baseline_state.outputs == resumed_state.outputs - assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value( - resumed_state.variable_pool, ("human", "__action_id") - ) - assert baseline_state.graph_execution.completed - assert resumed_state.graph_execution.completed diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py deleted file mode 100644 index 146b728dc24..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py +++ /dev/null @@ -1,268 +0,0 @@ -""" -Unit tests for Redis-based stop functionality in GraphEngine. - -Tests the integration of Redis command channel for stopping workflows -without user permission checks. -""" - -import json -from unittest.mock import MagicMock, Mock, patch - -import pytest -import redis - -from core.app.apps.base_app_queue_manager import AppQueueManager -from graphon.graph_engine.command_channels.redis_channel import RedisChannel -from graphon.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand -from graphon.graph_engine.manager import GraphEngineManager - - -class TestRedisStopIntegration: - """Test suite for Redis-based workflow stop functionality.""" - - def test_graph_engine_manager_sends_abort_command(self): - """Test that GraphEngineManager correctly sends abort command through Redis.""" - # Setup - task_id = "test-task-123" - expected_channel_key = f"workflow:{task_id}:commands" - - # Mock redis client - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - manager = GraphEngineManager(mock_redis) - - # Execute - manager.send_stop_command(task_id, reason="Test stop") - - # Verify - mock_redis.pipeline.assert_called_once() - - # Check that rpush was called with correct arguments - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 - - # Verify the channel key - assert calls[0][0][0] == expected_channel_key - - # Verify the command data - command_json = calls[0][0][1] - command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.ABORT - assert command_data["reason"] == "Test stop" - - def test_graph_engine_manager_sends_pause_command(self): - """Test that GraphEngineManager correctly sends pause command through Redis.""" - task_id = "test-task-pause-123" - expected_channel_key = f"workflow:{task_id}:commands" - - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - manager = GraphEngineManager(mock_redis) - manager.send_pause_command(task_id, reason="Awaiting resources") - - mock_redis.pipeline.assert_called_once() - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 - assert calls[0][0][0] == expected_channel_key - - command_json = calls[0][0][1] - command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.PAUSE.value - assert command_data["reason"] == "Awaiting resources" - - def test_graph_engine_manager_handles_redis_failure_gracefully(self): - """Test that GraphEngineManager handles Redis failures without raising exceptions.""" - task_id = "test-task-456" - - # Mock redis client to raise exception - mock_redis = MagicMock() - mock_redis.pipeline.side_effect = redis.ConnectionError("Redis connection failed") - manager = GraphEngineManager(mock_redis) - - # Should not raise exception - try: - manager.send_stop_command(task_id) - except Exception as e: - pytest.fail(f"GraphEngineManager.send_stop_command raised {e} unexpectedly") - - def test_app_queue_manager_no_user_check(self): - """Test that AppQueueManager.set_stop_flag_no_user_check works without user validation.""" - task_id = "test-task-789" - expected_cache_key = f"generate_task_stopped:{task_id}" - - # Mock redis client - mock_redis = MagicMock() - - with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis): - # Execute - AppQueueManager.set_stop_flag_no_user_check(task_id) - - # Verify - mock_redis.setex.assert_called_once_with(expected_cache_key, 600, 1) - - def test_app_queue_manager_no_user_check_with_empty_task_id(self): - """Test that AppQueueManager.set_stop_flag_no_user_check handles empty task_id.""" - # Mock redis client - mock_redis = MagicMock() - - with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis): - # Execute with empty task_id - AppQueueManager.set_stop_flag_no_user_check("") - - # Verify redis was not called - mock_redis.setex.assert_not_called() - - def test_redis_channel_send_abort_command(self): - """Test RedisChannel correctly serializes and sends AbortCommand.""" - # Setup - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - channel_key = "workflow:test:commands" - channel = RedisChannel(mock_redis, channel_key) - - # Create commands - abort_command = AbortCommand(reason="User requested stop") - pause_command = PauseCommand(reason="User requested pause") - - # Execute - channel.send_command(abort_command) - channel.send_command(pause_command) - - # Verify - mock_redis.pipeline.assert_called() - - # Check rpush was called - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 2 - assert calls[0][0][0] == channel_key - assert calls[1][0][0] == channel_key - - # Verify serialized commands - abort_command_json = calls[0][0][1] - abort_command_data = json.loads(abort_command_json) - assert abort_command_data["command_type"] == CommandType.ABORT.value - assert abort_command_data["reason"] == "User requested stop" - - pause_command_json = calls[1][0][1] - pause_command_data = json.loads(pause_command_json) - assert pause_command_data["command_type"] == CommandType.PAUSE.value - assert pause_command_data["reason"] == "User requested pause" - - # Check expire was set for each - assert mock_pipeline.expire.call_count == 2 - mock_pipeline.expire.assert_any_call(channel_key, 3600) - - def test_redis_channel_fetch_commands(self): - """Test RedisChannel correctly fetches and deserializes commands.""" - # Setup - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Mock command data - abort_command_json = json.dumps( - {"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None} - ) - pause_command_json = json.dumps( - {"command_type": CommandType.PAUSE.value, "reason": "Pause requested", "payload": None} - ) - - # Mock pipeline execute to return commands - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [ - [abort_command_json.encode(), pause_command_json.encode()], # lrange result - True, # delete result - ] - - channel_key = "workflow:test:commands" - channel = RedisChannel(mock_redis, channel_key) - - # Execute - commands = channel.fetch_commands() - - # Verify - assert len(commands) == 2 - assert isinstance(commands[0], AbortCommand) - assert commands[0].command_type == CommandType.ABORT - assert commands[0].reason == "Test abort" - assert isinstance(commands[1], PauseCommand) - assert commands[1].command_type == CommandType.PAUSE - assert commands[1].reason == "Pause requested" - - # Verify Redis operations - pending_pipe.get.assert_called_once_with(f"{channel_key}:pending") - pending_pipe.delete.assert_called_once_with(f"{channel_key}:pending") - fetch_pipe.lrange.assert_called_once_with(channel_key, 0, -1) - fetch_pipe.delete.assert_called_once_with(channel_key) - assert mock_redis.pipeline.call_count == 2 - - def test_redis_channel_fetch_commands_handles_invalid_json(self): - """Test RedisChannel gracefully handles invalid JSON in commands.""" - # Setup - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Mock invalid command data - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [ - [b"invalid json", b'{"command_type": "invalid_type"}'], # lrange result - True, # delete result - ] - - channel_key = "workflow:test:commands" - channel = RedisChannel(mock_redis, channel_key) - - # Execute - commands = channel.fetch_commands() - - # Should return empty list due to invalid commands - assert len(commands) == 0 - - def test_dual_stop_mechanism_compatibility(self): - """Test that both stop mechanisms can work together.""" - task_id = "test-task-dual" - - # Mock redis client - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis): - # Execute both stop mechanisms - AppQueueManager.set_stop_flag_no_user_check(task_id) - GraphEngineManager(mock_redis).send_stop_command(task_id) - - # Verify legacy stop flag was set - expected_stop_flag_key = f"generate_task_stopped:{task_id}" - mock_redis.setex.assert_called_once_with(expected_stop_flag_key, 600, 1) - - # Verify command was sent through Redis channel - mock_redis.pipeline.assert_called() - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 - assert calls[0][0][0] == f"workflow:{task_id}:commands" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py b/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py deleted file mode 100644 index 62ca7a630e3..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Unit tests for response session creation.""" - -from __future__ import annotations - -import pytest - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, NodeType -from graphon.graph_engine.response_coordinator.session import ResponseSession -from graphon.nodes.base.template import Template, TextSegment - - -class DummyResponseNode: - """Minimal response-capable node for session tests.""" - - def __init__(self, *, node_id: str, node_type: NodeType, template: Template) -> None: - self.id = node_id - self.node_type = node_type - self.execution_type = NodeExecutionType.RESPONSE - self.state = NodeState.UNKNOWN - self._template = template - - def get_streaming_template(self) -> Template: - return self._template - - -class DummyNodeWithoutStreamingTemplate: - """Minimal node that violates the response-session contract.""" - - def __init__(self, *, node_id: str, node_type: NodeType) -> None: - self.id = node_id - self.node_type = node_type - self.execution_type = NodeExecutionType.RESPONSE - self.state = NodeState.UNKNOWN - - -def test_response_session_from_node_accepts_nodes_outside_previous_allowlist() -> None: - """Session creation depends on the streaming-template contract rather than node type.""" - node = DummyResponseNode( - node_id="llm-node", - node_type=BuiltinNodeTypes.LLM, - template=Template(segments=[TextSegment(text="hello")]), - ) - - session = ResponseSession.from_node(node) - - assert session.node_id == "llm-node" - assert session.template.segments == [TextSegment(text="hello")] - - -def test_response_session_from_node_requires_streaming_template_method() -> None: - """Allowed node types still need to implement the streaming-template contract.""" - node = DummyNodeWithoutStreamingTemplate(node_id="answer-node", node_type=BuiltinNodeTypes.ANSWER) - - with pytest.raises(TypeError, match="get_streaming_template"): - ResponseSession.from_node(node) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py deleted file mode 100644 index a359a5fef98..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py +++ /dev/null @@ -1,79 +0,0 @@ -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_streaming_conversation_variables(): - fixture_name = "test_streaming_conversation_variables" - - # The test expects the workflow to output the input query - # Since the workflow assigns sys.query to conversation variable "str" and then answers with it - input_query = "Hello, this is my test query" - - mock_config = MockConfigBuilder().build() - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=False, # Don't use auto mock since we want to test actual variable assignment - mock_config=mock_config, - query=input_query, # Pass query as the sys.query value - inputs={}, # No additional inputs needed - expected_outputs={"answer": input_query}, # Expecting the input query to be output - expected_event_sequence=[ - GraphRunStartedEvent, - # START node - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Variable Assigner node - NodeRunStartedEvent, - NodeRunVariableUpdatedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - # ANSWER node - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" - - -def test_streaming_conversation_variables_v1_overwrite_waits_for_assignment(): - fixture_name = "test_streaming_conversation_variables_v1_overwrite" - input_query = "overwrite-value" - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=False, - mock_config=MockConfigBuilder().build(), - query=input_query, - inputs={}, - expected_outputs={"answer": f"Current Value Of `conv_var` is:{input_query}"}, - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" - - events = result.events - conv_var_chunk_events = [ - event - for event in events - if isinstance(event, NodeRunStreamChunkEvent) and tuple(event.selector) == ("conversation", "conv_var") - ] - - assert conv_var_chunk_events, "Expected conversation variable chunk events to be emitted" - assert all(event.chunk == input_query for event in conv_var_chunk_events), ( - "Expected streamed conversation variable value to match the input query" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 81d68ba2aac..b11f9576777 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -19,12 +19,7 @@ from functools import lru_cache from pathlib import Path from typing import Any -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.tools.utils.yaml_utils import _load_yaml_file -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool -from graphon.entities.graph_init_params import GraphInitParams +from graphon.entities import GraphInitParams from graphon.graph import Graph from graphon.graph_engine import GraphEngine, GraphEngineConfig from graphon.graph_engine.command_channels import InMemoryChannel @@ -44,6 +39,12 @@ from graphon.variables import ( StringVariable, ) +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.tools.utils.yaml_utils import _load_yaml_file +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool + from .test_mock_config import MockConfig from .test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py deleted file mode 100644 index a7309f64de9..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Validate conversation variable updates inside an iteration workflow. - -This test uses the ``update-conversation-variable-in-iteration`` fixture, which -routes ``sys.query`` into the conversation variable ``answer`` from within an -iteration container. The workflow should surface that updated conversation -variable in the final answer output. - -Code nodes in the fixture are mocked because their concrete outputs are not -relevant to verifying variable propagation semantics. -""" - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_update_conversation_variable_in_iteration(): - fixture_name = "update-conversation-variable-in-iteration" - user_query = "ensure conversation variable syncs" - - mock_config = ( - MockConfigBuilder() - .with_node_output("1759032363865", {"result": [1]}) - .with_node_output("1759032476318", {"result": ""}) - .build() - ) - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - query=user_query, - expected_outputs={"answer": user_query}, - description="Conversation variable updated within iteration should flow to answer output.", - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - - assert result.success, f"Workflow execution failed: {result.error}" - assert result.actual_outputs is not None - assert result.actual_outputs.get("answer") == user_query diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py deleted file mode 100644 index 2ad41037a99..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py +++ /dev/null @@ -1,58 +0,0 @@ -from unittest.mock import patch - -import pytest - -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode - -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -class TestVariableAggregator: - """Test cases for the variable aggregator workflow.""" - - @pytest.mark.parametrize( - ("switch1", "switch2", "expected_group1", "expected_group2", "description"), - [ - (0, 0, "switch 1 off", "switch 2 off", "Both switches off"), - (0, 1, "switch 1 off", "switch 2 on", "Switch1 off, Switch2 on"), - (1, 0, "switch 1 on", "switch 2 off", "Switch1 on, Switch2 off"), - (1, 1, "switch 1 on", "switch 2 on", "Both switches on"), - ], - ) - def test_variable_aggregator_combinations( - self, - switch1: int, - switch2: int, - expected_group1: str, - expected_group2: str, - description: str, - ) -> None: - """Test all four combinations of switch1 and switch2.""" - - def mock_template_transform_run(self): - """Mock the TemplateTransformNode._run() method to return results based on node title.""" - title = self._node_data.title - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"output": title}) - - with patch.object( - TemplateTransformNode, - "_run", - mock_template_transform_run, - ): - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="dual_switch_variable_aggregator_workflow", - inputs={"switch1": switch1, "switch2": switch2}, - expected_outputs={"group1": expected_group1, "group2": expected_group2}, - description=description, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs == test_case.expected_outputs, ( - f"Output mismatch: expected {test_case.expected_outputs}, got {result.actual_outputs}" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py deleted file mode 100644 index 60cab77c0a2..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py +++ /dev/null @@ -1,129 +0,0 @@ -import time -import uuid -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool -from graphon.entities import GraphInitParams -from graphon.graph import Graph -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import NodeRunVariableUpdatedEvent -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import StringVariable - -DEFAULT_NODE_ID = "node_id" - - -class CaptureVariableUpdateLayer(GraphEngineLayer): - def __init__(self) -> None: - super().__init__() - self.events: list[NodeRunVariableUpdatedEvent] = [] - self.observed_values: list[object | None] = [] - - def on_graph_start(self) -> None: - pass - - def on_event(self, event) -> None: - if not isinstance(event, NodeRunVariableUpdatedEvent): - return - - current_value = self.graph_runtime_state.variable_pool.get(event.variable.selector) - self.events.append(event) - self.observed_values.append(None if current_value is None else current_value.value) - - def on_graph_end(self, error: Exception | None) -> None: - pass - - -def test_graph_engine_applies_variable_updates_before_notifying_layers(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "over-write", - "input_variable_selector": ["node_id", "test_string_variable"], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - variable_pool = VariablePool() - add_variables_to_pool( - variable_pool, - build_bootstrap_variables( - system_variables=build_system_variables(conversation_id=str(uuid.uuid4())), - conversation_variables=[ - StringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value="the first value", - ) - ], - ), - ) - variable_pool.add( - [DEFAULT_NODE_ID, "test_string_variable"], - StringVariable( - id=str(uuid4()), - name="test_string_variable", - value="the second value", - ), - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - engine = GraphEngine( - workflow_id="workflow-id", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - capture_layer = CaptureVariableUpdateLayer() - engine.layer(capture_layer) - - events = list(engine.run()) - - update_events = [event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)] - assert len(update_events) == 1 - assert update_events[0].variable.value == "the second value" - - current_value = graph_runtime_state.variable_pool.get(["conversation", "test_conversation_variable"]) - assert current_value is not None - assert current_value.value == "the second value" - - assert len(capture_layer.events) == 1 - assert capture_layer.observed_values == ["the second value"] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py b/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py deleted file mode 100644 index 85132674b87..00000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py +++ /dev/null @@ -1,148 +0,0 @@ -import queue -from collections.abc import Generator -from datetime import UTC, datetime, timedelta -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph_engine.ready_queue import InMemoryReadyQueue -from graphon.graph_engine.worker import Worker -from graphon.graph_events import NodeRunFailedEvent, NodeRunStartedEvent - - -def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None: - fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - mock_datetime = mocker.patch("graphon.graph_engine.worker.datetime") - mock_datetime.now.return_value = fixed_time.replace(tzinfo=UTC) - - worker = Worker( - ready_queue=InMemoryReadyQueue(), - event_queue=queue.Queue(), - graph=MagicMock(), - layers=[], - ) - node = SimpleNamespace( - execution_id="exec-1", - id="node-1", - node_type=BuiltinNodeTypes.LLM, - ) - - event = worker._build_fallback_failure_event(node, RuntimeError("boom")) - - assert event.start_at == fixed_time - assert event.finished_at == fixed_time - assert event.error == "boom" - assert event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED - assert event.node_run_result.error == "boom" - assert event.node_run_result.error_type == "RuntimeError" - - -def test_worker_fallback_failure_event_reuses_observed_start_time() -> None: - start_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - failure_time = start_at + timedelta(seconds=5) - captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = [] - - class FakeNode: - execution_id = "exec-1" - id = "node-1" - node_type = BuiltinNodeTypes.LLM - - def ensure_execution_id(self) -> str: - return self.execution_id - - def run(self) -> Generator[NodeRunStartedEvent, None, None]: - yield NodeRunStartedEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - node_title="LLM", - start_at=start_at, - ) - - worker = Worker( - ready_queue=MagicMock(), - event_queue=MagicMock(), - graph=MagicMock(nodes={"node-1": FakeNode()}), - layers=[], - ) - - worker._ready_queue.get.side_effect = ["node-1"] - - def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None: - captured_events.append(event) - if len(captured_events) == 1: - raise RuntimeError("queue boom") - worker.stop() - - worker._event_queue.put.side_effect = put_side_effect - - with patch("graphon.graph_engine.worker.datetime") as mock_datetime: - mock_datetime.now.return_value = failure_time.replace(tzinfo=UTC) - worker.run() - - fallback_event = captured_events[-1] - - assert isinstance(fallback_event, NodeRunFailedEvent) - assert fallback_event.start_at == start_at - assert fallback_event.finished_at == failure_time - assert fallback_event.error == "queue boom" - assert fallback_event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED - - -def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_times() -> None: - parent_start = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - child_start = parent_start + timedelta(seconds=3) - failure_time = parent_start + timedelta(seconds=5) - captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = [] - - class FakeIterationNode: - execution_id = "iteration-exec" - id = "iteration-node" - node_type = BuiltinNodeTypes.ITERATION - - def ensure_execution_id(self) -> str: - return self.execution_id - - def run(self) -> Generator[NodeRunStartedEvent, None, None]: - yield NodeRunStartedEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - node_title="Iteration", - start_at=parent_start, - ) - yield NodeRunStartedEvent( - id="child-exec", - node_id="child-node", - node_type=BuiltinNodeTypes.LLM, - node_title="LLM", - start_at=child_start, - in_iteration_id=self.id, - ) - - worker = Worker( - ready_queue=MagicMock(), - event_queue=MagicMock(), - graph=MagicMock(nodes={"iteration-node": FakeIterationNode()}), - layers=[], - ) - - worker._ready_queue.get.side_effect = ["iteration-node"] - - def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None: - captured_events.append(event) - if len(captured_events) == 2: - raise RuntimeError("queue boom") - worker.stop() - - worker._event_queue.put.side_effect = put_side_effect - - with patch("graphon.graph_engine.worker.datetime") as mock_datetime: - mock_datetime.now.return_value = failure_time.replace(tzinfo=UTC) - worker.run() - - fallback_event = captured_events[-1] - - assert isinstance(fallback_event, NodeRunFailedEvent) - assert fallback_event.start_at == parent_start - assert fallback_event.finished_at == failure_time diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py index 1f4509af9a1..cbc920705ca 100644 --- a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py @@ -1,8 +1,9 @@ from unittest.mock import patch +from graphon.enums import BuiltinNodeTypes + from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer -from graphon.enums import BuiltinNodeTypes def test_transform_passes_conversation_id_to_tool_file_message_transformer() -> None: diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py index c86de7f6e63..59dd763b59d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py @@ -1,9 +1,10 @@ from types import SimpleNamespace from unittest.mock import Mock, patch -from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport from graphon.model_runtime.entities.model_entities import ModelType +from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport + def test_fetch_model_reuses_single_model_assembly(): provider_configuration = SimpleNamespace( diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 9c0ad25b58e..7195471eb6b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -2,14 +2,15 @@ import time import uuid from unittest.mock import MagicMock +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.answer.answer_node import AnswerNode +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables from extensions.ext_database import db -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.nodes.answer.answer_node import AnswerNode -from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index ec4cef1955a..343bcd39193 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,10 +1,10 @@ import pytest - -from core.workflow.node_factory import get_node_type_classes_mapping from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node +from core.workflow.node_factory import get_node_type_classes_mapping + # Ensures that all production node classes are imported and registered. _ = get_node_type_classes_mapping() diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index ef0df55995e..b9371a34f44 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -1,7 +1,6 @@ import types from collections.abc import Mapping -from core.workflow.node_factory import get_node_type_classes_mapping from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node @@ -14,6 +13,8 @@ from graphon.nodes.variable_assigner.v2.node import ( VariableAssignerNode as VariableAssignerV2, ) +from core.workflow.node_factory import get_node_type_classes_mapping + def test_variable_assigner_latest_prefers_highest_numeric_version(): # Act diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index ce0c9b79c68..d155124c501 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,4 +1,3 @@ -from configs import dify_config from graphon.nodes.code.code_node import CodeNode from graphon.nodes.code.entities import CodeLanguage, CodeNodeData from graphon.nodes.code.exc import ( @@ -9,6 +8,8 @@ from graphon.nodes.code.exc import ( from graphon.nodes.code.limits import CodeNodeLimits from graphon.variables.types import SegmentType +from configs import dify_config + CodeNode._limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, max_number=dify_config.CODE_MAX_NUMBER, diff --git a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py deleted file mode 100644 index 20fe2c1a747..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py +++ /dev/null @@ -1,352 +0,0 @@ -import pytest -from pydantic import ValidationError - -from graphon.nodes.code.entities import CodeLanguage, CodeNodeData -from graphon.variables.types import SegmentType - - -class TestCodeNodeDataOutput: - """Test suite for CodeNodeData.Output model.""" - - def test_output_with_string_type(self): - """Test Output with STRING type.""" - output = CodeNodeData.Output(type=SegmentType.STRING) - - assert output.type == SegmentType.STRING - assert output.children is None - - def test_output_with_number_type(self): - """Test Output with NUMBER type.""" - output = CodeNodeData.Output(type=SegmentType.NUMBER) - - assert output.type == SegmentType.NUMBER - assert output.children is None - - def test_output_with_boolean_type(self): - """Test Output with BOOLEAN type.""" - output = CodeNodeData.Output(type=SegmentType.BOOLEAN) - - assert output.type == SegmentType.BOOLEAN - - def test_output_with_object_type(self): - """Test Output with OBJECT type.""" - output = CodeNodeData.Output(type=SegmentType.OBJECT) - - assert output.type == SegmentType.OBJECT - - def test_output_with_array_string_type(self): - """Test Output with ARRAY_STRING type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_STRING) - - assert output.type == SegmentType.ARRAY_STRING - - def test_output_with_array_number_type(self): - """Test Output with ARRAY_NUMBER type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER) - - assert output.type == SegmentType.ARRAY_NUMBER - - def test_output_with_array_object_type(self): - """Test Output with ARRAY_OBJECT type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_OBJECT) - - assert output.type == SegmentType.ARRAY_OBJECT - - def test_output_with_array_boolean_type(self): - """Test Output with ARRAY_BOOLEAN type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN) - - assert output.type == SegmentType.ARRAY_BOOLEAN - - def test_output_with_nested_children(self): - """Test Output with nested children for OBJECT type.""" - child_output = CodeNodeData.Output(type=SegmentType.STRING) - parent_output = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={"name": child_output}, - ) - - assert parent_output.type == SegmentType.OBJECT - assert parent_output.children is not None - assert "name" in parent_output.children - assert parent_output.children["name"].type == SegmentType.STRING - - def test_output_with_deeply_nested_children(self): - """Test Output with deeply nested children.""" - inner_child = CodeNodeData.Output(type=SegmentType.NUMBER) - middle_child = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={"value": inner_child}, - ) - outer_output = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={"nested": middle_child}, - ) - - assert outer_output.children is not None - assert outer_output.children["nested"].children is not None - assert outer_output.children["nested"].children["value"].type == SegmentType.NUMBER - - def test_output_with_multiple_children(self): - """Test Output with multiple children.""" - output = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - "age": CodeNodeData.Output(type=SegmentType.NUMBER), - "active": CodeNodeData.Output(type=SegmentType.BOOLEAN), - }, - ) - - assert output.children is not None - assert len(output.children) == 3 - assert output.children["name"].type == SegmentType.STRING - assert output.children["age"].type == SegmentType.NUMBER - assert output.children["active"].type == SegmentType.BOOLEAN - - def test_output_rejects_invalid_type(self): - """Test Output rejects invalid segment types.""" - with pytest.raises(ValidationError): - CodeNodeData.Output(type=SegmentType.FILE) - - def test_output_rejects_array_file_type(self): - """Test Output rejects ARRAY_FILE type.""" - with pytest.raises(ValidationError): - CodeNodeData.Output(type=SegmentType.ARRAY_FILE) - - -class TestCodeNodeDataDependency: - """Test suite for CodeNodeData.Dependency model.""" - - def test_dependency_basic(self): - """Test Dependency with name and version.""" - dependency = CodeNodeData.Dependency(name="numpy", version="1.24.0") - - assert dependency.name == "numpy" - assert dependency.version == "1.24.0" - - def test_dependency_with_complex_version(self): - """Test Dependency with complex version string.""" - dependency = CodeNodeData.Dependency(name="pandas", version=">=2.0.0,<3.0.0") - - assert dependency.name == "pandas" - assert dependency.version == ">=2.0.0,<3.0.0" - - def test_dependency_with_empty_version(self): - """Test Dependency with empty version.""" - dependency = CodeNodeData.Dependency(name="requests", version="") - - assert dependency.name == "requests" - assert dependency.version == "" - - -class TestCodeNodeData: - """Test suite for CodeNodeData model.""" - - def test_code_node_data_python3(self): - """Test CodeNodeData with Python3 language.""" - data = CodeNodeData( - title="Test Code Node", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'result': 42}", - outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)}, - ) - - assert data.title == "Test Code Node" - assert data.code_language == CodeLanguage.PYTHON3 - assert data.code == "def main(): return {'result': 42}" - assert "result" in data.outputs - assert data.dependencies is None - - def test_code_node_data_javascript(self): - """Test CodeNodeData with JavaScript language.""" - data = CodeNodeData( - title="JS Code Node", - variables=[], - code_language=CodeLanguage.JAVASCRIPT, - code="function main() { return { result: 'hello' }; }", - outputs={"result": CodeNodeData.Output(type=SegmentType.STRING)}, - ) - - assert data.code_language == CodeLanguage.JAVASCRIPT - assert "result" in data.outputs - assert data.outputs["result"].type == SegmentType.STRING - - def test_code_node_data_with_dependencies(self): - """Test CodeNodeData with dependencies.""" - data = CodeNodeData( - title="Code with Deps", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="import numpy as np\ndef main(): return {'sum': 10}", - outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)}, - dependencies=[ - CodeNodeData.Dependency(name="numpy", version="1.24.0"), - CodeNodeData.Dependency(name="pandas", version="2.0.0"), - ], - ) - - assert data.dependencies is not None - assert len(data.dependencies) == 2 - assert data.dependencies[0].name == "numpy" - assert data.dependencies[1].name == "pandas" - - def test_code_node_data_with_multiple_outputs(self): - """Test CodeNodeData with multiple outputs.""" - data = CodeNodeData( - title="Multi Output", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'name': 'test', 'count': 5, 'items': ['a', 'b']}", - outputs={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - "count": CodeNodeData.Output(type=SegmentType.NUMBER), - "items": CodeNodeData.Output(type=SegmentType.ARRAY_STRING), - }, - ) - - assert len(data.outputs) == 3 - assert data.outputs["name"].type == SegmentType.STRING - assert data.outputs["count"].type == SegmentType.NUMBER - assert data.outputs["items"].type == SegmentType.ARRAY_STRING - - def test_code_node_data_with_object_output(self): - """Test CodeNodeData with nested object output.""" - data = CodeNodeData( - title="Object Output", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'user': {'name': 'John', 'age': 30}}", - outputs={ - "user": CodeNodeData.Output( - type=SegmentType.OBJECT, - children={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - "age": CodeNodeData.Output(type=SegmentType.NUMBER), - }, - ), - }, - ) - - assert data.outputs["user"].type == SegmentType.OBJECT - assert data.outputs["user"].children is not None - assert len(data.outputs["user"].children) == 2 - - def test_code_node_data_with_array_object_output(self): - """Test CodeNodeData with array of objects output.""" - data = CodeNodeData( - title="Array Object Output", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'users': [{'name': 'A'}, {'name': 'B'}]}", - outputs={ - "users": CodeNodeData.Output( - type=SegmentType.ARRAY_OBJECT, - children={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - }, - ), - }, - ) - - assert data.outputs["users"].type == SegmentType.ARRAY_OBJECT - assert data.outputs["users"].children is not None - - def test_code_node_data_empty_code(self): - """Test CodeNodeData with empty code.""" - data = CodeNodeData( - title="Empty Code", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="", - outputs={}, - ) - - assert data.code == "" - assert len(data.outputs) == 0 - - def test_code_node_data_multiline_code(self): - """Test CodeNodeData with multiline code.""" - multiline_code = """ -def main(): - result = 0 - for i in range(10): - result += i - return {'sum': result} -""" - data = CodeNodeData( - title="Multiline Code", - variables=[], - code_language=CodeLanguage.PYTHON3, - code=multiline_code, - outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)}, - ) - - assert "for i in range(10)" in data.code - assert "result += i" in data.code - - def test_code_node_data_with_special_characters_in_code(self): - """Test CodeNodeData with special characters in code.""" - code_with_special = "def main(): return {'msg': 'Hello\\nWorld\\t!'}" - data = CodeNodeData( - title="Special Chars", - variables=[], - code_language=CodeLanguage.PYTHON3, - code=code_with_special, - outputs={"msg": CodeNodeData.Output(type=SegmentType.STRING)}, - ) - - assert "\\n" in data.code - assert "\\t" in data.code - - def test_code_node_data_with_unicode_in_code(self): - """Test CodeNodeData with unicode characters in code.""" - unicode_code = "def main(): return {'greeting': 'ไฝ ๅฅฝไธ–็•Œ'}" - data = CodeNodeData( - title="Unicode Code", - variables=[], - code_language=CodeLanguage.PYTHON3, - code=unicode_code, - outputs={"greeting": CodeNodeData.Output(type=SegmentType.STRING)}, - ) - - assert "ไฝ ๅฅฝไธ–็•Œ" in data.code - - def test_code_node_data_empty_dependencies_list(self): - """Test CodeNodeData with empty dependencies list.""" - data = CodeNodeData( - title="No Deps", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {}", - outputs={}, - dependencies=[], - ) - - assert data.dependencies is not None - assert len(data.dependencies) == 0 - - def test_code_node_data_with_boolean_array_output(self): - """Test CodeNodeData with boolean array output.""" - data = CodeNodeData( - title="Boolean Array", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'flags': [True, False, True]}", - outputs={"flags": CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN)}, - ) - - assert data.outputs["flags"].type == SegmentType.ARRAY_BOOLEAN - - def test_code_node_data_with_number_array_output(self): - """Test CodeNodeData with number array output.""" - data = CodeNodeData( - title="Number Array", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'values': [1, 2, 3, 4, 5]}", - outputs={"values": CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER)}, - ) - - assert data.outputs["values"].type == SegmentType.ARRAY_NUMBER diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index 1d76067ec2b..fb03ae9998d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,7 +1,8 @@ +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent + from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent class _VarSeg: diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py deleted file mode 100644 index f1a48f49b92..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py +++ /dev/null @@ -1,33 +0,0 @@ -from graphon.nodes.http_request import build_http_request_config - - -def test_build_http_request_config_uses_literal_defaults(): - config = build_http_request_config() - - assert config.max_connect_timeout == 10 - assert config.max_read_timeout == 600 - assert config.max_write_timeout == 600 - assert config.max_binary_size == 10 * 1024 * 1024 - assert config.max_text_size == 1 * 1024 * 1024 - assert config.ssl_verify is True - assert config.ssrf_default_max_retries == 3 - - -def test_build_http_request_config_supports_explicit_overrides(): - config = build_http_request_config( - max_connect_timeout=5, - max_read_timeout=30, - max_write_timeout=40, - max_binary_size=2048, - max_text_size=1024, - ssl_verify=False, - ssrf_default_max_retries=8, - ) - - assert config.max_connect_timeout == 5 - assert config.max_read_timeout == 30 - assert config.max_write_timeout == 40 - assert config.max_binary_size == 2048 - assert config.max_text_size == 1024 - assert config.ssl_verify is False - assert config.ssrf_default_max_retries == 8 diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py deleted file mode 100644 index 88895608d98..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py +++ /dev/null @@ -1,233 +0,0 @@ -import json -from unittest.mock import Mock, PropertyMock, patch - -import httpx -import pytest - -from graphon.nodes.http_request.entities import Response - - -@pytest.fixture -def mock_response(): - response = Mock(spec=httpx.Response) - response.headers = {} - return response - - -def test_is_file_with_attachment_disposition(mock_response): - """Test is_file when content-disposition header contains 'attachment'""" - mock_response.headers = {"content-disposition": "attachment; filename=test.pdf", "content-type": "application/pdf"} - response = Response(mock_response) - assert response.is_file - - -def test_is_file_with_filename_disposition(mock_response): - """Test is_file when content-disposition header contains filename parameter""" - mock_response.headers = {"content-disposition": "inline; filename=test.pdf", "content-type": "application/pdf"} - response = Response(mock_response) - assert response.is_file - - -@pytest.mark.parametrize("content_type", ["application/pdf", "image/jpeg", "audio/mp3", "video/mp4"]) -def test_is_file_with_file_content_types(mock_response, content_type): - """Test is_file with various file content types""" - mock_response.headers = {"content-type": content_type} - # Mock binary content - type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) - response = Response(mock_response) - assert response.is_file, f"Content type {content_type} should be identified as a file" - - -@pytest.mark.parametrize( - "content_type", - [ - "application/json", - "application/xml", - "application/javascript", - "application/x-www-form-urlencoded", - "application/yaml", - "application/graphql", - ], -) -def test_text_based_application_types(mock_response, content_type): - """Test common text-based application types are not identified as files""" - mock_response.headers = {"content-type": content_type} - response = Response(mock_response) - assert not response.is_file, f"Content type {content_type} should not be identified as a file" - - -@pytest.mark.parametrize( - ("content", "content_type"), - [ - (b'{"key": "value"}', "application/octet-stream"), - (b"[1, 2, 3]", "application/unknown"), - (b"function test() {}", "application/x-unknown"), - (b"test", "application/binary"), - (b"var x = 1;", "application/data"), - ], -) -def test_content_based_detection(mock_response, content, content_type): - """Test content-based detection for text-like content""" - mock_response.headers = {"content-type": content_type} - type(mock_response).content = PropertyMock(return_value=content) - response = Response(mock_response) - assert not response.is_file, f"Content {content} with type {content_type} should not be identified as a file" - - -@pytest.mark.parametrize( - ("content", "content_type"), - [ - (bytes([0x00, 0xFF] * 512), "application/octet-stream"), - (bytes([0x89, 0x50, 0x4E, 0x47]), "application/unknown"), # PNG magic numbers - (bytes([0xFF, 0xD8, 0xFF]), "application/binary"), # JPEG magic numbers - ], -) -def test_binary_content_detection(mock_response, content, content_type): - """Test content-based detection for binary content""" - mock_response.headers = {"content-type": content_type} - type(mock_response).content = PropertyMock(return_value=content) - response = Response(mock_response) - assert response.is_file, f"Binary content with type {content_type} should be identified as a file" - - -@pytest.mark.parametrize( - ("content_type", "expected_main_type"), - [ - ("x-world/x-vrml", "model"), # VRML 3D model - ("font/ttf", "application"), # TrueType font - ("text/csv", "text"), # CSV text file - ("unknown/xyz", None), # Unknown type - ], -) -def test_mimetype_based_detection(mock_response, content_type, expected_main_type): - """Test detection using mimetypes.guess_type for non-application content types""" - mock_response.headers = {"content-type": content_type} - type(mock_response).content = PropertyMock(return_value=bytes([0x00])) # Dummy content - - with patch("graphon.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type: - # Mock the return value based on expected_main_type - if expected_main_type: - mock_guess_type.return_value = (f"{expected_main_type}/subtype", None) - else: - mock_guess_type.return_value = (None, None) - - response = Response(mock_response) - - # Check if the result matches our expectation - if expected_main_type in ("application", "image", "audio", "video"): - assert response.is_file, f"Content type {content_type} should be identified as a file" - else: - assert not response.is_file, f"Content type {content_type} should not be identified as a file" - - # Verify that guess_type was called - mock_guess_type.assert_called_once() - - -def test_is_file_with_inline_disposition(mock_response): - """Test is_file when content-disposition is 'inline'""" - mock_response.headers = {"content-disposition": "inline", "content-type": "application/pdf"} - # Mock binary content - type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) - response = Response(mock_response) - assert response.is_file - - -def test_is_file_with_no_content_disposition(mock_response): - """Test is_file when no content-disposition header is present""" - mock_response.headers = {"content-type": "application/pdf"} - # Mock binary content - type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) - response = Response(mock_response) - assert response.is_file - - -# UTF-8 Encoding Tests -@pytest.mark.parametrize( - ("content_bytes", "expected_text", "description"), - [ - # Chinese UTF-8 bytes - ( - b'{"message": "\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c"}', - '{"message": "ไฝ ๅฅฝไธ–็•Œ"}', - "Chinese characters UTF-8", - ), - # Japanese UTF-8 bytes - ( - b'{"message": "\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf"}', - '{"message": "ใ“ใ‚“ใซใกใฏ"}', - "Japanese characters UTF-8", - ), - # Korean UTF-8 bytes - ( - b'{"message": "\xec\x95\x88\xeb\x85\x95\xed\x95\x98\xec\x84\xb8\xec\x9a\x94"}', - '{"message": "์•ˆ๋…•ํ•˜์„ธ์š”"}', - "Korean characters UTF-8", - ), - # Arabic UTF-8 - (b'{"text": "\xd9\x85\xd8\xb1\xd8\xad\xd8\xa8\xd8\xa7"}', '{"text": "ู…ุฑุญุจุง"}', "Arabic characters UTF-8"), - # European characters UTF-8 - (b'{"text": "Caf\xc3\xa9 M\xc3\xbcnchen"}', '{"text": "Cafรฉ Mรผnchen"}', "European accented characters"), - # Simple ASCII - (b'{"text": "Hello World"}', '{"text": "Hello World"}', "Simple ASCII text"), - ], -) -def test_text_property_utf8_decoding(mock_response, content_bytes, expected_text, description): - """Test that Response.text properly decodes UTF-8 content with charset_normalizer""" - mock_response.headers = {"content-type": "application/json; charset=utf-8"} - type(mock_response).content = PropertyMock(return_value=content_bytes) - # Mock httpx response.text to return something different (simulating potential encoding issues) - mock_response.text = "incorrect-fallback-text" # To ensure we are not falling back to httpx's text property - - response = Response(mock_response) - - # Our enhanced text property should decode properly using charset_normalizer - assert response.text == expected_text, ( - f"Failed for {description}: got {repr(response.text)}, expected {repr(expected_text)}" - ) - - -def test_text_property_fallback_to_httpx(mock_response): - """Test that Response.text falls back to httpx.text when charset_normalizer fails""" - mock_response.headers = {"content-type": "application/json"} - - # Create malformed UTF-8 bytes - malformed_bytes = b'{"text": "\xff\xfe\x00\x00 invalid"}' - type(mock_response).content = PropertyMock(return_value=malformed_bytes) - - # Mock httpx.text to return some fallback value - fallback_text = '{"text": "fallback"}' - mock_response.text = fallback_text - - response = Response(mock_response) - - # Should fall back to httpx's text when charset_normalizer fails - assert response.text == fallback_text - - -@pytest.mark.parametrize( - ("json_content", "description"), - [ - # JSON with escaped Unicode (like Flask jsonify()) - ('{"message": "\\u4f60\\u597d\\u4e16\\u754c"}', "JSON with escaped Unicode"), - # JSON with mixed escape sequences and UTF-8 - ('{"mixed": "Hello \\u4f60\\u597d"}', "Mixed escaped and regular text"), - # JSON with complex escape sequences - ('{"complex": "\\ud83d\\ude00\\u4f60\\u597d"}', "Emoji and Chinese escapes"), - ], -) -def test_text_property_with_escaped_unicode(mock_response, json_content, description): - """Test Response.text with JSON containing Unicode escape sequences""" - mock_response.headers = {"content-type": "application/json"} - - content_bytes = json_content.encode("utf-8") - type(mock_response).content = PropertyMock(return_value=content_bytes) - mock_response.text = json_content # httpx would return the same for valid UTF-8 - - response = Response(mock_response) - - # Should preserve the escape sequences (valid JSON) - assert response.text == json_content, f"Failed for {description}" - - # The text should be valid JSON that can be parsed back to proper Unicode - parsed = json.loads(response.text) - assert isinstance(parsed, dict), f"Invalid JSON for {description}" diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index be7cc073dba..a5026b40cf6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -1,8 +1,4 @@ import pytest - -from configs import dify_config -from core.helper.ssrf_proxy import ssrf_proxy -from core.workflow.system_variables import default_system_variables from graphon.file.file_manager import file_manager from graphon.nodes.http_request import ( BodyData, @@ -16,6 +12,10 @@ from graphon.nodes.http_request.exc import AuthorizationConfigError from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool +from configs import dify_config +from core.helper.ssrf_proxy import ssrf_proxy +from core.workflow.system_variables import default_system_variables + HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index a3cadc06815..4705b3f76ec 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -3,17 +3,17 @@ from typing import Any import httpx import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response -from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index 1d6a4da7c4a..d16e1233ac9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,6 +1,7 @@ -from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients from graphon.runtime import VariablePool +from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients + def test_render_body_template_replaces_variable_values(): config = EmailDeliveryConfig( diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 5f28a07606d..a2cdbbf132e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -2,14 +2,41 @@ Unit tests for human input node entities. """ +from collections.abc import Mapping +from dataclasses import dataclass, field +from datetime import datetime, timedelta from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock import pytest +from graphon.entities import GraphInitParams +from graphon.node_events import PauseRequestedEvent +from graphon.node_events.node import StreamCompletedEvent +from graphon.nodes.human_input.entities import ( + FormInput, + FormInputDefault, + HumanInputNodeData, + UserAction, +) +from graphon.nodes.human_input.enums import ( + ButtonStyle, + FormInputType, + HumanInputFormStatus, + PlaceholderType, + TimeoutUnit, +) +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.runtime import GraphRuntimeState, VariablePool from pydantic import ValidationError from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from core.repositories.human_input_repository import HumanInputFormRepository +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRecipientEntity, + HumanInputFormRepository, +) from core.workflow.human_input_compat import ( DeliveryMethodType, EmailDeliveryConfig, @@ -23,24 +50,90 @@ from core.workflow.human_input_compat import ( ) from core.workflow.node_runtime import DifyHumanInputNodeRuntime from core.workflow.system_variables import build_system_variables -from graphon.entities import GraphInitParams -from graphon.node_events import PauseRequestedEvent -from graphon.node_events.node import StreamCompletedEvent -from graphon.nodes.human_input.entities import ( - FormInput, - FormInputDefault, - HumanInputNodeData, - UserAction, -) -from graphon.nodes.human_input.enums import ( - ButtonStyle, - FormInputType, - PlaceholderType, - TimeoutUnit, -) -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.runtime import GraphRuntimeState, VariablePool -from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository +from libs.datetime_utils import naive_utc_now + + +@dataclass +class _InMemoryFormEntity(HumanInputFormEntity): + form_id: str + rendered: str + token: str | None = None + action_id: str | None = None + data: Mapping[str, Any] | None = None + is_submitted: bool = False + status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING + expiration: datetime = field(default_factory=lambda: naive_utc_now() + timedelta(days=1)) + + @property + def id(self) -> str: + return self.form_id + + @property + def submission_token(self) -> str | None: + return self.token + + @property + def recipients(self) -> list[HumanInputFormRecipientEntity]: + return [] + + @property + def rendered_content(self) -> str: + return self.rendered + + @property + def selected_action_id(self) -> str | None: + return self.action_id + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return self.data + + @property + def submitted(self) -> bool: + return self.is_submitted + + @property + def status(self) -> HumanInputFormStatus: + return self.status_value + + @property + def expiration_time(self) -> datetime: + return self.expiration + + +class InMemoryHumanInputFormRepository(HumanInputFormRepository): + """Minimal in-memory repository for Dify-owned HumanInputNode behavior tests.""" + + def __init__(self) -> None: + self._form_counter = 0 + self.created_params: list[FormCreateParams] = [] + self.created_forms: list[_InMemoryFormEntity] = [] + self._forms_by_node_id: dict[str, _InMemoryFormEntity] = {} + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + self.created_params.append(params) + self._form_counter += 1 + form_id = f"form-{self._form_counter}" + entity = _InMemoryFormEntity( + form_id=form_id, + rendered=params.rendered_content, + token=f"token-{form_id}", + ) + self.created_forms.append(entity) + self._forms_by_node_id[params.node_id] = entity + return entity + + def get_form(self, node_id: str) -> HumanInputFormEntity | None: + return self._forms_by_node_id.get(node_id) + + def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None: + if not self.created_forms: + raise AssertionError("no form has been created to attach submission data") + entity = self.created_forms[-1] + entity.action_id = action_id + entity.data = form_data or {} + entity.is_submitted = True + entity.status_value = HumanInputFormStatus.SUBMITTED class TestDeliveryMethod: diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index fc4497f010d..52802c7ce1e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,10 +1,7 @@ import datetime from types import SimpleNamespace -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import default_system_variables -from graphon.entities.graph_init_params import GraphInitParams +from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes from graphon.graph_events import ( NodeRunHumanInputFormFilledEvent, @@ -14,6 +11,10 @@ from graphon.graph_events import ( from graphon.nodes.human_input.enums import HumanInputFormStatus from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import default_system_variables from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py deleted file mode 100644 index 8cc91bdb54a..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py +++ /dev/null @@ -1,339 +0,0 @@ -from graphon.nodes.iteration.entities import ( - ErrorHandleMode, - IterationNodeData, - IterationStartNodeData, - IterationState, -) - - -class TestErrorHandleMode: - """Test suite for ErrorHandleMode enum.""" - - def test_terminated_value(self): - """Test TERMINATED enum value.""" - assert ErrorHandleMode.TERMINATED == "terminated" - assert ErrorHandleMode.TERMINATED.value == "terminated" - - def test_continue_on_error_value(self): - """Test CONTINUE_ON_ERROR enum value.""" - assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error" - assert ErrorHandleMode.CONTINUE_ON_ERROR.value == "continue-on-error" - - def test_remove_abnormal_output_value(self): - """Test REMOVE_ABNORMAL_OUTPUT enum value.""" - assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT == "remove-abnormal-output" - assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT.value == "remove-abnormal-output" - - def test_error_handle_mode_is_str_enum(self): - """Test ErrorHandleMode is a string enum.""" - assert isinstance(ErrorHandleMode.TERMINATED, str) - assert isinstance(ErrorHandleMode.CONTINUE_ON_ERROR, str) - assert isinstance(ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, str) - - def test_error_handle_mode_comparison(self): - """Test ErrorHandleMode can be compared with strings.""" - assert ErrorHandleMode.TERMINATED == "terminated" - assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error" - - def test_all_error_handle_modes(self): - """Test all ErrorHandleMode values are accessible.""" - modes = list(ErrorHandleMode) - - assert len(modes) == 3 - assert ErrorHandleMode.TERMINATED in modes - assert ErrorHandleMode.CONTINUE_ON_ERROR in modes - assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT in modes - - -class TestIterationNodeData: - """Test suite for IterationNodeData model.""" - - def test_iteration_node_data_basic(self): - """Test IterationNodeData with basic configuration.""" - data = IterationNodeData( - title="Test Iteration", - iterator_selector=["node1", "output"], - output_selector=["iteration", "result"], - ) - - assert data.title == "Test Iteration" - assert data.iterator_selector == ["node1", "output"] - assert data.output_selector == ["iteration", "result"] - - def test_iteration_node_data_default_values(self): - """Test IterationNodeData default values.""" - data = IterationNodeData( - title="Default Test", - iterator_selector=["start", "items"], - output_selector=["iter", "out"], - ) - - assert data.parent_loop_id is None - assert data.is_parallel is False - assert data.parallel_nums == 10 - assert data.error_handle_mode == ErrorHandleMode.TERMINATED - assert data.flatten_output is True - - def test_iteration_node_data_parallel_mode(self): - """Test IterationNodeData with parallel mode enabled.""" - data = IterationNodeData( - title="Parallel Iteration", - iterator_selector=["node", "list"], - output_selector=["iter", "output"], - is_parallel=True, - parallel_nums=5, - ) - - assert data.is_parallel is True - assert data.parallel_nums == 5 - - def test_iteration_node_data_custom_parallel_nums(self): - """Test IterationNodeData with custom parallel numbers.""" - data = IterationNodeData( - title="Custom Parallel", - iterator_selector=["a", "b"], - output_selector=["c", "d"], - parallel_nums=20, - ) - - assert data.parallel_nums == 20 - - def test_iteration_node_data_continue_on_error(self): - """Test IterationNodeData with continue on error mode.""" - data = IterationNodeData( - title="Continue Error", - iterator_selector=["x", "y"], - output_selector=["z", "w"], - error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR, - ) - - assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR - - def test_iteration_node_data_remove_abnormal_output(self): - """Test IterationNodeData with remove abnormal output mode.""" - data = IterationNodeData( - title="Remove Abnormal", - iterator_selector=["input", "array"], - output_selector=["output", "result"], - error_handle_mode=ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, - ) - - assert data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT - - def test_iteration_node_data_flatten_output_disabled(self): - """Test IterationNodeData with flatten output disabled.""" - data = IterationNodeData( - title="No Flatten", - iterator_selector=["a"], - output_selector=["b"], - flatten_output=False, - ) - - assert data.flatten_output is False - - def test_iteration_node_data_with_parent_loop_id(self): - """Test IterationNodeData with parent loop ID.""" - data = IterationNodeData( - title="Nested Loop", - iterator_selector=["parent", "items"], - output_selector=["child", "output"], - parent_loop_id="parent_loop_123", - ) - - assert data.parent_loop_id == "parent_loop_123" - - def test_iteration_node_data_complex_selectors(self): - """Test IterationNodeData with complex selectors.""" - data = IterationNodeData( - title="Complex Selectors", - iterator_selector=["node1", "output", "data", "items"], - output_selector=["iteration", "result", "value"], - ) - - assert len(data.iterator_selector) == 4 - assert len(data.output_selector) == 3 - - def test_iteration_node_data_all_options(self): - """Test IterationNodeData with all options configured.""" - data = IterationNodeData( - title="Full Config", - iterator_selector=["start", "list"], - output_selector=["end", "result"], - parent_loop_id="outer_loop", - is_parallel=True, - parallel_nums=15, - error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR, - flatten_output=False, - ) - - assert data.title == "Full Config" - assert data.parent_loop_id == "outer_loop" - assert data.is_parallel is True - assert data.parallel_nums == 15 - assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR - assert data.flatten_output is False - - -class TestIterationStartNodeData: - """Test suite for IterationStartNodeData model.""" - - def test_iteration_start_node_data_basic(self): - """Test IterationStartNodeData basic creation.""" - data = IterationStartNodeData(title="Iteration Start") - - assert data.title == "Iteration Start" - - def test_iteration_start_node_data_with_description(self): - """Test IterationStartNodeData with description.""" - data = IterationStartNodeData( - title="Start Node", - desc="This is the start of iteration", - ) - - assert data.title == "Start Node" - assert data.desc == "This is the start of iteration" - - -class TestIterationState: - """Test suite for IterationState model.""" - - def test_iteration_state_default_values(self): - """Test IterationState default values.""" - state = IterationState() - - assert state.outputs == [] - assert state.current_output is None - - def test_iteration_state_with_outputs(self): - """Test IterationState with outputs.""" - state = IterationState(outputs=["result1", "result2", "result3"]) - - assert len(state.outputs) == 3 - assert state.outputs[0] == "result1" - assert state.outputs[2] == "result3" - - def test_iteration_state_with_current_output(self): - """Test IterationState with current output.""" - state = IterationState(current_output="current_value") - - assert state.current_output == "current_value" - - def test_iteration_state_get_last_output_with_outputs(self): - """Test get_last_output with outputs present.""" - state = IterationState(outputs=["first", "second", "last"]) - - result = state.get_last_output() - - assert result == "last" - - def test_iteration_state_get_last_output_empty(self): - """Test get_last_output with empty outputs.""" - state = IterationState(outputs=[]) - - result = state.get_last_output() - - assert result is None - - def test_iteration_state_get_last_output_single(self): - """Test get_last_output with single output.""" - state = IterationState(outputs=["only_one"]) - - result = state.get_last_output() - - assert result == "only_one" - - def test_iteration_state_get_current_output(self): - """Test get_current_output method.""" - state = IterationState(current_output={"key": "value"}) - - result = state.get_current_output() - - assert result == {"key": "value"} - - def test_iteration_state_get_current_output_none(self): - """Test get_current_output when None.""" - state = IterationState() - - result = state.get_current_output() - - assert result is None - - def test_iteration_state_with_complex_outputs(self): - """Test IterationState with complex output types.""" - state = IterationState( - outputs=[ - {"id": 1, "name": "first"}, - {"id": 2, "name": "second"}, - [1, 2, 3], - "string_output", - ] - ) - - assert len(state.outputs) == 4 - assert state.outputs[0] == {"id": 1, "name": "first"} - assert state.outputs[2] == [1, 2, 3] - - def test_iteration_state_with_none_outputs(self): - """Test IterationState with None values in outputs.""" - state = IterationState(outputs=["value1", None, "value3"]) - - assert len(state.outputs) == 3 - assert state.outputs[1] is None - - def test_iteration_state_get_last_output_with_none(self): - """Test get_last_output when last output is None.""" - state = IterationState(outputs=["first", None]) - - result = state.get_last_output() - - assert result is None - - def test_iteration_state_metadata_class(self): - """Test IterationState.MetaData class.""" - metadata = IterationState.MetaData(iterator_length=10) - - assert metadata.iterator_length == 10 - - def test_iteration_state_metadata_different_lengths(self): - """Test IterationState.MetaData with different lengths.""" - metadata1 = IterationState.MetaData(iterator_length=0) - metadata2 = IterationState.MetaData(iterator_length=100) - metadata3 = IterationState.MetaData(iterator_length=1000000) - - assert metadata1.iterator_length == 0 - assert metadata2.iterator_length == 100 - assert metadata3.iterator_length == 1000000 - - def test_iteration_state_outputs_modification(self): - """Test modifying IterationState outputs.""" - state = IterationState(outputs=[]) - - state.outputs.append("new_output") - state.outputs.append("another_output") - - assert len(state.outputs) == 2 - assert state.get_last_output() == "another_output" - - def test_iteration_state_current_output_update(self): - """Test updating current_output.""" - state = IterationState() - - state.current_output = "first_value" - assert state.get_current_output() == "first_value" - - state.current_output = "updated_value" - assert state.get_current_output() == "updated_value" - - def test_iteration_state_with_numeric_outputs(self): - """Test IterationState with numeric outputs.""" - state = IterationState(outputs=[1, 2, 3, 4, 5]) - - assert state.get_last_output() == 5 - assert len(state.outputs) == 5 - - def test_iteration_state_with_boolean_outputs(self): - """Test IterationState with boolean outputs.""" - state = IterationState(outputs=[True, False, True]) - - assert state.get_last_output() is True - assert state.outputs[1] is False diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py deleted file mode 100644 index 58b82aa8933..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py +++ /dev/null @@ -1,438 +0,0 @@ -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from graphon.nodes.iteration.exc import ( - InvalidIteratorValueError, - IterationGraphNotFoundError, - IterationIndexNotFoundError, - IterationNodeError, - IteratorVariableNotFoundError, - StartNodeIdNotFoundError, -) -from graphon.nodes.iteration.iteration_node import IterationNode - - -class TestIterationNodeExceptions: - """Test suite for iteration node exceptions.""" - - def test_iteration_node_error_is_value_error(self): - """Test IterationNodeError inherits from ValueError.""" - error = IterationNodeError("test error") - - assert isinstance(error, ValueError) - assert str(error) == "test error" - - def test_iterator_variable_not_found_error(self): - """Test IteratorVariableNotFoundError.""" - error = IteratorVariableNotFoundError("Iterator variable not found") - - assert isinstance(error, IterationNodeError) - assert isinstance(error, ValueError) - assert "Iterator variable not found" in str(error) - - def test_invalid_iterator_value_error(self): - """Test InvalidIteratorValueError.""" - error = InvalidIteratorValueError("Invalid iterator value") - - assert isinstance(error, IterationNodeError) - assert "Invalid iterator value" in str(error) - - def test_start_node_id_not_found_error(self): - """Test StartNodeIdNotFoundError.""" - error = StartNodeIdNotFoundError("Start node ID not found") - - assert isinstance(error, IterationNodeError) - assert "Start node ID not found" in str(error) - - def test_iteration_graph_not_found_error(self): - """Test IterationGraphNotFoundError.""" - error = IterationGraphNotFoundError("Iteration graph not found") - - assert isinstance(error, IterationNodeError) - assert "Iteration graph not found" in str(error) - - def test_iteration_index_not_found_error(self): - """Test IterationIndexNotFoundError.""" - error = IterationIndexNotFoundError("Iteration index not found") - - assert isinstance(error, IterationNodeError) - assert "Iteration index not found" in str(error) - - def test_exception_with_empty_message(self): - """Test exception with empty message.""" - error = IterationNodeError("") - - assert str(error) == "" - - def test_exception_with_detailed_message(self): - """Test exception with detailed message.""" - error = IteratorVariableNotFoundError("Variable 'items' not found in node 'start_node'") - - assert "items" in str(error) - assert "start_node" in str(error) - - def test_all_exceptions_inherit_from_base(self): - """Test all exceptions inherit from IterationNodeError.""" - exceptions = [ - IteratorVariableNotFoundError("test"), - InvalidIteratorValueError("test"), - StartNodeIdNotFoundError("test"), - IterationGraphNotFoundError("test"), - IterationIndexNotFoundError("test"), - ] - - for exc in exceptions: - assert isinstance(exc, IterationNodeError) - assert isinstance(exc, ValueError) - - -class TestIterationNodeClassAttributes: - """Test suite for IterationNode class attributes.""" - - def test_node_type(self): - """Test IterationNode node_type attribute.""" - assert IterationNode.node_type == BuiltinNodeTypes.ITERATION - - def test_version(self): - """Test IterationNode version method.""" - version = IterationNode.version() - - assert version == "1" - - -class TestIterationNodeDefaultConfig: - """Test suite for IterationNode get_default_config.""" - - def test_get_default_config_returns_dict(self): - """Test get_default_config returns a dictionary.""" - config = IterationNode.get_default_config() - - assert isinstance(config, dict) - - def test_get_default_config_type(self): - """Test get_default_config includes type.""" - config = IterationNode.get_default_config() - - assert config.get("type") == "iteration" - - def test_get_default_config_has_config_section(self): - """Test get_default_config has config section.""" - config = IterationNode.get_default_config() - - assert "config" in config - assert isinstance(config["config"], dict) - - def test_get_default_config_is_parallel_default(self): - """Test get_default_config is_parallel default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["is_parallel"] is False - - def test_get_default_config_parallel_nums_default(self): - """Test get_default_config parallel_nums default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["parallel_nums"] == 10 - - def test_get_default_config_error_handle_mode_default(self): - """Test get_default_config error_handle_mode default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["error_handle_mode"] == ErrorHandleMode.TERMINATED - - def test_get_default_config_flatten_output_default(self): - """Test get_default_config flatten_output default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["flatten_output"] is True - - def test_get_default_config_with_none_filters(self): - """Test get_default_config with None filters.""" - config = IterationNode.get_default_config(filters=None) - - assert config is not None - assert "type" in config - - def test_get_default_config_with_empty_filters(self): - """Test get_default_config with empty filters.""" - config = IterationNode.get_default_config(filters={}) - - assert config is not None - - -class TestIterationNodeInitialization: - """Test suite for IterationNode initialization.""" - - def test_init_node_data_basic(self): - """Test init_node_data with basic configuration.""" - node = IterationNode.__new__(IterationNode) - data = { - "title": "Test Iteration", - "iterator_selector": ["start", "items"], - "output_selector": ["iteration", "result"], - } - - node.init_node_data(data) - - assert node._node_data.title == "Test Iteration" - assert node._node_data.iterator_selector == ["start", "items"] - - def test_init_node_data_with_parallel(self): - """Test init_node_data with parallel configuration.""" - node = IterationNode.__new__(IterationNode) - data = { - "title": "Parallel Iteration", - "iterator_selector": ["node", "list"], - "output_selector": ["out", "result"], - "is_parallel": True, - "parallel_nums": 5, - } - - node.init_node_data(data) - - assert node._node_data.is_parallel is True - assert node._node_data.parallel_nums == 5 - - def test_init_node_data_with_error_handle_mode(self): - """Test init_node_data with error handle mode.""" - node = IterationNode.__new__(IterationNode) - data = { - "title": "Error Handle Test", - "iterator_selector": ["a", "b"], - "output_selector": ["c", "d"], - "error_handle_mode": "continue-on-error", - } - - node.init_node_data(data) - - assert node._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR - - def test_get_title(self): - """Test _get_title method.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="My Iteration", - iterator_selector=["x"], - output_selector=["y"], - ) - - assert node._get_title() == "My Iteration" - - def test_get_description_none(self): - """Test _get_description returns None when not set.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - assert node._get_description() is None - - def test_get_description_with_value(self): - """Test _get_description with value.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - desc="This is a description", - iterator_selector=["a"], - output_selector=["b"], - ) - - assert node._get_description() == "This is a description" - - def test_node_data_property(self): - """Test node_data property returns node data.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Base Test", - iterator_selector=["x"], - output_selector=["y"], - ) - - result = node.node_data - - assert result == node._node_data - - -class TestIterationNodeDataValidation: - """Test suite for IterationNodeData validation scenarios.""" - - def test_valid_iteration_node_data(self): - """Test valid IterationNodeData creation.""" - data = IterationNodeData( - title="Valid Iteration", - iterator_selector=["start", "items"], - output_selector=["end", "result"], - ) - - assert data.title == "Valid Iteration" - - def test_iteration_node_data_with_all_error_modes(self): - """Test IterationNodeData with all error handle modes.""" - modes = [ - ErrorHandleMode.TERMINATED, - ErrorHandleMode.CONTINUE_ON_ERROR, - ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, - ] - - for mode in modes: - data = IterationNodeData( - title=f"Test {mode}", - iterator_selector=["a"], - output_selector=["b"], - error_handle_mode=mode, - ) - assert data.error_handle_mode == mode - - def test_iteration_node_data_parallel_configuration(self): - """Test IterationNodeData parallel configuration combinations.""" - configs = [ - (False, 10), - (True, 1), - (True, 5), - (True, 20), - (True, 100), - ] - - for is_parallel, parallel_nums in configs: - data = IterationNodeData( - title="Parallel Test", - iterator_selector=["x"], - output_selector=["y"], - is_parallel=is_parallel, - parallel_nums=parallel_nums, - ) - assert data.is_parallel == is_parallel - assert data.parallel_nums == parallel_nums - - def test_iteration_node_data_flatten_output_options(self): - """Test IterationNodeData flatten_output options.""" - data_flatten = IterationNodeData( - title="Flatten True", - iterator_selector=["a"], - output_selector=["b"], - flatten_output=True, - ) - - data_no_flatten = IterationNodeData( - title="Flatten False", - iterator_selector=["a"], - output_selector=["b"], - flatten_output=False, - ) - - assert data_flatten.flatten_output is True - assert data_no_flatten.flatten_output is False - - def test_iteration_node_data_complex_selectors(self): - """Test IterationNodeData with complex selectors.""" - data = IterationNodeData( - title="Complex", - iterator_selector=["node1", "output", "data", "items", "list"], - output_selector=["iteration", "result", "value", "final"], - ) - - assert len(data.iterator_selector) == 5 - assert len(data.output_selector) == 4 - - def test_iteration_node_data_single_element_selectors(self): - """Test IterationNodeData with single element selectors.""" - data = IterationNodeData( - title="Single", - iterator_selector=["items"], - output_selector=["result"], - ) - - assert len(data.iterator_selector) == 1 - assert len(data.output_selector) == 1 - - -class TestIterationNodeErrorStrategies: - """Test suite for IterationNode error strategies.""" - - def test_get_error_strategy_default(self): - """Test _get_error_strategy with default value.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - result = node._get_error_strategy() - - assert result is None or result == node._node_data.error_strategy - - def test_get_retry_config(self): - """Test _get_retry_config method.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - result = node._get_retry_config() - - assert result is not None - - def test_get_default_value_dict(self): - """Test _get_default_value_dict method.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - result = node._get_default_value_dict() - - assert isinstance(result, dict) - - -def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: - seen_configs: list[object] = [] - original_validate_python = NodeConfigDictAdapter.validate_python - - def record_validate_python(value: object): - seen_configs.append(value) - return original_validate_python(value) - - monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) - - child_node_config = { - "id": "answer-node", - "data": { - "type": "answer", - "title": "Answer", - "answer": "", - "iteration_id": "iteration-node", - }, - } - - IterationNode._extract_variable_selector_to_variable_mapping( - graph_config={ - "nodes": [ - { - "id": "iteration-node", - "data": { - "type": "iteration", - "title": "Iteration", - "iterator_selector": ["start", "items"], - "output_selector": ["iteration", "result"], - }, - }, - child_node_config, - ], - "edges": [], - }, - node_id="iteration-node", - node_data=IterationNodeData( - title="Iteration", - iterator_selector=["start", "items"], - output_selector=["iteration", "result"], - ), - ) - - assert seen_configs == [child_node_config] diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py deleted file mode 100644 index 4c3ad85fcd7..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py +++ /dev/null @@ -1,201 +0,0 @@ -from threading import Event -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph_events import GraphRunAbortedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import IterationFailedEvent, IterationStartedEvent, StreamCompletedEvent -from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from graphon.nodes.iteration.exc import ChildGraphAbortedError -from graphon.nodes.iteration.iteration_node import IterationNode -from tests.workflow_test_utils import build_test_variable_pool - - -def _usage_with_tokens(total_tokens: int) -> LLMUsage: - usage = LLMUsage.empty_usage() - usage.total_tokens = total_tokens - return usage - - -class _AbortOnRequestGraphEngine: - def __init__(self, *, index: int, total_tokens: int) -> None: - variable_pool = build_test_variable_pool() - variable_pool.add(["iteration-node", "index"], index) - - self.started = Event() - self.abort_requested = Event() - self.finished = Event() - self.abort_reason: str | None = None - self.graph_runtime_state = SimpleNamespace( - variable_pool=variable_pool, - llm_usage=_usage_with_tokens(total_tokens), - ) - - def request_abort(self, reason: str | None = None) -> None: - self.abort_reason = reason - self.abort_requested.set() - - def run(self): - self.started.set() - assert self.abort_requested.wait(1), "parallel sibling never received an abort request" - self.finished.set() - yield GraphRunAbortedEvent(reason=self.abort_reason) - - -def _build_immediate_abort_graph_engine( - *, - index: int, - total_tokens: int, - wait_before_abort: Event | None = None, -) -> SimpleNamespace: - variable_pool = build_test_variable_pool() - variable_pool.add(["iteration-node", "index"], index) - - started = Event() - finished = Event() - - def run(): - started.set() - if wait_before_abort is not None: - assert wait_before_abort.wait(1), "parallel sibling never started" - finished.set() - yield GraphRunAbortedEvent(reason="quota exceeded") - - return SimpleNamespace( - graph_runtime_state=SimpleNamespace( - variable_pool=variable_pool, - llm_usage=_usage_with_tokens(total_tokens), - ), - run=run, - request_abort=lambda reason=None: None, - started=started, - finished=finished, - ) - - -def _build_iteration_node( - *, - error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED, - is_parallel: bool = False, -) -> IterationNode: - node = IterationNode.__new__(IterationNode) - node._node_id = "iteration-node" - node._node_data = IterationNodeData( - title="Iteration", - iterator_selector=["start", "items"], - output_selector=["iteration-node", "output"], - start_node_id="child-start", - is_parallel=is_parallel, - parallel_nums=2, - error_handle_mode=error_handle_mode, - ) - - variable_pool = build_test_variable_pool() - variable_pool.add(["start", "items"], ["first", "second"]) - node.graph_runtime_state = SimpleNamespace( - variable_pool=variable_pool, - llm_usage=LLMUsage.empty_usage(), - ) - return node - - -def test_run_single_iter_raises_child_graph_aborted_error_on_abort_event() -> None: - node = _build_iteration_node() - variable_pool = build_test_variable_pool() - variable_pool.add(["iteration-node", "index"], 0) - graph_engine = SimpleNamespace( - run=lambda: iter([GraphRunAbortedEvent(reason="quota exceeded")]), - ) - - with pytest.raises(ChildGraphAbortedError, match="quota exceeded"): - list( - node._run_single_iter( - variable_pool=variable_pool, - outputs=[], - graph_engine=graph_engine, - ) - ) - - -def test_iteration_run_fails_on_sequential_child_abort() -> None: - node = _build_iteration_node(error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR) - graph_engine = SimpleNamespace( - graph_runtime_state=SimpleNamespace( - variable_pool=build_test_variable_pool(), - llm_usage=LLMUsage.empty_usage(), - ) - ) - node._create_graph_engine = MagicMock(return_value=graph_engine) - node._run_single_iter = MagicMock(side_effect=ChildGraphAbortedError("quota exceeded")) - - events = list(node._run()) - - assert isinstance(events[0], IterationStartedEvent) - assert isinstance(events[-2], IterationFailedEvent) - assert events[-2].error == "quota exceeded" - assert isinstance(events[-1], StreamCompletedEvent) - assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED - assert events[-1].node_run_result.error == "quota exceeded" - node._create_graph_engine.assert_called_once() - node._run_single_iter.assert_called_once() - - -def test_iteration_run_merges_child_usage_before_failing_on_sequential_child_abort() -> None: - node = _build_iteration_node(error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR) - graph_engine = SimpleNamespace( - graph_runtime_state=SimpleNamespace( - variable_pool=build_test_variable_pool(), - llm_usage=_usage_with_tokens(7), - ) - ) - node._create_graph_engine = MagicMock(return_value=graph_engine) - node._run_single_iter = MagicMock(side_effect=ChildGraphAbortedError("quota exceeded")) - - events = list(node._run()) - - assert isinstance(events[-1], StreamCompletedEvent) - assert events[-1].node_run_result.llm_usage.total_tokens == 7 - assert node.graph_runtime_state.llm_usage.total_tokens == 7 - - -@pytest.mark.parametrize( - "error_handle_mode", - [ - ErrorHandleMode.CONTINUE_ON_ERROR, - ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, - ], -) -def test_iteration_run_fails_on_parallel_child_abort_regardless_of_error_mode( - error_handle_mode: ErrorHandleMode, -) -> None: - node = _build_iteration_node( - error_handle_mode=error_handle_mode, - is_parallel=True, - ) - blocking_engine = _AbortOnRequestGraphEngine(index=1, total_tokens=5) - aborting_engine = _build_immediate_abort_graph_engine( - index=0, - total_tokens=3, - wait_before_abort=blocking_engine.started, - ) - node._create_graph_engine = MagicMock( - side_effect=lambda index, item: {0: aborting_engine, 1: blocking_engine}[index] - ) - - events = list(node._run()) - - assert isinstance(events[0], IterationStartedEvent) - assert isinstance(events[-2], IterationFailedEvent) - assert events[-2].error == "quota exceeded" - assert isinstance(events[-1], StreamCompletedEvent) - assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED - assert events[-1].node_run_result.error == "quota exceeded" - assert events[-1].node_run_result.llm_usage.total_tokens == 8 - assert node.graph_runtime_state.llm_usage.total_tokens == 8 - assert blocking_engine.started.is_set() - assert blocking_engine.abort_requested.is_set() - assert blocking_engine.finished.is_set() - assert blocking_engine.abort_reason == "quota exceeded" diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index 82cc734274b..bbfe350f7e4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -2,8 +2,6 @@ from collections.abc import Mapping from typing import Any import pytest - -from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams from graphon.nodes.iteration.exc import IterationGraphNotFoundError from graphon.nodes.iteration.iteration_node import IterationNode @@ -13,6 +11,8 @@ from graphon.runtime import ( GraphRuntimeState, VariablePool, ) + +from core.workflow.system_variables import default_system_variables from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py deleted file mode 100644 index 41d7c3193d7..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py +++ /dev/null @@ -1,67 +0,0 @@ -import time -from datetime import UTC, datetime - -import pytest - -from graphon.enums import BuiltinNodeTypes -from graphon.graph_events import NodeRunSucceededEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from graphon.nodes.iteration.iteration_node import IterationNode - - -def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Parallel Iteration", - iterator_selector=["start", "items"], - output_selector=["iteration", "output"], - is_parallel=True, - parallel_nums=2, - error_handle_mode=ErrorHandleMode.TERMINATED, - ) - node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new) - - def fake_execute_tracked_iteration_parallel( - *, - index: int, - item: object, - started_child_engines: dict[int, object], - started_child_engines_lock: object, - ): - _ = started_child_engines - _ = started_child_engines_lock - return ( - 0.1 + (index * 0.1), - [ - NodeRunSucceededEvent( - id=f"exec-{index}", - node_id=f"llm-{index}", - node_type=BuiltinNodeTypes.LLM, - start_at=datetime.now(UTC).replace(tzinfo=None), - ), - ], - f"output-{item}", - LLMUsage.empty_usage(), - ) - - node._execute_tracked_iteration_parallel = fake_execute_tracked_iteration_parallel - - outputs: list[object] = [] - iter_run_map: dict[str, float] = {} - usage_accumulator = [LLMUsage.empty_usage()] - - generator = node._execute_parallel_iterations( - iterator_list_value=["a", "b"], - outputs=outputs, - iter_run_map=iter_run_map, - usage_accumulator=usage_accumulator, - ) - - for _ in generator: - # Simulate a slow consumer replaying buffered events. - time.sleep(0.02) - - assert outputs == ["output-a", "output-b"] - assert iter_run_map["0"] == pytest.approx(0.1) - assert iter_run_map["1"] == pytest.approx(0.2) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index a6fca1bfb40..f8802138b58 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -3,6 +3,9 @@ import uuid from unittest.mock import Mock import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.segments import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -16,9 +19,6 @@ from core.workflow.nodes.knowledge_index.protocols import ( SummaryIndexServiceProtocol, ) from core.workflow.system_variables import SystemVariableKey, build_system_variables -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables.segments import StringSegment from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 45e8ae7d208..ab64be59ad2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -3,6 +3,10 @@ import uuid from unittest.mock import Mock import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.nodes.knowledge_retrieval.entities import ( @@ -17,10 +21,6 @@ from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source from core.workflow.system_variables import build_system_variables -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import StringSegment from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index eca34f05beb..fdf1706765a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,14 +1,14 @@ from unittest.mock import MagicMock import pytest - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.nodes.list_operator.node import ListOperatorNode from graphon.runtime import GraphRuntimeState from graphon.variables import ArrayNumberSegment, ArrayStringSegment +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY + class TestListOperatorNode: """Comprehensive tests for ListOperatorNode.""" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py deleted file mode 100644 index 4f9ba0194af..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ /dev/null @@ -1,170 +0,0 @@ -import uuid -from typing import NamedTuple -from unittest import mock -from unittest.mock import MagicMock - -import httpx -import pytest - -from graphon.file import FileTransferMethod, FileType -from graphon.nodes.llm.file_saver import ( - FileSaverImpl, - _extract_content_type_and_extension, - _get_extension, - _validate_extension_override, -) -from graphon.nodes.protocols import ToolFileManagerProtocol - -_PNG_DATA = b"\x89PNG\r\n\x1a\n" - - -def _gen_id(): - return str(uuid.uuid4()) - - -class TestFileSaverImpl: - def test_save_binary_string(self, monkeypatch: pytest.MonkeyPatch): - file_type = FileType.IMAGE - mime_type = "image/png" - mock_tool_file = MagicMock() - mock_tool_file.id = _gen_id() - mock_tool_file.name = f"{_gen_id()}.png" - mock_tool_file.file_key = "test-file-key" - mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManagerProtocol) - mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file - file_reference = MagicMock() - file_reference_factory = MagicMock() - file_reference_factory.build_from_mapping.return_value = file_reference - http_client = MagicMock() - - file_saver = FileSaverImpl( - tool_file_manager=mocked_tool_file_manager, - file_reference_factory=file_reference_factory, - http_client=http_client, - ) - - file = file_saver.save_binary_string(_PNG_DATA, mime_type, file_type) - assert file is file_reference - - mocked_tool_file_manager.create_file_by_raw.assert_called_once_with( - file_binary=_PNG_DATA, - mimetype=mime_type, - ) - file_reference_factory.build_from_mapping.assert_called_once_with( - mapping={ - "type": file_type, - "transfer_method": FileTransferMethod.TOOL_FILE, - "filename": mock_tool_file.name, - "extension": ".png", - "mime_type": mime_type, - "size": len(_PNG_DATA), - "tool_file_id": mock_tool_file.id, - "related_id": mock_tool_file.id, - "storage_key": mock_tool_file.file_key, - } - ) - - def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch): - _TEST_URL = "https://example.com/image.png" - mock_request = httpx.Request("GET", _TEST_URL) - mock_response = httpx.Response( - status_code=401, - request=mock_request, - ) - http_client = MagicMock() - http_client.get.return_value = mock_response - - file_saver = FileSaverImpl( - tool_file_manager=MagicMock(), - file_reference_factory=MagicMock(), - http_client=http_client, - ) - - with pytest.raises(httpx.HTTPStatusError) as exc: - file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) - http_client.get.assert_called_once_with(_TEST_URL) - assert exc.value.response.status_code == 401 - - def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch): - _TEST_URL = "https://example.com/image.png" - mime_type = "image/png" - - mock_request = httpx.Request("GET", _TEST_URL) - mock_response = httpx.Response( - status_code=200, - content=b"test-data", - headers={"Content-Type": mime_type}, - request=mock_request, - ) - http_client = MagicMock() - http_client.get.return_value = mock_response - - file_saver = FileSaverImpl( - tool_file_manager=MagicMock(), - file_reference_factory=MagicMock(), - http_client=http_client, - ) - expected_file = MagicMock() - mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=expected_file) - monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string) - - file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) - mock_save_binary_string.assert_called_once_with( - mock_response.content, - mime_type, - FileType.IMAGE, - extension_override=".png", - ) - assert file is expected_file - - -def test_validate_extension_override(): - class TestCase(NamedTuple): - extension_override: str | None - expected: str | None - - cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"] - - for valid_ext_override in [None, "", ".png", ".tar.gz"]: - assert valid_ext_override == _validate_extension_override(valid_ext_override) - - for invalid_ext_override in ["png", "tar.gz"]: - with pytest.raises(ValueError) as exc: - _validate_extension_override(invalid_ext_override) - - -class TestExtractContentTypeAndExtension: - def test_with_both_content_type_and_extension(self): - content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png") - assert content_type == "image/png" - assert extension == ".png" - - def test_url_with_file_extension(self): - for content_type in [None, ""]: - content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type) - assert content_type == "image/png" - assert extension == ".png" - - def test_response_with_content_type(self): - content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png") - assert content_type == "image/png" - assert extension == ".png" - - def test_no_content_type_and_no_extension(self): - for content_type in [None, ""]: - content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type) - assert content_type == "application/octet-stream" - assert extension == ".bin" - - -class TestGetExtension: - def test_with_extension_override(self): - mime_type = "image/png" - for override in [".jpg", ""]: - extension = _get_extension(mime_type, override) - assert extension == override - - def test_without_extension_override(self): - mime_type = "image/png" - extension = _get_extension(mime_type) - assert extension == ".png" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index dfc982f49cf..c784f805c01 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -1,10 +1,7 @@ from unittest import mock import pytest - -from core.model_manager import ModelInstance -from graphon.file import FileTransferMethod, FileType -from graphon.file.models import File +from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities import ( ImagePromptMessageContent, PromptMessageRole, @@ -36,6 +33,8 @@ from graphon.nodes.llm.exc import ( from graphon.runtime import VariablePool from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment +from core.model_manager import ModelInstance + def _build_model_schema( *, diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index a2fbc50392d..a215e9d350d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -4,19 +4,6 @@ from collections.abc import Sequence from unittest import mock import pytest - -from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom -from core.app.llm.model_access import ( - DifyCredentialsProvider, - DifyModelFactory, - build_dify_model_access, - fetch_model_config, -) -from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.common_entities import I18nObject @@ -79,6 +66,19 @@ from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.runtime import GraphRuntimeState, VariablePool from graphon.template_rendering import TemplateRenderError from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment + +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom +from core.app.llm.model_access import ( + DifyCredentialsProvider, + DifyModelFactory, + build_dify_model_access, + fetch_model_config, +) +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.workflow.system_variables import default_system_variables from models.provider import ProviderType from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py deleted file mode 100644 index af1cff4e81f..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py +++ /dev/null @@ -1,25 +0,0 @@ -from collections.abc import Mapping, Sequence - -from pydantic import BaseModel, Field - -from graphon.file import File -from graphon.model_runtime.entities.message_entities import PromptMessage -from graphon.model_runtime.entities.model_entities import ModelFeature -from graphon.nodes.llm.entities import LLMNodeChatModelMessage - - -class LLMNodeTestScenario(BaseModel): - """Test scenario for LLM node testing.""" - - description: str = Field(..., description="Description of the test scenario") - sys_query: str = Field(..., description="User query input") - sys_files: Sequence[File] = Field(default_factory=list, description="List of user files") - vision_enabled: bool = Field(default=False, description="Whether vision is enabled") - vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") - features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") - window_size: int = Field(..., description="Window size for memory") - prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") - file_variables: Mapping[str, File | Sequence[File]] = Field( - default_factory=dict, description="List of file variables" - ) - expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing") diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py deleted file mode 100644 index ccf1077838a..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py +++ /dev/null @@ -1,27 +0,0 @@ -from graphon.nodes.parameter_extractor.entities import ParameterConfig -from graphon.variables.types import SegmentType - - -class TestParameterConfig: - def test_select_type(self): - data = { - "name": "yes_or_no", - "type": "select", - "options": ["yes", "no"], - "description": "a simple select made of `yes` and `no`", - "required": True, - } - - pc = ParameterConfig.model_validate(data) - assert pc.type == SegmentType.STRING - assert pc.options == data["options"] - - def test_validate_bool_type(self): - data = { - "name": "boolean", - "type": "bool", - "description": "a simple boolean parameter", - "required": True, - } - pc = ParameterConfig.model_validate(data) - assert pc.type == SegmentType.BOOLEAN diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index 8f8ec49f141..1c362a0a037 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -6,8 +6,6 @@ from dataclasses import dataclass from typing import Any import pytest - -from factories.variable_factory import build_segment_with_type from graphon.model_runtime.entities import LLMMode from graphon.nodes.llm import ModelConfig, VisionConfig from graphon.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData @@ -20,6 +18,8 @@ from graphon.nodes.parameter_extractor.exc import ( from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from graphon.variables.types import SegmentType +from factories.variable_factory import build_segment_with_type + @dataclass class ValidTestCase: diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py deleted file mode 100644 index 01878ed692d..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py +++ /dev/null @@ -1,225 +0,0 @@ -import pytest -from pydantic import ValidationError - -from graphon.enums import ErrorStrategy -from graphon.nodes.template_transform.entities import TemplateTransformNodeData - - -class TestTemplateTransformNodeData: - """Test suite for TemplateTransformNodeData entity.""" - - def test_valid_template_transform_node_data(self): - """Test creating valid TemplateTransformNodeData.""" - data = { - "title": "Template Transform", - "desc": "Transform data using Jinja2 template", - "variables": [ - {"variable": "name", "value_selector": ["sys", "user_name"]}, - {"variable": "age", "value_selector": ["sys", "user_age"]}, - ], - "template": "Hello {{ name }}, you are {{ age }} years old!", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.title == "Template Transform" - assert node_data.desc == "Transform data using Jinja2 template" - assert len(node_data.variables) == 2 - assert node_data.variables[0].variable == "name" - assert node_data.variables[0].value_selector == ["sys", "user_name"] - assert node_data.variables[1].variable == "age" - assert node_data.variables[1].value_selector == ["sys", "user_age"] - assert node_data.template == "Hello {{ name }}, you are {{ age }} years old!" - - def test_template_transform_node_data_with_empty_variables(self): - """Test TemplateTransformNodeData with no variables.""" - data = { - "title": "Static Template", - "variables": [], - "template": "This is a static template with no variables.", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.title == "Static Template" - assert len(node_data.variables) == 0 - assert node_data.template == "This is a static template with no variables." - - def test_template_transform_node_data_with_complex_template(self): - """Test TemplateTransformNodeData with complex Jinja2 template.""" - data = { - "title": "Complex Template", - "variables": [ - {"variable": "items", "value_selector": ["sys", "item_list"]}, - {"variable": "total", "value_selector": ["sys", "total_count"]}, - ], - "template": ( - "{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}. Total: {{ total }}" - ), - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.title == "Complex Template" - assert len(node_data.variables) == 2 - assert "{% for item in items %}" in node_data.template - assert "{{ total }}" in node_data.template - - def test_template_transform_node_data_with_error_strategy(self): - """Test TemplateTransformNodeData with error handling strategy.""" - data = { - "title": "Template with Error Handling", - "variables": [{"variable": "value", "value_selector": ["sys", "input"]}], - "template": "{{ value }}", - "error_strategy": "fail-branch", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.error_strategy == ErrorStrategy.FAIL_BRANCH - - def test_template_transform_node_data_with_retry_config(self): - """Test TemplateTransformNodeData with retry configuration.""" - data = { - "title": "Template with Retry", - "variables": [{"variable": "data", "value_selector": ["sys", "data"]}], - "template": "{{ data }}", - "retry_config": {"enabled": True, "max_retries": 3, "retry_interval": 1000}, - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.retry_config.enabled is True - assert node_data.retry_config.max_retries == 3 - assert node_data.retry_config.retry_interval == 1000 - - def test_template_transform_node_data_missing_required_fields(self): - """Test that missing required fields raises ValidationError.""" - data = { - "title": "Incomplete Template", - # Missing 'variables' and 'template' - } - - with pytest.raises(ValidationError) as exc_info: - TemplateTransformNodeData.model_validate(data) - - errors = exc_info.value.errors() - assert len(errors) >= 2 - error_fields = {error["loc"][0] for error in errors} - assert "variables" in error_fields - assert "template" in error_fields - - def test_template_transform_node_data_invalid_variable_selector(self): - """Test that invalid variable selector format raises ValidationError.""" - data = { - "title": "Invalid Variable", - "variables": [ - {"variable": "name", "value_selector": "invalid_format"} # Should be list - ], - "template": "{{ name }}", - } - - with pytest.raises(ValidationError): - TemplateTransformNodeData.model_validate(data) - - def test_template_transform_node_data_with_default_value_dict(self): - """Test TemplateTransformNodeData with default value dictionary.""" - data = { - "title": "Template with Defaults", - "variables": [ - {"variable": "name", "value_selector": ["sys", "user_name"]}, - {"variable": "greeting", "value_selector": ["sys", "greeting"]}, - ], - "template": "{{ greeting }} {{ name }}!", - "default_value_dict": {"greeting": "Hello", "name": "Guest"}, - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.default_value_dict == {"greeting": "Hello", "name": "Guest"} - - def test_template_transform_node_data_with_nested_selectors(self): - """Test TemplateTransformNodeData with nested variable selectors.""" - data = { - "title": "Nested Selectors", - "variables": [ - {"variable": "user_info", "value_selector": ["sys", "user", "profile", "name"]}, - {"variable": "settings", "value_selector": ["sys", "config", "app", "theme"]}, - ], - "template": "User: {{ user_info }}, Theme: {{ settings }}", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert len(node_data.variables) == 2 - assert node_data.variables[0].value_selector == ["sys", "user", "profile", "name"] - assert node_data.variables[1].value_selector == ["sys", "config", "app", "theme"] - - def test_template_transform_node_data_with_multiline_template(self): - """Test TemplateTransformNodeData with multiline template.""" - data = { - "title": "Multiline Template", - "variables": [ - {"variable": "title", "value_selector": ["sys", "title"]}, - {"variable": "content", "value_selector": ["sys", "content"]}, - ], - "template": """ -# {{ title }} - -{{ content }} - ---- -Generated by Template Transform Node - """, - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert "# {{ title }}" in node_data.template - assert "{{ content }}" in node_data.template - assert "Generated by Template Transform Node" in node_data.template - - def test_template_transform_node_data_serialization(self): - """Test that TemplateTransformNodeData can be serialized and deserialized.""" - original_data = { - "title": "Serialization Test", - "desc": "Test serialization", - "variables": [{"variable": "test", "value_selector": ["sys", "test"]}], - "template": "{{ test }}", - } - - node_data = TemplateTransformNodeData.model_validate(original_data) - serialized = node_data.model_dump() - deserialized = TemplateTransformNodeData.model_validate(serialized) - - assert deserialized.title == node_data.title - assert deserialized.desc == node_data.desc - assert len(deserialized.variables) == len(node_data.variables) - assert deserialized.template == node_data.template - - def test_template_transform_node_data_with_special_characters(self): - """Test TemplateTransformNodeData with special characters in template.""" - data = { - "title": "Special Characters", - "variables": [{"variable": "text", "value_selector": ["sys", "input"]}], - "template": "Special: {{ text }} | Symbols: @#$%^&*() | Unicode: ไฝ ๅฅฝ ๐ŸŽ‰", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert "@#$%^&*()" in node_data.template - assert "ไฝ ๅฅฝ" in node_data.template - assert "๐ŸŽ‰" in node_data.template - - def test_template_transform_node_data_empty_template(self): - """Test TemplateTransformNodeData with empty template string.""" - data = { - "title": "Empty Template", - "variables": [], - "template": "", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.template == "" - assert len(node_data.variables) == 0 diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index bc44ececd88..d86e0efe023 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -1,8 +1,6 @@ from unittest.mock import MagicMock import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.nodes.base.entities import VariableSelector @@ -10,6 +8,8 @@ from graphon.nodes.template_transform.entities import TemplateTransformNodeData from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode from graphon.runtime import GraphRuntimeState from graphon.template_rendering import TemplateRenderError + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py index 636237e56e0..bd22a8e318c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py @@ -1,14 +1,14 @@ from unittest.mock import MagicMock import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.nodes.base.entities import VariableSelector from graphon.nodes.template_transform.template_transform_node import ( DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH, TemplateTransformNode, ) from graphon.runtime import GraphRuntimeState + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params from .template_transform_node_spec import TestTemplateTransformNode # noqa: F401 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 0522dd9d14a..e11ebf6eb8b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -1,16 +1,16 @@ from collections.abc import Mapping import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_runtime import resolve_dify_run_context -from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams from graphon.entities.base_node_data import BaseNodeData from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.node import Node from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_runtime import resolve_dify_run_context +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 87ec2d5bcee..555ff0c9452 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -4,8 +4,6 @@ from unittest.mock import Mock, patch import pandas as pd import pytest from docx.oxml.text.paragraph import CT_P - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod @@ -21,6 +19,8 @@ from graphon.nodes.document_extractor.node import ( from graphon.variables import ArrayFileSegment from graphon.variables.segments import ArrayStringSegment from graphon.variables.variables import StringVariable + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 782750e02e5..1b14f0ab133 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -3,11 +3,6 @@ import uuid from unittest.mock import MagicMock, Mock import pytest - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, FileType from graphon.graph import Graph @@ -16,6 +11,11 @@ from graphon.nodes.if_else.if_else_node import IfElseNode from graphon.runtime import GraphRuntimeState, VariablePool from graphon.utils.condition.entities import Condition, SubCondition, SubVariableCondition from graphon.variables import ArrayFileSegment + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index b217e4e8e7b..d28c3e01e5f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -1,8 +1,6 @@ from unittest.mock import MagicMock import pytest - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from graphon.enums import WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, FileType from graphon.nodes.list_operator.entities import ( @@ -18,6 +16,8 @@ from graphon.nodes.list_operator.exc import InvalidKeyError from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func from graphon.variables import ArrayFileSegment +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom + @pytest.fixture def list_operator_node(): diff --git a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py deleted file mode 100644 index d613ba154af..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py +++ /dev/null @@ -1,150 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph_events import GraphRunAbortedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import LoopFailedEvent, LoopStartedEvent, StreamCompletedEvent -from graphon.nodes.loop.entities import LoopNodeData -from graphon.nodes.loop.loop_node import LoopNode -from tests.workflow_test_utils import build_test_variable_pool - - -def _usage_with_tokens(total_tokens: int) -> LLMUsage: - usage = LLMUsage.empty_usage() - usage.total_tokens = total_tokens - return usage - - -def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: - seen_configs: list[object] = [] - original_validate_python = NodeConfigDictAdapter.validate_python - - def record_validate_python(value: object): - seen_configs.append(value) - return original_validate_python(value) - - monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) - - child_node_config = { - "id": "answer-node", - "data": { - "type": "answer", - "title": "Answer", - "answer": "", - "loop_id": "loop-node", - }, - } - - LoopNode._extract_variable_selector_to_variable_mapping( - graph_config={ - "nodes": [ - { - "id": "loop-node", - "data": { - "type": "loop", - "title": "Loop", - "loop_count": 1, - "break_conditions": [], - "logical_operator": "and", - }, - }, - child_node_config, - ], - "edges": [], - }, - node_id="loop-node", - node_data=LoopNodeData( - title="Loop", - loop_count=1, - break_conditions=[], - logical_operator="and", - ), - ) - - assert seen_configs == [child_node_config] - - -def test_run_single_loop_raises_on_child_abort_event() -> None: - node = LoopNode.__new__(LoopNode) - node._node_id = "loop-node" - node._node_data = LoopNodeData( - title="Loop", - loop_count=1, - break_conditions=[], - logical_operator="and", - start_node_id="child-start", - ) - - graph_engine = SimpleNamespace( - run=lambda: iter([GraphRunAbortedEvent(reason="quota exceeded")]), - ) - - with pytest.raises(RuntimeError, match="quota exceeded"): - list(node._run_single_loop(graph_engine=graph_engine, current_index=0)) - - -def test_loop_run_fails_on_child_abort_and_stops_subsequent_rounds() -> None: - node = LoopNode.__new__(LoopNode) - node._node_id = "loop-node" - node._node_data = LoopNodeData( - title="Loop", - loop_count=2, - break_conditions=[], - logical_operator="and", - start_node_id="child-start", - ) - node.graph_config = {"nodes": [], "edges": []} - node.graph_runtime_state = SimpleNamespace( - variable_pool=build_test_variable_pool(), - llm_usage=LLMUsage.empty_usage(), - ) - - aborting_engine = SimpleNamespace( - graph_runtime_state=SimpleNamespace(outputs={}, llm_usage=LLMUsage.empty_usage()), - ) - create_graph_engine = MagicMock(return_value=aborting_engine) - node._create_graph_engine = create_graph_engine - node._run_single_loop = lambda *, graph_engine, current_index: (_ for _ in ()).throw(RuntimeError("quota exceeded")) - - events = list(node._run()) - - assert isinstance(events[0], LoopStartedEvent) - assert isinstance(events[1], LoopFailedEvent) - assert events[1].error == "quota exceeded" - assert isinstance(events[2], StreamCompletedEvent) - assert events[2].node_run_result.status == WorkflowNodeExecutionStatus.FAILED - assert events[2].node_run_result.error == "quota exceeded" - create_graph_engine.assert_called_once() - - -def test_loop_run_merges_child_usage_before_failing_on_child_abort() -> None: - node = LoopNode.__new__(LoopNode) - node._node_id = "loop-node" - node._node_data = LoopNodeData( - title="Loop", - loop_count=1, - break_conditions=[], - logical_operator="and", - start_node_id="child-start", - ) - node.graph_config = {"nodes": [], "edges": []} - node.graph_runtime_state = SimpleNamespace( - variable_pool=build_test_variable_pool(), - llm_usage=LLMUsage.empty_usage(), - ) - - aborting_engine = SimpleNamespace( - graph_runtime_state=SimpleNamespace(outputs={}, llm_usage=_usage_with_tokens(7)), - ) - node._create_graph_engine = MagicMock(return_value=aborting_engine) - node._run_single_loop = lambda *, graph_engine, current_index: (_ for _ in ()).throw(RuntimeError("quota exceeded")) - - events = list(node._run()) - - assert isinstance(events[-1], StreamCompletedEvent) - assert events[-1].node_run_result.llm_usage.total_tokens == 7 - assert node.graph_runtime_state.llm_usage.total_tokens == 7 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py deleted file mode 100644 index efbf786a558..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py +++ /dev/null @@ -1,126 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock - -from graphon.model_runtime.entities import ImagePromptMessageContent -from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory -from graphon.nodes.protocols import HttpClientProtocol -from graphon.nodes.question_classifier import ( - QuestionClassifierNode, - QuestionClassifierNodeData, -) -from graphon.template_rendering import Jinja2TemplateRenderer -from tests.workflow_test_utils import build_test_graph_init_params - - -def test_init_question_classifier_node_data(): - data = { - "title": "test classifier node", - "query_variable_selector": ["id", "name"], - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, - "classes": [{"id": "1", "name": "class 1"}], - "instruction": "This is a test instruction", - "memory": { - "role_prefix": {"user": "Human:", "assistant": "AI:"}, - "window": {"enabled": True, "size": 5}, - "query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:", - }, - "vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}}, - } - - node_data = QuestionClassifierNodeData.model_validate(data) - - assert node_data.query_variable_selector == ["id", "name"] - assert node_data.model.provider == "openai" - assert node_data.classes[0].id == "1" - assert node_data.instruction == "This is a test instruction" - assert node_data.memory is not None - assert node_data.memory.role_prefix is not None - assert node_data.memory.role_prefix.user == "Human:" - assert node_data.memory.role_prefix.assistant == "AI:" - assert node_data.memory.window.enabled == True - assert node_data.memory.window.size == 5 - assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:" - assert node_data.vision.enabled == True - assert node_data.vision.configs.variable_selector == ["image"] - assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.LOW - - -def test_init_question_classifier_node_data_without_vision_config(): - data = { - "title": "test classifier node", - "query_variable_selector": ["id", "name"], - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, - "classes": [{"id": "1", "name": "class 1"}], - "instruction": "This is a test instruction", - "memory": { - "role_prefix": {"user": "Human:", "assistant": "AI:"}, - "window": {"enabled": True, "size": 5}, - "query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:", - }, - } - - node_data = QuestionClassifierNodeData.model_validate(data) - - assert node_data.query_variable_selector == ["id", "name"] - assert node_data.model.provider == "openai" - assert node_data.classes[0].id == "1" - assert node_data.instruction == "This is a test instruction" - assert node_data.memory is not None - assert node_data.memory.role_prefix is not None - assert node_data.memory.role_prefix.user == "Human:" - assert node_data.memory.role_prefix.assistant == "AI:" - assert node_data.memory.window.enabled == True - assert node_data.memory.window.size == 5 - assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:" - assert node_data.vision.enabled == False - assert node_data.vision.configs.variable_selector == ["sys", "files"] - assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH - - -def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(monkeypatch): - node_data = QuestionClassifierNodeData.model_validate( - { - "title": "test classifier node", - "query_variable_selector": ["id", "name"], - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, - "classes": [{"id": "1", "name": "class 1"}], - "instruction": "This is a test instruction", - } - ) - template_renderer = MagicMock(spec=Jinja2TemplateRenderer) - node = QuestionClassifierNode( - id="node-id", - config={"id": "node-id", "data": node_data.model_dump(mode="json")}, - graph_init_params=build_test_graph_init_params( - workflow_id="workflow-id", - graph_config={}, - tenant_id="tenant-id", - app_id="app-id", - user_id="user-id", - ), - graph_runtime_state=SimpleNamespace(variable_pool=MagicMock()), - credentials_provider=MagicMock(spec=CredentialsProvider), - model_factory=MagicMock(spec=ModelFactory), - model_instance=MagicMock(), - http_client=MagicMock(spec=HttpClientProtocol), - llm_file_saver=MagicMock(), - template_renderer=template_renderer, - ) - fetch_prompt_messages = MagicMock(return_value=([], None)) - monkeypatch.setattr( - "graphon.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages", - fetch_prompt_messages, - ) - monkeypatch.setattr( - "graphon.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema", - MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])), - ) - - node._calculate_rest_token( - node_data=node_data, - query="hello", - model_instance=MagicMock(stop=(), parameters={}), - context="", - ) - - assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 543f9878de1..833c3030521 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -2,16 +2,16 @@ import json import time import pytest -from pydantic import ValidationError as PydanticValidationError - -from core.workflow.system_variables import build_system_variables -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState from graphon.variables import build_segment, segment_to_variable from graphon.variables.input_entities import VariableEntity, VariableEntityType from graphon.variables.variables import Variable +from pydantic import ValidationError as PydanticValidationError + +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index c8061813401..15870148027 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -8,14 +8,14 @@ from typing import TYPE_CHECKING, Any from unittest.mock import MagicMock import pytest - -from core.workflow.system_variables import build_system_variables from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.segments import ArrayFileSegment + +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if TYPE_CHECKING: # pragma: no cover - imported for type checking only diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py index 438af211f36..c4dfc5a1792 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py @@ -6,6 +6,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType +from graphon.nodes.tool.exc import ToolRuntimeInvocationError +from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage +from graphon.runtime import VariablePool from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError @@ -17,11 +22,6 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.node_runtime import DifyToolNodeRuntime from core.workflow.system_variables import build_system_variables -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType -from graphon.nodes.tool.exc import ToolRuntimeInvocationError -from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage -from graphon.runtime import VariablePool from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py index c8ddc53284b..952e798430f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -1,12 +1,13 @@ from collections.abc import Mapping -from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode -from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.runtime import GraphRuntimeState + +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py deleted file mode 100644 index fabc8df73e8..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ /dev/null @@ -1,312 +0,0 @@ -import time -import uuid -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool -from graphon.entities import GraphInitParams -from graphon.graph import Graph -from graphon.graph_events.node import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent -from graphon.nodes.variable_assigner.common import helpers as common_helpers -from graphon.nodes.variable_assigner.v1 import VariableAssignerNode -from graphon.nodes.variable_assigner.v1.node_data import WriteMode -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import ArrayStringVariable, StringVariable - -DEFAULT_NODE_ID = "node_id" - - -def _build_variable_pool( - *, - conversation_id: str, - conversation_variables: list[StringVariable | ArrayStringVariable], -) -> VariablePool: - variable_pool = VariablePool() - add_variables_to_pool( - variable_pool, - build_bootstrap_variables( - system_variables=build_system_variables(conversation_id=conversation_id), - conversation_variables=conversation_variables, - ), - ) - return variable_pool - - -def test_overwrite_string_variable(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "over-write", - "input_variable_selector": ["node_id", "test_string_variable"], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = StringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value="the first value", - ) - - input_variable = StringVariable( - id=str(uuid4()), - name="test_string_variable", - value="the second value", - ) - conversation_id = str(uuid.uuid4()) - - # construct variable pool - variable_pool = _build_variable_pool( - conversation_id=conversation_id, - conversation_variables=[conversation_variable], - ) - - variable_pool.add( - [DEFAULT_NODE_ID, input_variable.name], - input_variable, - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.OVER_WRITE, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - events = list(node.run()) - updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) - succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) - updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) - assert updated_variables is not None - assert updated_variables[0].name == conversation_variable.name - assert updated_variables[0].new_value == input_variable.value - assert updated_event.variable.value == "the second value" - assert tuple(updated_event.variable.selector) == ("conversation", conversation_variable.name) - - -def test_append_variable_to_array(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "append", - "input_variable_selector": ["node_id", "test_string_variable"], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["the first value"], - ) - - input_variable = StringVariable( - id=str(uuid4()), - name="test_string_variable", - value="the second value", - ) - conversation_id = str(uuid.uuid4()) - - variable_pool = _build_variable_pool( - conversation_id=conversation_id, - conversation_variables=[conversation_variable], - ) - variable_pool.add( - [DEFAULT_NODE_ID, input_variable.name], - input_variable, - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.APPEND, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - events = list(node.run()) - updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) - succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) - updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) - assert updated_variables is not None - assert updated_variables[0].name == conversation_variable.name - assert updated_variables[0].new_value == ["the first value", "the second value"] - assert updated_event.variable.value == ["the first value", "the second value"] - - -def test_clear_array(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "clear", - "input_variable_selector": [], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["the first value"], - ) - - conversation_id = str(uuid.uuid4()) - variable_pool = _build_variable_pool( - conversation_id=conversation_id, - conversation_variables=[conversation_variable], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.CLEAR, - "input_variable_selector": [], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - events = list(node.run()) - updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) - succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) - updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) - assert updated_variables is not None - assert updated_variables[0].name == conversation_variable.name - assert updated_variables[0].new_value == [] - assert updated_event.variable.value == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py deleted file mode 100644 index 8b137891791..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py deleted file mode 100644 index 9ac8bbe9c2a..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py +++ /dev/null @@ -1,22 +0,0 @@ -from graphon.nodes.variable_assigner.v2.enums import Operation -from graphon.nodes.variable_assigner.v2.helpers import is_input_value_valid -from graphon.variables import SegmentType - - -def test_is_input_value_valid_overwrite_array_string(): - # Valid cases - assert is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["hello", "world"] - ) - assert is_input_value_valid(variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[]) - - # Invalid cases - assert not is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value="not an array" - ) - assert not is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[1, 2, 3] - ) - assert not is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["valid", 123, "invalid"] - ) diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py deleted file mode 100644 index 53346c4a90c..00000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ /dev/null @@ -1,430 +0,0 @@ -import time -import uuid -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool -from graphon.entities import GraphInitParams -from graphon.graph import Graph -from graphon.graph_events import NodeRunVariableUpdatedEvent -from graphon.nodes.variable_assigner.v2 import VariableAssignerNode -from graphon.nodes.variable_assigner.v2.enums import InputType, Operation -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import ArrayStringVariable - -DEFAULT_NODE_ID = "node_id" - - -def _build_variable_pool(*, conversation_variables: list[ArrayStringVariable]) -> VariablePool: - variable_pool = VariablePool() - add_variables_to_pool( - variable_pool, - build_bootstrap_variables( - system_variables=build_system_variables(conversation_id="conversation_id"), - conversation_variables=conversation_variables, - ), - ) - return variable_pool - - -def test_handle_item_directly(): - """Test the _handle_item method directly for remove operations.""" - # Create variables - variable1 = ArrayStringVariable( - id=str(uuid4()), - name="test_variable1", - value=["first", "second", "third"], - ) - - variable2 = ArrayStringVariable( - id=str(uuid4()), - name="test_variable2", - value=["first", "second", "third"], - ) - - # Create a mock class with just the _handle_item method - class MockNode: - def _handle_item(self, *, variable, operation, value): - match operation: - case Operation.REMOVE_FIRST: - if not variable.value: - return variable.value - return variable.value[1:] - case Operation.REMOVE_LAST: - if not variable.value: - return variable.value - return variable.value[:-1] - - node = MockNode() - - # Test remove-first - result1 = node._handle_item( - variable=variable1, - operation=Operation.REMOVE_FIRST, - value=None, - ) - - # Test remove-last - result2 = node._handle_item( - variable=variable2, - operation=Operation.REMOVE_LAST, - value=None, - ) - - # Check the results - assert result1 == ["second", "third"] - assert result2 == ["first", "second"] - - -def test_remove_first_from_array(): - """Test removing the first element from an array.""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["first", "second", "third"], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_FIRST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - # Run the node - result = list(node.run()) - - updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) - assert updated_event.variable.value == ["second", "third"] - - -def test_remove_last_from_array(): - """Test removing the last element from an array.""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["first", "second", "third"], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_LAST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - result = list(node.run()) - updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) - assert updated_event.variable.value == ["first", "second"] - - -def test_remove_first_from_empty_array(): - """Test removing the first element from an empty array (should do nothing).""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=[], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_FIRST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - result = list(node.run()) - updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) - assert updated_event.variable.value == [] - - -def test_remove_last_from_empty_array(): - """Test removing the last element from an empty array (should do nothing).""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=[], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_LAST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - result = list(node.run()) - updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) - assert updated_event.variable.value == [] - - -def test_node_factory_creates_variable_assigner_node(): - graph_config = { - "edges": [], - "nodes": [ - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - variable_pool = _build_variable_pool(conversation_variables=[]) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - - node = node_factory.create_node(graph_config["nodes"][0]) - - assert isinstance(node, VariableAssignerNode) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index 617554ee179..f1132af02b5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -1,4 +1,5 @@ import pytest +from graphon.entities.exc import BaseNodeError from core.workflow.nodes.trigger_webhook.exc import ( WebhookConfigError, @@ -6,7 +7,6 @@ from core.workflow.nodes.trigger_webhook.exc import ( WebhookNotFoundError, WebhookTimeoutError, ) -from graphon.entities.exc import BaseNodeError def test_webhook_node_error_inheritance(): diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index 6fbd26131df..cccd3fb6767 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -8,6 +8,10 @@ when passing files to downstream LLM nodes. from unittest.mock import Mock, patch +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.nodes.trigger_webhook.entities import ( ContentType, @@ -17,10 +21,6 @@ from core.workflow.nodes.trigger_webhook.entities import ( ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode from core.workflow.system_variables import default_system_variables -from graphon.entities.graph_init_params import GraphInitParams -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.runtime.graph_runtime_state import GraphRuntimeState -from graphon.runtime.variable_pool import VariablePool from tests.workflow_test_utils import build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 9f954b20905..34c66a4f9f3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -1,6 +1,11 @@ from unittest.mock import patch import pytest +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import FileVariable, StringVariable from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE @@ -13,12 +18,6 @@ from core.workflow.nodes.trigger_webhook.entities import ( ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode from core.workflow.system_variables import default_system_variables -from graphon.entities.graph_init_params import GraphInitParams -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType -from graphon.runtime.graph_runtime_state import GraphRuntimeState -from graphon.runtime.variable_pool import VariablePool -from graphon.variables import FileVariable, StringVariable from tests.workflow_test_utils import build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/test_enums.py b/api/tests/unit_tests/core/workflow/test_enums.py deleted file mode 100644 index 453e0a8502e..00000000000 --- a/api/tests/unit_tests/core/workflow/test_enums.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Tests for workflow pause related enums and constants.""" - -from graphon.enums import ( - WorkflowExecutionStatus, -) - - -class TestWorkflowExecutionStatus: - """Test WorkflowExecutionStatus enum.""" - - def test_is_ended_method(self): - """Test is_ended method for different statuses.""" - # Test ended statuses - ended_statuses = [ - WorkflowExecutionStatus.SUCCEEDED, - WorkflowExecutionStatus.FAILED, - WorkflowExecutionStatus.PARTIAL_SUCCEEDED, - WorkflowExecutionStatus.STOPPED, - ] - - for status in ended_statuses: - assert status.is_ended(), f"{status} should be considered ended" - - # Test non-ended statuses - non_ended_statuses = [ - WorkflowExecutionStatus.SCHEDULED, - WorkflowExecutionStatus.RUNNING, - WorkflowExecutionStatus.PAUSED, - ] - - for status in non_ended_statuses: - assert not status.is_ended(), f"{status} should not be considered ended" - - def test_ended_values(self): - """Test ended_values returns the expected status values.""" - assert set(WorkflowExecutionStatus.ended_values()) == { - WorkflowExecutionStatus.SUCCEEDED.value, - WorkflowExecutionStatus.FAILED.value, - WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value, - WorkflowExecutionStatus.STOPPED.value, - } diff --git a/api/tests/unit_tests/core/workflow/test_human_input_compat.py b/api/tests/unit_tests/core/workflow/test_human_input_compat.py index 0623800b30c..cd41c43e4ad 100644 --- a/api/tests/unit_tests/core/workflow/test_human_input_compat.py +++ b/api/tests/unit_tests/core/workflow/test_human_input_compat.py @@ -1,5 +1,6 @@ from types import SimpleNamespace +from graphon.enums import BuiltinNodeTypes from pydantic import BaseModel from core.workflow.human_input_compat import ( @@ -15,7 +16,6 @@ from core.workflow.human_input_compat import ( normalize_node_data_for_graph, parse_human_input_delivery_methods, ) -from graphon.enums import BuiltinNodeTypes def test_email_delivery_config_helpers_render_and_sanitize_text() -> None: diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index 1db848a010d..bc0b339fec5 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -2,15 +2,15 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch, sentinel import pytest +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.code.entities import CodeLanguage +from graphon.variables.segments import StringSegment from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext, InvokeFrom, UserFrom from core.workflow import node_factory from core.workflow import template_rendering as workflow_template_rendering from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.code.entities import CodeLanguage -from graphon.variables.segments import StringSegment def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None: diff --git a/api/tests/unit_tests/core/workflow/test_node_runtime.py b/api/tests/unit_tests/core/workflow/test_node_runtime.py index 71a2afb28af..4f9c1dad599 100644 --- a/api/tests/unit_tests/core/workflow/test_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/test_node_runtime.py @@ -2,6 +2,10 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, sentinel import pytest +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.nodes.human_input.entities import HumanInputNodeData from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext, InvokeFrom, UserFrom from core.llm_generator.output_parser.errors import OutputParserError @@ -26,10 +30,6 @@ from core.workflow.node_runtime import ( build_dify_llm_file_saver, resolve_dify_run_context, ) -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.nodes.human_input.entities import HumanInputNodeData from tests.workflow_test_utils import build_test_run_context diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py index 72a0557b7c0..05ea3dc3117 100644 --- a/api/tests/unit_tests/core/workflow/test_system_variable.py +++ b/api/tests/unit_tests/core/workflow/test_system_variable.py @@ -1,14 +1,14 @@ from types import SimpleNamespace +from graphon.file import File, FileTransferMethod, FileType +from graphon.nodes import BuiltinNodeTypes + from core.workflow.system_variables import ( build_system_variables, default_system_variables, get_node_creation_preload_selectors, system_variables_to_mapping, ) -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File -from graphon.nodes import BuiltinNodeTypes def test_build_system_variables_normalizes_workflow_execution_id(): diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index dddd6eb00c4..e7b2b2914a0 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -2,15 +2,6 @@ import uuid from collections import defaultdict import pytest - -from core.workflow.system_variables import build_system_variables, system_variables_to_mapping -from core.workflow.variable_pool_initializer import add_variables_to_pool -from core.workflow.variable_prefixes import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, - SYSTEM_VARIABLE_NODE_ID, -) -from factories.variable_factory import build_segment, segment_to_variable from graphon.file import File, FileTransferMethod, FileType from graphon.runtime import VariablePool from graphon.variables import FileSegment, StringSegment @@ -36,6 +27,15 @@ from graphon.variables.variables import ( Variable, ) +from core.workflow.system_variables import build_system_variables, system_variables_to_mapping +from core.workflow.variable_pool_initializer import add_variables_to_pool +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) +from factories.variable_factory import build_segment, segment_to_variable + @pytest.fixture def pool(): diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 4ae6ed1659a..d8361d06c47 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -1,6 +1,12 @@ from types import SimpleNamespace import pytest +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.file import File, FileTransferMethod, FileType +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.runtime import VariablePool +from graphon.variables.variables import StringVariable from configs import dify_config from core.helper.code_executor.code_executor import CodeLanguage @@ -10,13 +16,6 @@ from core.workflow.variable_prefixes import ( ENVIRONMENT_VARIABLE_NODE_ID, ) from core.workflow.workflow_entry import WorkflowEntry -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.file.enums import FileType -from graphon.file.models import File, FileTransferMethod -from graphon.nodes.code.code_node import CodeNode -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.runtime import VariablePool -from graphon.variables.variables import StringVariable @pytest.fixture(autouse=True) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py index 456ab5da413..879c0bb7218 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -4,18 +4,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch, sentinel import pytest - -from core.app.apps.exc import GenerateTaskStoppedError -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.model_manager import ModelInstance -from core.workflow import workflow_entry -from core.workflow.system_variables import default_system_variables from graphon.entities.base_node_data import BaseNodeData from graphon.entities.graph_config import NodeConfigDictAdapter from graphon.enums import NodeType, WorkflowNodeExecutionStatus from graphon.errors import WorkflowNodeRunFailedError -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File +from graphon.file import File, FileTransferMethod, FileType from graphon.graph import Graph from graphon.graph_events import GraphRunFailedEvent from graphon.model_runtime.entities.llm_entities import LLMUsage @@ -24,6 +17,12 @@ from graphon.nodes import BuiltinNodeTypes from graphon.nodes.base.node import Node from graphon.runtime import ChildGraphNotFoundError, VariablePool from graphon.variables.variables import StringVariable + +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.model_manager import ModelInstance +from core.workflow import workflow_entry +from core.workflow.system_variables import default_system_variables from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py index b3ecfe4bc93..4b2f98aefff 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py @@ -2,10 +2,11 @@ from unittest.mock import MagicMock, patch +from graphon.graph_engine.command_channels import RedisChannel +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.workflow_entry import WorkflowEntry -from graphon.graph_engine.command_channels.redis_channel import RedisChannel -from graphon.runtime import GraphRuntimeState, VariablePool class TestWorkflowEntryRedisChannel: diff --git a/api/tests/unit_tests/core/workflow/utils/test_condition.py b/api/tests/unit_tests/core/workflow/utils/test_condition.py deleted file mode 100644 index f4c86aa77a0..00000000000 --- a/api/tests/unit_tests/core/workflow/utils/test_condition.py +++ /dev/null @@ -1,52 +0,0 @@ -from graphon.runtime import VariablePool -from graphon.utils.condition.entities import Condition -from graphon.utils.condition.processor import ConditionProcessor - - -def test_number_formatting(): - condition_processor = ConditionProcessor() - variable_pool = VariablePool() - variable_pool.add(["test_node_id", "zone"], 0) - variable_pool.add(["test_node_id", "one"], 1) - variable_pool.add(["test_node_id", "one_one"], 1.1) - # 0 <= 0.95 - assert ( - condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=[Condition(variable_selector=["test_node_id", "zone"], comparison_operator="โ‰ค", value="0.95")], - operator="or", - ).final_result - == True - ) - - # 1 >= 0.95 - assert ( - condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=[Condition(variable_selector=["test_node_id", "one"], comparison_operator="โ‰ฅ", value="0.95")], - operator="or", - ).final_result - == True - ) - - # 1.1 >= 0.95 - assert ( - condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=[ - Condition(variable_selector=["test_node_id", "one_one"], comparison_operator="โ‰ฅ", value="0.95") - ], - operator="or", - ).final_result - == True - ) - - # 1.1 > 0 - assert ( - condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=[Condition(variable_selector=["test_node_id", "one_one"], comparison_operator=">", value="0")], - operator="or", - ).final_result - == True - ) diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py deleted file mode 100644 index 009c860f160..00000000000 --- a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py +++ /dev/null @@ -1,48 +0,0 @@ -import dataclasses - -from graphon.nodes.base import variable_template_parser -from graphon.nodes.base.entities import VariableSelector - - -def test_extract_selectors_from_template(): - template = ( - "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." - ) - selectors = variable_template_parser.extract_selectors_from_template(template) - assert selectors == [ - VariableSelector(variable="#sys.user_id#", value_selector=["sys", "user_id"]), - VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]), - VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]), - ] - - -def test_invalid_references(): - @dataclasses.dataclass - class TestCase: - name: str - template: str - - cases = [ - TestCase( - name="lack of closing brace", - template="Hello, {{#sys.user_id#", - ), - TestCase( - name="lack of opening brace", - template="Hello, #sys.user_id#}}", - ), - TestCase( - name="lack selector name", - template="Hello, {{#sys#}}", - ), - TestCase( - name="empty node name part", - template="Hello, {{#.user_id#}}", - ), - ] - for idx, c in enumerate(cases, 1): - fail_msg = f"Test case {c.name} failed, index={idx}" - selectors = variable_template_parser.extract_selectors_from_template(c.template) - assert selectors == [], fail_msg - parser = variable_template_parser.VariableTemplateParser(c.template) - assert parser.extract_variable_selectors() == [], fail_msg diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py index 511192001ef..4fe3f2cb283 100644 --- a/api/tests/unit_tests/factories/test_build_from_mapping.py +++ b/api/tests/unit_tests/factories/test_build_from_mapping.py @@ -2,13 +2,13 @@ import uuid from unittest.mock import MagicMock, patch import pytest +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig from httpx import Response from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope from core.workflow.file_reference import build_file_reference, parse_file_reference, resolve_file_record_id from factories.file_factory.builders import build_from_mapping as _build_from_mapping -from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig from models import ToolFile, UploadFile # Test Data diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py index 70d7d8c5750..8d573b11543 100644 --- a/api/tests/unit_tests/factories/test_variable_factory.py +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -4,11 +4,6 @@ from typing import Any from uuid import uuid4 import pytest -from hypothesis import HealthCheck, given, settings -from hypothesis import strategies as st - -from factories import variable_factory -from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type from graphon.file import File, FileTransferMethod, FileType from graphon.variables import ( ArrayNumberVariable, @@ -36,6 +31,11 @@ from graphon.variables.segments import ( StringSegment, ) from graphon.variables.types import SegmentType +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st + +from factories import variable_factory +from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type def test_string_variable(): diff --git a/api/tests/unit_tests/fields/test_file_fields.py b/api/tests/unit_tests/fields/test_file_fields.py index 9d9f626b9e4..0e848d6ef5b 100644 --- a/api/tests/unit_tests/fields/test_file_fields.py +++ b/api/tests/unit_tests/fields/test_file_fields.py @@ -4,11 +4,11 @@ from datetime import datetime from types import SimpleNamespace import pytest +from graphon.file import File, FileTransferMethod, FileType from core.workflow.file_reference import build_file_reference from fields import conversation_fields, message_fields from fields.file_fields import FileResponse, FileWithSignedUrl, RemoteFileInfo, UploadConfig -from graphon.file import File, FileTransferMethod, FileType def test_file_response_serializes_datetime() -> None: diff --git a/api/tests/unit_tests/graphon/file/test_file_factory.py b/api/tests/unit_tests/graphon/file/test_file_factory.py deleted file mode 100644 index eeb537c28f5..00000000000 --- a/api/tests/unit_tests/graphon/file/test_file_factory.py +++ /dev/null @@ -1,18 +0,0 @@ -from graphon.file import FileType -from graphon.file.file_factory import get_file_type_by_mime_type, standardize_file_type - - -def test_standardize_file_type_recognizes_case_insensitive_extension(): - assert standardize_file_type(extension=".PNG") == FileType.IMAGE - - -def test_standardize_file_type_recognizes_document_extension(): - assert standardize_file_type(extension=".txt") == FileType.DOCUMENT - - -def test_standardize_file_type_falls_back_to_mime_type(): - assert standardize_file_type(mime_type="video/mp4") == FileType.VIDEO - - -def test_get_file_type_by_mime_type_returns_custom_for_unknown_type(): - assert get_file_type_by_mime_type("application/octet-stream") == FileType.CUSTOM diff --git a/api/tests/unit_tests/graphon/file/test_file_manager.py b/api/tests/unit_tests/graphon/file/test_file_manager.py deleted file mode 100644 index 1eebb13f4e8..00000000000 --- a/api/tests/unit_tests/graphon/file/test_file_manager.py +++ /dev/null @@ -1,133 +0,0 @@ -import base64 -from unittest.mock import MagicMock - -import pytest - -from core.workflow.file_reference import build_file_reference -from graphon.file import File, FileTransferMethod, FileType -from graphon.file.file_manager import download, to_prompt_message_content -from graphon.file.runtime import get_workflow_file_runtime, set_workflow_file_runtime -from graphon.model_runtime.entities import ( - DocumentPromptMessageContent, - ImagePromptMessageContent, - TextPromptMessageContent, -) - - -def _build_file( - *, - transfer_method: FileTransferMethod, - file_type: FileType = FileType.IMAGE, - reference: str | None = None, - remote_url: str | None = None, - filename: str = "image.png", - extension: str = ".png", - mime_type: str = "image/png", -) -> File: - return File( - id="file-id", - type=file_type, - transfer_method=transfer_method, - reference=reference, - remote_url=remote_url, - filename=filename, - extension=extension, - mime_type=mime_type, - size=128, - ) - - -@pytest.fixture -def workflow_file_runtime(): - previous_runtime = get_workflow_file_runtime() - runtime = MagicMock() - set_workflow_file_runtime(runtime) - try: - yield runtime - finally: - set_workflow_file_runtime(previous_runtime) - - -@pytest.mark.parametrize( - "transfer_method", - [ - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.TOOL_FILE, - FileTransferMethod.DATASOURCE_FILE, - ], -) -def test_download_delegates_storage_backed_files_to_runtime_loader(workflow_file_runtime, transfer_method) -> None: - workflow_file_runtime.load_file_bytes.return_value = b"payload" - file = _build_file( - transfer_method=transfer_method, - reference=build_file_reference(record_id="file-id", storage_key="files/payload.bin"), - ) - - assert download(file) == b"payload" - workflow_file_runtime.load_file_bytes.assert_called_once_with(file=file) - - -def test_download_remote_url_uses_runtime_http_get(workflow_file_runtime) -> None: - response = MagicMock() - response.content = b"remote-payload" - workflow_file_runtime.http_get.return_value = response - file = _build_file( - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url="https://example.com/image.png", - ) - - assert download(file) == b"remote-payload" - workflow_file_runtime.http_get.assert_called_once_with("https://example.com/image.png", follow_redirects=True) - response.raise_for_status.assert_called_once_with() - - -def test_to_prompt_message_content_uses_runtime_url_resolution_for_images(workflow_file_runtime) -> None: - workflow_file_runtime.multimodal_send_format = "url" - workflow_file_runtime.resolve_file_url.return_value = "https://cdn.example.com/image.png" - file = _build_file( - transfer_method=FileTransferMethod.LOCAL_FILE, - reference=build_file_reference(record_id="upload-file-id", storage_key="files/image.png"), - ) - - content = to_prompt_message_content(file, image_detail_config=ImagePromptMessageContent.DETAIL.HIGH) - - assert isinstance(content, ImagePromptMessageContent) - assert content.url == "https://cdn.example.com/image.png" - assert content.base64_data == "" - assert content.detail == ImagePromptMessageContent.DETAIL.HIGH - - -def test_to_prompt_message_content_uses_runtime_file_loader_for_base64_documents(workflow_file_runtime) -> None: - workflow_file_runtime.multimodal_send_format = "base64" - workflow_file_runtime.load_file_bytes.return_value = b"document-bytes" - file = _build_file( - transfer_method=FileTransferMethod.TOOL_FILE, - file_type=FileType.DOCUMENT, - reference=build_file_reference(record_id="tool-file-id", storage_key="docs/report.pdf"), - filename="report.pdf", - extension=".pdf", - mime_type="application/pdf", - ) - - content = to_prompt_message_content(file) - - assert isinstance(content, DocumentPromptMessageContent) - assert content.base64_data == base64.b64encode(b"document-bytes").decode("utf-8") - assert content.url == "" - workflow_file_runtime.load_file_bytes.assert_called_once_with(file=file) - - -def test_to_prompt_message_content_returns_text_placeholder_for_custom_files() -> None: - file = _build_file( - transfer_method=FileTransferMethod.REMOTE_URL, - file_type=FileType.CUSTOM, - remote_url="https://example.com/archive.bin", - filename="archive.bin", - extension=".bin", - mime_type="application/octet-stream", - ) - - content = to_prompt_message_content(file) - - assert isinstance(content, TextPromptMessageContent) - assert content.data == "[Unsupported file type: archive.bin (custom)]" diff --git a/api/tests/unit_tests/graphon/file/test_models.py b/api/tests/unit_tests/graphon/file/test_models.py deleted file mode 100644 index 17d244da5f3..00000000000 --- a/api/tests/unit_tests/graphon/file/test_models.py +++ /dev/null @@ -1,54 +0,0 @@ -from core.workflow.file_reference import build_file_reference -from graphon.file import File, FileTransferMethod, FileType, helpers - - -def _build_local_file(*, reference: str, storage_key: str | None = None) -> File: - return File( - id="file-id", - type=FileType.DOCUMENT, - transfer_method=FileTransferMethod.LOCAL_FILE, - reference=reference, - filename="report.pdf", - extension=".pdf", - mime_type="application/pdf", - size=128, - storage_key=storage_key, - ) - - -def test_file_exposes_legacy_aliases_from_opaque_reference() -> None: - reference = build_file_reference(record_id="upload-file-id", storage_key="files/report.pdf") - - file = _build_local_file(reference=reference) - - assert file.reference == reference - assert file.related_id == "upload-file-id" - assert file.storage_key == "files/report.pdf" - - -def test_file_falls_back_to_raw_reference_when_opaque_reference_is_invalid() -> None: - file = _build_local_file(reference="dify-file-ref:not-base64", storage_key="fallback-key") - - assert file.related_id == "dify-file-ref:not-base64" - assert file.storage_key == "fallback-key" - - -def test_file_to_dict_keeps_reference_and_legacy_related_id(monkeypatch) -> None: - reference = build_file_reference(record_id="upload-file-id", storage_key="files/report.pdf") - file = _build_local_file(reference=reference) - monkeypatch.setattr(helpers, "resolve_file_url", lambda _file, for_external=True: "https://example.com/report.pdf") - - serialized = file.to_dict() - - assert serialized["reference"] == reference - assert serialized["related_id"] == "upload-file-id" - assert serialized["url"] == "https://example.com/report.pdf" - - -def test_file_related_id_setter_updates_reference_alias() -> None: - file = _build_local_file(reference="upload-file-id", storage_key="files/report.pdf") - - file.related_id = "replacement-upload-id" - - assert file.reference == "replacement-upload-id" - assert file.related_id == "replacement-upload-id" diff --git a/api/tests/unit_tests/graphon/model_runtime/__base/__init__.py b/api/tests/unit_tests/graphon/model_runtime/__base/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/unit_tests/graphon/model_runtime/__base/test_increase_tool_call.py b/api/tests/unit_tests/graphon/model_runtime/__base/test_increase_tool_call.py deleted file mode 100644 index 7b4fc5a04ce..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/__base/test_increase_tool_call.py +++ /dev/null @@ -1,114 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage -from graphon.model_runtime.model_providers.__base.large_language_model import _increase_tool_call - -ToolCall = AssistantPromptMessage.ToolCall - -# CASE 1: Single tool call -INPUTS_CASE_1 = [ - ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), -] -EXPECTED_CASE_1 = [ - ToolCall( - id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') - ), -] - -# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...) -INPUTS_CASE_2 = [ - ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), - ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), -] -EXPECTED_CASE_2 = [ - ToolCall( - id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') - ), - ToolCall( - id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}') - ), -] - -# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...) -INPUTS_CASE_3 = [ - ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), - ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), - ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), - ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), - ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), - ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), -] -EXPECTED_CASE_3 = [ - ToolCall( - id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') - ), - ToolCall( - id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}') - ), -] - -# CASE 4: Tool call sequences with no IDs -INPUTS_CASE_4 = [ - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), -] -EXPECTED_CASE_4 = [ - ToolCall( - id="RANDOM_ID_1", - type="function", - function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'), - ), - ToolCall( - id="RANDOM_ID_2", - type="function", - function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'), - ), -] - - -def _run_case(inputs: list[ToolCall], expected: list[ToolCall]): - actual = [] - _increase_tool_call(inputs, actual) - assert actual == expected - - -def test__increase_tool_call(): - # case 1: - _run_case(INPUTS_CASE_1, EXPECTED_CASE_1) - - # case 2: - _run_case(INPUTS_CASE_2, EXPECTED_CASE_2) - - # case 3: - _run_case(INPUTS_CASE_3, EXPECTED_CASE_3) - - # case 4: - mock_id_generator = MagicMock() - mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4] - with patch( - "graphon.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator - ): - _run_case(INPUTS_CASE_4, EXPECTED_CASE_4) - - -def test__increase_tool_call__no_id_no_name_first_delta_should_raise(): - inputs = [ - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='"value"}')), - ] - actual: list[ToolCall] = [] - with patch("graphon.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", MagicMock()): - with pytest.raises(ValueError): - _increase_tool_call(inputs, actual) diff --git a/api/tests/unit_tests/graphon/model_runtime/__base/test_large_language_model_non_stream_parsing.py b/api/tests/unit_tests/graphon/model_runtime/__base/test_large_language_model_non_stream_parsing.py deleted file mode 100644 index c922fbaa605..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/__base/test_large_language_model_non_stream_parsing.py +++ /dev/null @@ -1,126 +0,0 @@ -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_runtime_result - - -def _make_chunk( - *, - model: str = "test-model", - content: str | list[TextPromptMessageContent] | None, - tool_calls: list[AssistantPromptMessage.ToolCall] | None = None, - usage: LLMUsage | None = None, - system_fingerprint: str | None = None, -) -> LLMResultChunk: - message = AssistantPromptMessage(content=content, tool_calls=tool_calls or []) - delta = LLMResultChunkDelta(index=0, message=message, usage=usage) - return LLMResultChunk(model=model, delta=delta, system_fingerprint=system_fingerprint) - - -def test__normalize_non_stream_runtime_result__from_first_chunk_str_content_and_tool_calls(): - prompt_messages = [UserPromptMessage(content="hi")] - - tool_calls = [ - AssistantPromptMessage.ToolCall( - id="1", - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments=""), - ), - AssistantPromptMessage.ToolCall( - id="", - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='{"arg1": '), - ), - AssistantPromptMessage.ToolCall( - id="", - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='"value"}'), - ), - ] - - usage = LLMUsage.empty_usage().model_copy(update={"prompt_tokens": 1, "total_tokens": 1}) - chunk = _make_chunk(content="hello", tool_calls=tool_calls, usage=usage, system_fingerprint="fp-1") - - result = _normalize_non_stream_runtime_result( - model="test-model", prompt_messages=prompt_messages, result=iter([chunk]) - ) - - assert result.model == "test-model" - assert result.prompt_messages == prompt_messages - assert result.message.content == "hello" - assert result.usage.prompt_tokens == 1 - assert result.system_fingerprint == "fp-1" - assert result.message.tool_calls == [ - AssistantPromptMessage.ToolCall( - id="1", - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'), - ) - ] - - -def test__normalize_non_stream_runtime_result__from_first_chunk_list_content(): - prompt_messages = [UserPromptMessage(content="hi")] - - content_list = [TextPromptMessageContent(data="a"), TextPromptMessageContent(data="b")] - chunk = _make_chunk(content=content_list, usage=LLMUsage.empty_usage()) - - result = _normalize_non_stream_runtime_result( - model="test-model", prompt_messages=prompt_messages, result=iter([chunk]) - ) - - assert result.message.content == content_list - - -def test__normalize_non_stream_runtime_result__passthrough_llm_result(): - prompt_messages = [UserPromptMessage(content="hi")] - llm_result = LLMResult( - model="test-model", - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content="ok"), - usage=LLMUsage.empty_usage(), - ) - - assert ( - _normalize_non_stream_runtime_result(model="test-model", prompt_messages=prompt_messages, result=llm_result) - == llm_result - ) - - -def test__normalize_non_stream_runtime_result__empty_iterator_defaults(): - prompt_messages = [UserPromptMessage(content="hi")] - - result = _normalize_non_stream_runtime_result(model="test-model", prompt_messages=prompt_messages, result=iter([])) - - assert result.model == "test-model" - assert result.prompt_messages == prompt_messages - assert result.message.content == [] - assert result.message.tool_calls == [] - assert result.usage == LLMUsage.empty_usage() - assert result.system_fingerprint is None - - -def test__normalize_non_stream_runtime_result__accumulates_all_chunks(): - """All chunks are accumulated from the iterator.""" - prompt_messages = [UserPromptMessage(content="hi")] - - closed: list[bool] = [] - - def _chunk_iter(): - try: - yield _make_chunk(content="hello", usage=LLMUsage.empty_usage()) - yield _make_chunk(content=" world", usage=LLMUsage.empty_usage()) - finally: - closed.append(True) - - result = _normalize_non_stream_runtime_result( - model="test-model", - prompt_messages=prompt_messages, - result=_chunk_iter(), - ) - - assert result.message.content == "hello world" - assert closed == [True] diff --git a/api/tests/unit_tests/graphon/model_runtime/__init__.py b/api/tests/unit_tests/graphon/model_runtime/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/unit_tests/graphon/model_runtime/callbacks/test_base_callback.py b/api/tests/unit_tests/graphon/model_runtime/callbacks/test_base_callback.py deleted file mode 100644 index 776fc230cbb..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/callbacks/test_base_callback.py +++ /dev/null @@ -1,964 +0,0 @@ -"""Comprehensive unit tests for core/model_runtime/callbacks/base_callback.py""" - -from unittest.mock import MagicMock, patch - -import pytest - -from graphon.model_runtime.callbacks.base_callback import ( - _TEXT_COLOR_MAPPING, - Callback, -) -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool - -# --------------------------------------------------------------------------- -# Concrete implementation of the abstract Callback for testing -# --------------------------------------------------------------------------- - - -class ConcreteCallback(Callback): - """A minimal concrete subclass that satisfies all abstract methods.""" - - def __init__(self, raise_error: bool = False): - self.raise_error = raise_error - # Track invocations - self.before_invoke_calls: list[dict] = [] - self.new_chunk_calls: list[dict] = [] - self.after_invoke_calls: list[dict] = [] - self.invoke_error_calls: list[dict] = [] - - def on_before_invoke( - self, - llm_instance, - model, - credentials, - prompt_messages, - model_parameters, - tools=None, - stop=None, - stream=True, - user=None, - ): - self.before_invoke_calls.append( - { - "llm_instance": llm_instance, - "model": model, - "credentials": credentials, - "prompt_messages": prompt_messages, - "model_parameters": model_parameters, - "tools": tools, - "stop": stop, - "stream": stream, - "user": user, - } - ) - # To cover the 'raise NotImplementedError()' in the base class - try: - super().on_before_invoke( - llm_instance, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user - ) - except NotImplementedError: - pass - - def on_new_chunk( - self, - llm_instance, - chunk, - model, - credentials, - prompt_messages, - model_parameters, - tools=None, - stop=None, - stream=True, - user=None, - ): - self.new_chunk_calls.append( - { - "llm_instance": llm_instance, - "chunk": chunk, - "model": model, - "credentials": credentials, - "prompt_messages": prompt_messages, - "model_parameters": model_parameters, - "tools": tools, - "stop": stop, - "stream": stream, - "user": user, - } - ) - try: - super().on_new_chunk( - llm_instance, chunk, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user - ) - except NotImplementedError: - pass - - def on_after_invoke( - self, - llm_instance, - result, - model, - credentials, - prompt_messages, - model_parameters, - tools=None, - stop=None, - stream=True, - user=None, - ): - self.after_invoke_calls.append( - { - "llm_instance": llm_instance, - "result": result, - "model": model, - "credentials": credentials, - "prompt_messages": prompt_messages, - "model_parameters": model_parameters, - "tools": tools, - "stop": stop, - "stream": stream, - "user": user, - } - ) - try: - super().on_after_invoke( - llm_instance, result, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user - ) - except NotImplementedError: - pass - - def on_invoke_error( - self, - llm_instance, - ex, - model, - credentials, - prompt_messages, - model_parameters, - tools=None, - stop=None, - stream=True, - user=None, - ): - self.invoke_error_calls.append( - { - "llm_instance": llm_instance, - "ex": ex, - "model": model, - "credentials": credentials, - "prompt_messages": prompt_messages, - "model_parameters": model_parameters, - "tools": tools, - "stop": stop, - "stream": stream, - "user": user, - } - ) - try: - super().on_invoke_error( - llm_instance, ex, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user - ) - except NotImplementedError: - pass - - -# --------------------------------------------------------------------------- -# A subclass that deliberately leaves abstract methods un-implemented, -# used to verify that instantiation raises TypeError. -# --------------------------------------------------------------------------- - - -# =========================================================================== -# Tests for _TEXT_COLOR_MAPPING module-level constant -# =========================================================================== - - -class TestTextColorMapping: - """Tests for the module-level _TEXT_COLOR_MAPPING dictionary.""" - - def test_contains_all_expected_colors(self): - expected_keys = {"blue", "yellow", "pink", "green", "red"} - assert set(_TEXT_COLOR_MAPPING.keys()) == expected_keys - - def test_blue_escape_code(self): - assert _TEXT_COLOR_MAPPING["blue"] == "36;1" - - def test_yellow_escape_code(self): - assert _TEXT_COLOR_MAPPING["yellow"] == "33;1" - - def test_pink_escape_code(self): - assert _TEXT_COLOR_MAPPING["pink"] == "38;5;200" - - def test_green_escape_code(self): - assert _TEXT_COLOR_MAPPING["green"] == "32;1" - - def test_red_escape_code(self): - assert _TEXT_COLOR_MAPPING["red"] == "31;1" - - def test_mapping_is_dict(self): - assert isinstance(_TEXT_COLOR_MAPPING, dict) - - def test_all_values_are_strings(self): - for key, value in _TEXT_COLOR_MAPPING.items(): - assert isinstance(value, str), f"Value for {key!r} should be str" - - -# =========================================================================== -# Tests for the Callback ABC itself -# =========================================================================== - - -class TestCallbackAbstract: - """Tests verifying Callback is a proper ABC.""" - - def test_cannot_instantiate_abstract_class_directly(self): - """Callback cannot be instantiated since it has abstract methods.""" - with pytest.raises(TypeError): - Callback() # type: ignore[abstract] - - def test_concrete_subclass_can_be_instantiated(self): - cb = ConcreteCallback() - assert isinstance(cb, Callback) - - def test_default_raise_error_is_false(self): - cb = ConcreteCallback() - assert cb.raise_error is False - - def test_raise_error_can_be_set_to_true(self): - cb = ConcreteCallback(raise_error=True) - assert cb.raise_error is True - - def test_subclass_missing_on_before_invoke_raises_type_error(self): - """A subclass missing any single abstract method cannot be instantiated.""" - - class IncompleteCallback(Callback): - def on_new_chunk(self, *a, **kw): ... - def on_after_invoke(self, *a, **kw): ... - def on_invoke_error(self, *a, **kw): ... - - with pytest.raises(TypeError): - IncompleteCallback() # type: ignore[abstract] - - def test_subclass_missing_on_new_chunk_raises_type_error(self): - class IncompleteCallback(Callback): - def on_before_invoke(self, *a, **kw): ... - def on_after_invoke(self, *a, **kw): ... - def on_invoke_error(self, *a, **kw): ... - - with pytest.raises(TypeError): - IncompleteCallback() # type: ignore[abstract] - - def test_subclass_missing_on_after_invoke_raises_type_error(self): - class IncompleteCallback(Callback): - def on_before_invoke(self, *a, **kw): ... - def on_new_chunk(self, *a, **kw): ... - def on_invoke_error(self, *a, **kw): ... - - with pytest.raises(TypeError): - IncompleteCallback() # type: ignore[abstract] - - def test_subclass_missing_on_invoke_error_raises_type_error(self): - class IncompleteCallback(Callback): - def on_before_invoke(self, *a, **kw): ... - def on_new_chunk(self, *a, **kw): ... - def on_after_invoke(self, *a, **kw): ... - - with pytest.raises(TypeError): - IncompleteCallback() # type: ignore[abstract] - - -# =========================================================================== -# Tests for on_before_invoke -# =========================================================================== - - -class TestOnBeforeInvoke: - """Tests for the on_before_invoke callback method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - self.llm_instance = MagicMock() - self.model = "gpt-4" - self.credentials = {"api_key": "sk-test"} - self.prompt_messages = [MagicMock(spec=PromptMessage)] - self.model_parameters = {"temperature": 0.7} - - def test_on_before_invoke_called_with_required_args(self): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.before_invoke_calls) == 1 - call = self.cb.before_invoke_calls[0] - assert call["llm_instance"] is self.llm_instance - assert call["model"] == self.model - assert call["credentials"] == self.credentials - assert call["prompt_messages"] is self.prompt_messages - assert call["model_parameters"] is self.model_parameters - - def test_on_before_invoke_defaults_tools_none(self): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.before_invoke_calls[0]["tools"] is None - - def test_on_before_invoke_defaults_stop_none(self): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.before_invoke_calls[0]["stop"] is None - - def test_on_before_invoke_defaults_stream_true(self): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.before_invoke_calls[0]["stream"] is True - - def test_on_before_invoke_defaults_user_none(self): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.before_invoke_calls[0]["user"] is None - - def test_on_before_invoke_with_all_optional_args(self): - tools = [MagicMock(spec=PromptMessageTool)] - stop = ["stop1", "stop2"] - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - tools=tools, - stop=stop, - stream=False, - user="user-123", - ) - call = self.cb.before_invoke_calls[0] - assert call["tools"] is tools - assert call["stop"] == stop - assert call["stream"] is False - assert call["user"] == "user-123" - - def test_on_before_invoke_called_multiple_times(self): - for i in range(3): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=f"model-{i}", - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.before_invoke_calls) == 3 - assert self.cb.before_invoke_calls[2]["model"] == "model-2" - - -# =========================================================================== -# Tests for on_new_chunk -# =========================================================================== - - -class TestOnNewChunk: - """Tests for the on_new_chunk callback method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - self.llm_instance = MagicMock() - self.chunk = MagicMock(spec=LLMResultChunk) - self.model = "gpt-3.5-turbo" - self.credentials = {"api_key": "sk-test"} - self.prompt_messages = [MagicMock(spec=PromptMessage)] - self.model_parameters = {"max_tokens": 256} - - def test_on_new_chunk_called_with_required_args(self): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.new_chunk_calls) == 1 - call = self.cb.new_chunk_calls[0] - assert call["llm_instance"] is self.llm_instance - assert call["chunk"] is self.chunk - assert call["model"] == self.model - assert call["credentials"] == self.credentials - - def test_on_new_chunk_defaults_tools_none(self): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.new_chunk_calls[0]["tools"] is None - - def test_on_new_chunk_defaults_stop_none(self): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.new_chunk_calls[0]["stop"] is None - - def test_on_new_chunk_defaults_stream_true(self): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.new_chunk_calls[0]["stream"] is True - - def test_on_new_chunk_defaults_user_none(self): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.new_chunk_calls[0]["user"] is None - - def test_on_new_chunk_with_all_optional_args(self): - tools = [MagicMock(spec=PromptMessageTool)] - stop = ["END"] - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - tools=tools, - stop=stop, - stream=False, - user="chunk-user", - ) - call = self.cb.new_chunk_calls[0] - assert call["tools"] is tools - assert call["stop"] == stop - assert call["stream"] is False - assert call["user"] == "chunk-user" - - def test_on_new_chunk_called_multiple_times(self): - for i in range(5): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.new_chunk_calls) == 5 - - -# =========================================================================== -# Tests for on_after_invoke -# =========================================================================== - - -class TestOnAfterInvoke: - """Tests for the on_after_invoke callback method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - self.llm_instance = MagicMock() - self.result = MagicMock(spec=LLMResult) - self.model = "claude-3" - self.credentials = {"api_key": "anthropic-key"} - self.prompt_messages = [MagicMock(spec=PromptMessage)] - self.model_parameters = {"temperature": 1.0} - - def test_on_after_invoke_called_with_required_args(self): - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.after_invoke_calls) == 1 - call = self.cb.after_invoke_calls[0] - assert call["llm_instance"] is self.llm_instance - assert call["result"] is self.result - assert call["model"] == self.model - assert call["credentials"] is self.credentials - - def test_on_after_invoke_defaults_tools_none(self): - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.after_invoke_calls[0]["tools"] is None - - def test_on_after_invoke_defaults_stop_none(self): - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.after_invoke_calls[0]["stop"] is None - - def test_on_after_invoke_defaults_stream_true(self): - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.after_invoke_calls[0]["stream"] is True - - def test_on_after_invoke_defaults_user_none(self): - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.after_invoke_calls[0]["user"] is None - - def test_on_after_invoke_with_all_optional_args(self): - tools = [MagicMock(spec=PromptMessageTool)] - stop = ["STOP"] - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - tools=tools, - stop=stop, - stream=False, - user="after-user", - ) - call = self.cb.after_invoke_calls[0] - assert call["tools"] is tools - assert call["stop"] == stop - assert call["stream"] is False - assert call["user"] == "after-user" - - -# =========================================================================== -# Tests for on_invoke_error -# =========================================================================== - - -class TestOnInvokeError: - """Tests for the on_invoke_error callback method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - self.llm_instance = MagicMock() - self.ex = ValueError("something went wrong") - self.model = "gemini-pro" - self.credentials = {"api_key": "google-key"} - self.prompt_messages = [MagicMock(spec=PromptMessage)] - self.model_parameters = {"top_p": 0.9} - - def test_on_invoke_error_called_with_required_args(self): - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.invoke_error_calls) == 1 - call = self.cb.invoke_error_calls[0] - assert call["llm_instance"] is self.llm_instance - assert call["ex"] is self.ex - assert call["model"] == self.model - assert call["credentials"] is self.credentials - - def test_on_invoke_error_defaults_tools_none(self): - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.invoke_error_calls[0]["tools"] is None - - def test_on_invoke_error_defaults_stop_none(self): - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.invoke_error_calls[0]["stop"] is None - - def test_on_invoke_error_defaults_stream_true(self): - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.invoke_error_calls[0]["stream"] is True - - def test_on_invoke_error_defaults_user_none(self): - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.invoke_error_calls[0]["user"] is None - - def test_on_invoke_error_with_all_optional_args(self): - tools = [MagicMock(spec=PromptMessageTool)] - stop = ["HALT"] - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - tools=tools, - stop=stop, - stream=False, - user="error-user", - ) - call = self.cb.invoke_error_calls[0] - assert call["tools"] is tools - assert call["stop"] == stop - assert call["stream"] is False - assert call["user"] == "error-user" - - def test_on_invoke_error_accepts_various_exception_types(self): - for exc in [RuntimeError("r"), KeyError("k"), Exception("e")]: - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=exc, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.invoke_error_calls) == 3 - - -# =========================================================================== -# Tests for print_text (concrete method on Callback) -# =========================================================================== - - -class TestPrintText: - """Tests for the concrete print_text method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - - def test_print_text_without_color_prints_plain_text(self, capsys): - self.cb.print_text("hello world") - captured = capsys.readouterr() - assert captured.out == "hello world" - - def test_print_text_with_color_prints_colored_text(self, capsys): - self.cb.print_text("colored text", color="blue") - captured = capsys.readouterr() - # Should contain ANSI escape sequences - assert "colored text" in captured.out - assert "\001b[" in captured.out or "\033[" in captured.out or "\x1b[" in captured.out - - def test_print_text_without_color_no_ansi(self, capsys): - self.cb.print_text("plain text", color=None) - captured = capsys.readouterr() - assert captured.out == "plain text" - # No ANSI escape sequences - assert "\x1b" not in captured.out - - def test_print_text_default_end_is_empty_string(self, capsys): - self.cb.print_text("no newline") - captured = capsys.readouterr() - assert not captured.out.endswith("\n") - - def test_print_text_with_custom_end(self, capsys): - self.cb.print_text("with newline", end="\n") - captured = capsys.readouterr() - assert captured.out.endswith("\n") - - def test_print_text_with_empty_string(self, capsys): - self.cb.print_text("", color=None) - captured = capsys.readouterr() - assert captured.out == "" - - @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) - def test_print_text_all_colors_work(self, color, capsys): - """Verify no KeyError is thrown for any valid color.""" - self.cb.print_text("test", color=color) - captured = capsys.readouterr() - assert "test" in captured.out - - def test_print_text_calls_get_colored_text_when_color_given(self): - with patch.object(self.cb, "_get_colored_text", return_value="[COLORED]") as mock_gct: - with patch("builtins.print") as mock_print: - self.cb.print_text("hello", color="green") - mock_gct.assert_called_once_with("hello", "green") - mock_print.assert_called_once_with("[COLORED]", end="") - - def test_print_text_does_not_call_get_colored_text_when_no_color(self): - with patch.object(self.cb, "_get_colored_text") as mock_gct: - with patch("builtins.print"): - self.cb.print_text("hello", color=None) - mock_gct.assert_not_called() - - def test_print_text_passes_end_to_print(self): - with patch("builtins.print") as mock_print: - self.cb.print_text("text", end="---") - mock_print.assert_called_once_with("text", end="---") - - -# =========================================================================== -# Tests for _get_colored_text (private helper method) -# =========================================================================== - - -class TestGetColoredText: - """Tests for the _get_colored_text private method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - - @pytest.mark.parametrize(("color", "expected_code"), list(_TEXT_COLOR_MAPPING.items())) - def test_get_colored_text_uses_correct_escape_code(self, color, expected_code): - result = self.cb._get_colored_text("text", color) - assert expected_code in result - - @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) - def test_get_colored_text_contains_input_text(self, color): - result = self.cb._get_colored_text("hello", color) - assert "hello" in result - - @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) - def test_get_colored_text_starts_with_escape(self, color): - result = self.cb._get_colored_text("text", color) - # Should start with an ANSI escape (\x1b or \u001b) - assert result.startswith("\x1b[") or result.startswith("\u001b[") - - @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) - def test_get_colored_text_ends_with_reset(self, color): - result = self.cb._get_colored_text("text", color) - # Should end with the ANSI reset code - assert result.endswith("\x1b[0m") or result.endswith("\u001b[0m") - - def test_get_colored_text_returns_string(self): - result = self.cb._get_colored_text("text", "blue") - assert isinstance(result, str) - - def test_get_colored_text_blue_exact_format(self): - result = self.cb._get_colored_text("hello", "blue") - expected = f"\u001b[{_TEXT_COLOR_MAPPING['blue']}m\033[1;3mhello\u001b[0m" - assert result == expected - - def test_get_colored_text_red_exact_format(self): - result = self.cb._get_colored_text("error", "red") - expected = f"\u001b[{_TEXT_COLOR_MAPPING['red']}m\033[1;3merror\u001b[0m" - assert result == expected - - def test_get_colored_text_green_exact_format(self): - result = self.cb._get_colored_text("ok", "green") - expected = f"\u001b[{_TEXT_COLOR_MAPPING['green']}m\033[1;3mok\u001b[0m" - assert result == expected - - def test_get_colored_text_yellow_exact_format(self): - result = self.cb._get_colored_text("warn", "yellow") - expected = f"\u001b[{_TEXT_COLOR_MAPPING['yellow']}m\033[1;3mwarn\u001b[0m" - assert result == expected - - def test_get_colored_text_pink_exact_format(self): - result = self.cb._get_colored_text("info", "pink") - expected = f"\u001b[{_TEXT_COLOR_MAPPING['pink']}m\033[1;3minfo\u001b[0m" - assert result == expected - - def test_get_colored_text_empty_string(self): - result = self.cb._get_colored_text("", "blue") - assert isinstance(result, str) - # Empty text should still have escape codes - assert _TEXT_COLOR_MAPPING["blue"] in result - - def test_get_colored_text_invalid_color_raises_key_error(self): - with pytest.raises(KeyError): - self.cb._get_colored_text("text", "purple") - - def test_get_colored_text_with_special_characters(self): - special = "hello\nworld\ttab" - result = self.cb._get_colored_text(special, "blue") - assert special in result - - def test_get_colored_text_with_long_text(self): - long_text = "a" * 10000 - result = self.cb._get_colored_text(long_text, "green") - assert long_text in result - - -# =========================================================================== -# Integration-style tests: full workflow through a ConcreteCallback -# =========================================================================== - - -class TestConcreteCallbackIntegration: - """End-to-end workflow tests using ConcreteCallback.""" - - def test_full_invocation_lifecycle(self): - """Simulate a complete LLM invocation lifecycle through all callbacks.""" - cb = ConcreteCallback() - llm_instance = MagicMock() - model = "gpt-4o" - credentials = {"api_key": "sk-xyz"} - prompt_messages = [MagicMock(spec=PromptMessage)] - model_parameters = {"temperature": 0.5} - tools = [MagicMock(spec=PromptMessageTool)] - stop = [""] - user = "user-abc" - - # 1. Before invoke - cb.on_before_invoke( - llm_instance=llm_instance, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=True, - user=user, - ) - - # 2. Multiple chunks during streaming - for i in range(3): - chunk = MagicMock(spec=LLMResultChunk) - cb.on_new_chunk( - llm_instance=llm_instance, - chunk=chunk, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=True, - user=user, - ) - - # 3. After invoke - result = MagicMock(spec=LLMResult) - cb.on_after_invoke( - llm_instance=llm_instance, - result=result, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=True, - user=user, - ) - - assert len(cb.before_invoke_calls) == 1 - assert len(cb.new_chunk_calls) == 3 - assert len(cb.after_invoke_calls) == 1 - assert len(cb.invoke_error_calls) == 0 - - def test_error_lifecycle(self): - """Simulate an invoke that results in an error.""" - cb = ConcreteCallback() - llm_instance = MagicMock() - model = "gpt-4" - credentials = {} - prompt_messages = [] - model_parameters = {} - - cb.on_before_invoke( - llm_instance=llm_instance, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - ) - - ex = RuntimeError("API timeout") - cb.on_invoke_error( - llm_instance=llm_instance, - ex=ex, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - ) - - assert len(cb.before_invoke_calls) == 1 - assert len(cb.invoke_error_calls) == 1 - assert cb.invoke_error_calls[0]["ex"] is ex - assert len(cb.after_invoke_calls) == 0 - - def test_print_text_with_color_in_integration(self, capsys): - """verify print_text works correctly in a concrete instance.""" - cb = ConcreteCallback() - cb.print_text("SUCCESS", color="green", end="\n") - captured = capsys.readouterr() - assert "SUCCESS" in captured.out - assert "\n" in captured.out - - def test_print_text_no_color_in_integration(self, capsys): - cb = ConcreteCallback() - cb.print_text("plain output") - captured = capsys.readouterr() - assert captured.out == "plain output" diff --git a/api/tests/unit_tests/graphon/model_runtime/callbacks/test_logging_callback.py b/api/tests/unit_tests/graphon/model_runtime/callbacks/test_logging_callback.py deleted file mode 100644 index df9215826c5..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/callbacks/test_logging_callback.py +++ /dev/null @@ -1,700 +0,0 @@ -""" -Comprehensive unit tests for core/model_runtime/callbacks/logging_callback.py - -Coverage targets: - - LoggingCallback.on_before_invoke (all branches: stop, tools, user, stream, - prompt_message.name, model_parameters) - - LoggingCallback.on_new_chunk (writes to stdout) - - LoggingCallback.on_after_invoke (all branches: tool_calls present / absent) - - LoggingCallback.on_invoke_error (logs exception via logger.exception) -""" - -from __future__ import annotations - -import json -from collections.abc import Sequence -from decimal import Decimal -from unittest.mock import MagicMock, patch - -import pytest - -from graphon.model_runtime.callbacks.logging_callback import LoggingCallback -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMUsage, -) -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageTool, - SystemPromptMessage, - UserPromptMessage, -) - -# --------------------------------------------------------------------------- -# Shared helpers -# --------------------------------------------------------------------------- - - -def _make_usage() -> LLMUsage: - """Return a minimal LLMUsage instance.""" - return LLMUsage( - prompt_tokens=10, - prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal("0.001"), - prompt_price=Decimal("0.01"), - completion_tokens=20, - completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal("0.002"), - completion_price=Decimal("0.04"), - total_tokens=30, - total_price=Decimal("0.05"), - currency="USD", - latency=0.5, - ) - - -def _make_llm_result( - content: str = "hello world", - tool_calls: list | None = None, - model: str = "gpt-4", - system_fingerprint: str | None = "fp-abc", -) -> LLMResult: - """Return an LLMResult with an AssistantPromptMessage.""" - assistant_msg = AssistantPromptMessage( - content=content, - tool_calls=tool_calls or [], - ) - return LLMResult( - model=model, - message=assistant_msg, - usage=_make_usage(), - system_fingerprint=system_fingerprint, - ) - - -def _make_chunk(content: str = "chunk-text") -> LLMResultChunk: - """Return a minimal LLMResultChunk.""" - return LLMResultChunk( - model="gpt-4", - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=content), - ), - ) - - -def _make_user_prompt(content: str = "Hello!", name: str | None = None) -> UserPromptMessage: - return UserPromptMessage(content=content, name=name) - - -def _make_system_prompt(content: str = "You are helpful.") -> SystemPromptMessage: - return SystemPromptMessage(content=content) - - -def _make_tool(name: str = "my_tool") -> PromptMessageTool: - return PromptMessageTool(name=name, description="A tool", parameters={}) - - -def _make_tool_call( - call_id: str = "call-1", - func_name: str = "some_func", - arguments: str = '{"key": "value"}', -) -> AssistantPromptMessage.ToolCall: - return AssistantPromptMessage.ToolCall( - id=call_id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=func_name, arguments=arguments), - ) - - -# --------------------------------------------------------------------------- -# Fixture: shared LoggingCallback instance (no heavy state) -# --------------------------------------------------------------------------- - - -@pytest.fixture -def cb() -> LoggingCallback: - return LoggingCallback() - - -@pytest.fixture -def llm_instance() -> MagicMock: - return MagicMock() - - -# =========================================================================== -# Tests for on_before_invoke -# =========================================================================== - - -class TestOnBeforeInvoke: - """Tests for LoggingCallback.on_before_invoke.""" - - def _invoke( - self, - cb: LoggingCallback, - llm_instance: MagicMock, - *, - model: str = "gpt-4", - credentials: dict | None = None, - prompt_messages: list | None = None, - model_parameters: dict | None = None, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - cb.on_before_invoke( - llm_instance=llm_instance, - model=model, - credentials=credentials or {}, - prompt_messages=prompt_messages or [], - model_parameters=model_parameters or {}, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - - def test_minimal_call_does_not_raise(self, cb: LoggingCallback, llm_instance: MagicMock): - """Calling with bare-minimum args should not raise.""" - self._invoke(cb, llm_instance) - - def test_model_name_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """The model name must appear in print_text calls.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, model="claude-3") - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "claude-3" in calls_text - - def test_model_parameters_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """Each key-value pair of model_parameters must be printed.""" - params = {"temperature": 0.7, "max_tokens": 512} - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, model_parameters=params) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "temperature" in calls_text - assert "0.7" in calls_text - assert "max_tokens" in calls_text - assert "512" in calls_text - - def test_empty_model_parameters(self, cb: LoggingCallback, llm_instance: MagicMock): - """Empty model_parameters dict should not raise.""" - self._invoke(cb, llm_instance, model_parameters={}) - - # ------------------------------------------------------------------ - # stop branch - # ------------------------------------------------------------------ - - def test_stop_branch_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): - """stop words must appear in output when provided.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, stop=["STOP", "END"]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "stop" in calls_text - - def test_stop_branch_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): - """When stop=None the stop line must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, stop=None) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "\tstop:" not in calls_text - - def test_stop_branch_skipped_when_empty_list(self, cb: LoggingCallback, llm_instance: MagicMock): - """When stop=[] (falsy) the stop line must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, stop=[]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "\tstop:" not in calls_text - - # ------------------------------------------------------------------ - # tools branch - # ------------------------------------------------------------------ - - def test_tools_branch_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): - """Tool names must appear in output when tools are provided.""" - tools = [_make_tool("search"), _make_tool("calculate")] - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, tools=tools) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "search" in calls_text - assert "calculate" in calls_text - - def test_tools_branch_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): - """When tools=None the Tools section must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, tools=None) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "Tools:" not in calls_text - - def test_tools_branch_skipped_when_empty_list(self, cb: LoggingCallback, llm_instance: MagicMock): - """When tools=[] (falsy) the Tools section must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, tools=[]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "Tools:" not in calls_text - - # ------------------------------------------------------------------ - # user branch - # ------------------------------------------------------------------ - - def test_user_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): - """User string must appear in output when provided.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, user="alice") - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "alice" in calls_text - - def test_user_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): - """When user=None the User line must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, user=None) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "User:" not in calls_text - - # ------------------------------------------------------------------ - # stream branch - # ------------------------------------------------------------------ - - def test_stream_true_prints_new_chunk_header(self, cb: LoggingCallback, llm_instance: MagicMock): - """When stream=True the [on_llm_new_chunk] marker must be printed.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, stream=True) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "[on_llm_new_chunk]" in calls_text - - def test_stream_false_no_new_chunk_header(self, cb: LoggingCallback, llm_instance: MagicMock): - """When stream=False the [on_llm_new_chunk] marker must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, stream=False) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "[on_llm_new_chunk]" not in calls_text - - # ------------------------------------------------------------------ - # prompt_messages branch - # ------------------------------------------------------------------ - - def test_prompt_message_with_name_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """When a PromptMessage has a name it must be printed.""" - msg = _make_user_prompt("hi", name="bob") - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, prompt_messages=[msg]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "bob" in calls_text - - def test_prompt_message_without_name_skips_name_line(self, cb: LoggingCallback, llm_instance: MagicMock): - """When a PromptMessage has no name the name line must NOT appear.""" - msg = _make_user_prompt("hi", name=None) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, prompt_messages=[msg]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "\tname:" not in calls_text - - def test_prompt_message_role_and_content_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """Role and content of each PromptMessage must appear in output.""" - msg = _make_system_prompt("Be concise.") - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, prompt_messages=[msg]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "system" in calls_text - assert "Be concise." in calls_text - - def test_multiple_prompt_messages_all_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """All entries in prompt_messages are iterated and printed.""" - msgs = [ - _make_system_prompt("sys"), - _make_user_prompt("user msg"), - ] - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, prompt_messages=msgs) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "sys" in calls_text - assert "user msg" in calls_text - - # ------------------------------------------------------------------ - # Combination: everything provided - # ------------------------------------------------------------------ - - def test_all_optional_fields_combined(self, cb: LoggingCallback, llm_instance: MagicMock): - """Supply stop, tools, user, multiple params, named message โ€“ no exception.""" - msgs = [_make_user_prompt("question", name="alice")] - tools = [_make_tool("tool_a")] - with patch.object(cb, "print_text"): - self._invoke( - cb, - llm_instance, - model="gpt-3.5", - model_parameters={"temperature": 1.0, "top_p": 0.9}, - tools=tools, - stop=["DONE"], - stream=True, - user="alice", - prompt_messages=msgs, - ) - - -# =========================================================================== -# Tests for on_new_chunk -# =========================================================================== - - -class TestOnNewChunk: - """Tests for LoggingCallback.on_new_chunk.""" - - def test_chunk_content_written_to_stdout(self, cb: LoggingCallback, llm_instance: MagicMock): - """on_new_chunk must write the chunk's text content to sys.stdout.""" - chunk = _make_chunk("hello from LLM") - written = [] - - with patch("sys.stdout") as mock_stdout: - mock_stdout.write.side_effect = written.append - cb.on_new_chunk( - llm_instance=llm_instance, - chunk=chunk, - model="gpt-4", - credentials={}, - prompt_messages=[], - model_parameters={}, - ) - mock_stdout.write.assert_called_once_with("hello from LLM") - mock_stdout.flush.assert_called_once() - - def test_chunk_content_empty_string(self, cb: LoggingCallback, llm_instance: MagicMock): - """Works correctly even when the chunk content is an empty string.""" - chunk = _make_chunk("") - with patch("sys.stdout") as mock_stdout: - cb.on_new_chunk( - llm_instance=llm_instance, - chunk=chunk, - model="gpt-4", - credentials={}, - prompt_messages=[], - model_parameters={}, - ) - mock_stdout.write.assert_called_once_with("") - mock_stdout.flush.assert_called_once() - - def test_chunk_passes_all_optional_params(self, cb: LoggingCallback, llm_instance: MagicMock): - """All optional parameters are accepted without errors.""" - chunk = _make_chunk("data") - with patch("sys.stdout"): - cb.on_new_chunk( - llm_instance=llm_instance, - chunk=chunk, - model="gpt-4", - credentials={"key": "secret"}, - prompt_messages=[_make_user_prompt("q")], - model_parameters={"temperature": 0.5}, - tools=[_make_tool("t1")], - stop=["EOS"], - stream=True, - user="bob", - ) - - -# =========================================================================== -# Tests for on_after_invoke -# =========================================================================== - - -class TestOnAfterInvoke: - """Tests for LoggingCallback.on_after_invoke.""" - - def _invoke( - self, - cb: LoggingCallback, - llm_instance: MagicMock, - result: LLMResult, - **kwargs, - ): - cb.on_after_invoke( - llm_instance=llm_instance, - result=result, - model=result.model, - credentials={}, - prompt_messages=[], - model_parameters={}, - **kwargs, - ) - - def test_basic_result_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """After-invoke header, content, model, usage, fingerprint must be printed.""" - result = _make_llm_result() - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "[on_llm_after_invoke]" in calls_text - assert "hello world" in calls_text - assert "gpt-4" in calls_text - assert "fp-abc" in calls_text - - def test_no_tool_calls_skips_tool_call_block(self, cb: LoggingCallback, llm_instance: MagicMock): - """When there are no tool_calls the 'Tool calls:' block must NOT appear.""" - result = _make_llm_result(tool_calls=[]) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "Tool calls:" not in calls_text - - def test_with_tool_calls_prints_all_fields(self, cb: LoggingCallback, llm_instance: MagicMock): - """When tool_calls exist their id, name, and JSON arguments must be printed.""" - tc = _make_tool_call( - call_id="call-xyz", - func_name="fetch_data", - arguments='{"url": "https://example.com"}', - ) - result = _make_llm_result(tool_calls=[tc]) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "Tool calls:" in calls_text - assert "call-xyz" in calls_text - assert "fetch_data" in calls_text - # arguments should be JSON-dumped - assert "https://example.com" in calls_text - - def test_multiple_tool_calls_all_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """All tool calls in the list must be iterated.""" - tcs = [ - _make_tool_call("id-1", "func_a", '{"a": 1}'), - _make_tool_call("id-2", "func_b", '{"b": 2}'), - ] - result = _make_llm_result(tool_calls=tcs) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "id-1" in calls_text - assert "func_a" in calls_text - assert "id-2" in calls_text - assert "func_b" in calls_text - - def test_system_fingerprint_none_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """When system_fingerprint is None it should still be printed (as None).""" - result = _make_llm_result(system_fingerprint=None) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "System Fingerprint: None" in calls_text - - def test_usage_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """The usage object must appear in the printed output.""" - result = _make_llm_result() - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "Usage:" in calls_text - - def test_tool_call_arguments_are_json_dumped(self, cb: LoggingCallback, llm_instance: MagicMock): - """Verify json.dumps is applied to the arguments field (a string).""" - raw_args = '{"x": 42}' - tc = _make_tool_call(arguments=raw_args) - result = _make_llm_result(tool_calls=[tc]) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - - # Check if any call to print_text included the expected (json-encoded) arguments - # json.dumps(raw_args) produces a string starting and ending with quotes - expected_substring = json.dumps(raw_args) - found = any(expected_substring in str(call.args[0]) for call in mock_print.call_args_list) - assert found, f"Expected {expected_substring} to be printed in one of the calls" - - def test_optional_params_accepted(self, cb: LoggingCallback, llm_instance: MagicMock): - """All optional parameters should be accepted without error.""" - result = _make_llm_result() - cb.on_after_invoke( - llm_instance=llm_instance, - result=result, - model=result.model, - credentials={"key": "secret"}, - prompt_messages=[_make_user_prompt("q")], - model_parameters={"temperature": 0.9}, - tools=[_make_tool("t")], - stop=[""], - stream=False, - user="carol", - ) - - -# =========================================================================== -# Tests for on_invoke_error -# =========================================================================== - - -class TestOnInvokeError: - """Tests for LoggingCallback.on_invoke_error.""" - - def _invoke_error( - self, - cb: LoggingCallback, - llm_instance: MagicMock, - ex: Exception, - **kwargs, - ): - cb.on_invoke_error( - llm_instance=llm_instance, - ex=ex, - model="gpt-4", - credentials={}, - prompt_messages=[], - model_parameters={}, - **kwargs, - ) - - def test_prints_error_header(self, cb: LoggingCallback, llm_instance: MagicMock): - """The [on_llm_invoke_error] banner must be printed.""" - with patch.object(cb, "print_text") as mock_print: - with patch("graphon.model_runtime.callbacks.logging_callback.logger") as mock_logger: - self._invoke_error(cb, llm_instance, RuntimeError("boom")) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "[on_llm_invoke_error]" in calls_text - - def test_exception_logged_via_logger_exception(self, cb: LoggingCallback, llm_instance: MagicMock): - """logger.exception must be called with the exception.""" - ex = ValueError("something went wrong") - with patch.object(cb, "print_text"): - with patch("graphon.model_runtime.callbacks.logging_callback.logger") as mock_logger: - self._invoke_error(cb, llm_instance, ex) - mock_logger.exception.assert_called_once_with(ex) - - def test_exception_type_variety(self, cb: LoggingCallback, llm_instance: MagicMock): - """Works with any exception type (TypeError, IOError, etc.).""" - for exc_cls in (TypeError, IOError, KeyError, Exception): - ex = exc_cls("error") - with patch.object(cb, "print_text"): - with patch("graphon.model_runtime.callbacks.logging_callback.logger") as mock_logger: - self._invoke_error(cb, llm_instance, ex) - mock_logger.exception.assert_called_once_with(ex) - - def test_optional_params_accepted(self, cb: LoggingCallback, llm_instance: MagicMock): - """All optional parameters should be accepted without error.""" - ex = RuntimeError("fail") - with patch.object(cb, "print_text"): - with patch("graphon.model_runtime.callbacks.logging_callback.logger"): - cb.on_invoke_error( - llm_instance=llm_instance, - ex=ex, - model="gpt-4", - credentials={"key": "secret"}, - prompt_messages=[_make_user_prompt("q")], - model_parameters={"temperature": 0.7}, - tools=[_make_tool("t")], - stop=["STOP"], - stream=True, - user="dave", - ) - - -# =========================================================================== -# Tests for print_text (inherited from Callback, exercised through LoggingCallback) -# =========================================================================== - - -class TestPrintText: - """Verify that print_text from the Callback base class works correctly.""" - - def test_print_text_with_color(self, cb: LoggingCallback, capsys): - """print_text with a known colour should emit an ANSI escape sequence.""" - cb.print_text("hello", color="blue") - captured = capsys.readouterr() - assert "hello" in captured.out - # ANSI escape codes should be present - assert "\x1b[" in captured.out - - def test_print_text_without_color(self, cb: LoggingCallback, capsys): - """print_text without colour should print plain text.""" - cb.print_text("plain text") - captured = capsys.readouterr() - assert "plain text" in captured.out - - def test_print_text_all_colours(self, cb: LoggingCallback, capsys): - """Verify all supported colour keys don't raise.""" - for colour in ("blue", "yellow", "pink", "green", "red"): - cb.print_text("x", color=colour) - captured = capsys.readouterr() - # All outputs should contain 'x' (5 calls) - assert captured.out.count("x") >= 5 - - -# =========================================================================== -# Integration-style test: real print_text called (no mocking) -# =========================================================================== - - -class TestLoggingCallbackIntegration: - """Light integration tests โ€“ real print_text calls, just checking no exceptions.""" - - def test_on_before_invoke_full_run(self, capsys): - """Full on_before_invoke run with all optional fields โ€“ verifies real output.""" - cb = LoggingCallback() - llm = MagicMock() - msgs = [_make_user_prompt("Who are you?", name="tester")] - tools = [_make_tool("calculator")] - cb.on_before_invoke( - llm_instance=llm, - model="gpt-4-turbo", - credentials={"api_key": "sk-xxx"}, - prompt_messages=msgs, - model_parameters={"temperature": 0.8}, - tools=tools, - stop=["STOP"], - stream=True, - user="test_user", - ) - captured = capsys.readouterr() - assert "gpt-4-turbo" in captured.out - assert "calculator" in captured.out - assert "test_user" in captured.out - assert "STOP" in captured.out - assert "tester" in captured.out - - def test_on_new_chunk_full_run(self, capsys): - """Full on_new_chunk run โ€“ verifies real stdout write.""" - cb = LoggingCallback() - chunk = _make_chunk("streaming token") - cb.on_new_chunk( - llm_instance=MagicMock(), - chunk=chunk, - model="gpt-4", - credentials={}, - prompt_messages=[], - model_parameters={}, - ) - captured = capsys.readouterr() - assert "streaming token" in captured.out - - def test_on_after_invoke_full_run_with_tool_calls(self, capsys): - """Full on_after_invoke run with tool calls โ€“ verifies real output.""" - cb = LoggingCallback() - tc = _make_tool_call("call-99", "do_thing", '{"n": 5}') - result = _make_llm_result(content="result content", tool_calls=[tc], system_fingerprint="fp-xyz") - cb.on_after_invoke( - llm_instance=MagicMock(), - result=result, - model=result.model, - credentials={}, - prompt_messages=[], - model_parameters={}, - ) - captured = capsys.readouterr() - assert "result content" in captured.out - assert "call-99" in captured.out - assert "do_thing" in captured.out - assert "fp-xyz" in captured.out - - def test_on_invoke_error_full_run(self, capsys): - """Full on_invoke_error run โ€“ just verifies no exception is raised.""" - cb = LoggingCallback() - ex = RuntimeError("something bad happened") - # logger.exception writes to stderr; we just confirm it doesn't crash - cb.on_invoke_error( - llm_instance=MagicMock(), - ex=ex, - model="gpt-4", - credentials={}, - prompt_messages=[], - model_parameters={}, - ) - captured = capsys.readouterr() - assert "[on_llm_invoke_error]" in captured.out diff --git a/api/tests/unit_tests/graphon/model_runtime/entities/test_common_entities.py b/api/tests/unit_tests/graphon/model_runtime/entities/test_common_entities.py deleted file mode 100644 index 7d6255c37aa..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/entities/test_common_entities.py +++ /dev/null @@ -1,35 +0,0 @@ -from graphon.model_runtime.entities.common_entities import I18nObject - - -class TestI18nObject: - def test_i18n_object_with_both_languages(self): - """ - Test I18nObject when both zh_Hans and en_US are provided. - """ - i18n = I18nObject(zh_Hans="ไฝ ๅฅฝ", en_US="Hello") - assert i18n.zh_Hans == "ไฝ ๅฅฝ" - assert i18n.en_US == "Hello" - - def test_i18n_object_fallback_to_en_us(self): - """ - Test I18nObject when zh_Hans is missing, it should fallback to en_US. - """ - i18n = I18nObject(en_US="Hello") - assert i18n.zh_Hans == "Hello" - assert i18n.en_US == "Hello" - - def test_i18n_object_with_none_zh_hans(self): - """ - Test I18nObject when zh_Hans is None, it should fallback to en_US. - """ - i18n = I18nObject(zh_Hans=None, en_US="Hello") - assert i18n.zh_Hans == "Hello" - assert i18n.en_US == "Hello" - - def test_i18n_object_with_empty_zh_hans(self): - """ - Test I18nObject when zh_Hans is an empty string, it should fallback to en_US. - """ - i18n = I18nObject(zh_Hans="", en_US="Hello") - assert i18n.zh_Hans == "Hello" - assert i18n.en_US == "Hello" diff --git a/api/tests/unit_tests/graphon/model_runtime/entities/test_llm_entities.py b/api/tests/unit_tests/graphon/model_runtime/entities/test_llm_entities.py deleted file mode 100644 index 51a6c38fa90..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/entities/test_llm_entities.py +++ /dev/null @@ -1,148 +0,0 @@ -"""Tests for LLMUsage entity.""" - -from decimal import Decimal - -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata - - -class TestLLMUsage: - """Test cases for LLMUsage class.""" - - def test_from_metadata_with_all_tokens(self): - """Test from_metadata when all token types are provided.""" - metadata: LLMUsageMetadata = { - "prompt_tokens": 100, - "completion_tokens": 50, - "total_tokens": 150, - "prompt_unit_price": 0.001, - "completion_unit_price": 0.002, - "total_price": 0.2, - "currency": "USD", - "latency": 1.5, - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 100 - assert usage.completion_tokens == 50 - assert usage.total_tokens == 150 - assert usage.prompt_unit_price == Decimal("0.001") - assert usage.completion_unit_price == Decimal("0.002") - assert usage.total_price == Decimal("0.2") - assert usage.currency == "USD" - assert usage.latency == 1.5 - - def test_from_metadata_with_prompt_tokens_only(self): - """Test from_metadata when only prompt_tokens is provided.""" - metadata: LLMUsageMetadata = { - "prompt_tokens": 100, - "total_tokens": 100, - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 100 - assert usage.completion_tokens == 0 - assert usage.total_tokens == 100 - - def test_from_metadata_with_completion_tokens_only(self): - """Test from_metadata when only completion_tokens is provided.""" - metadata: LLMUsageMetadata = { - "completion_tokens": 50, - "total_tokens": 50, - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 0 - assert usage.completion_tokens == 50 - assert usage.total_tokens == 50 - - def test_from_metadata_calculates_total_when_missing(self): - """Test from_metadata calculates total_tokens when not provided.""" - metadata: LLMUsageMetadata = { - "prompt_tokens": 100, - "completion_tokens": 50, - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 100 - assert usage.completion_tokens == 50 - assert usage.total_tokens == 150 # Should be calculated - - def test_from_metadata_with_total_but_no_completion(self): - """ - Test from_metadata when total_tokens is provided but completion_tokens is 0. - This tests the fix for issue #24360 - prompt tokens should NOT be assigned to completion_tokens. - """ - metadata: LLMUsageMetadata = { - "prompt_tokens": 479, - "completion_tokens": 0, - "total_tokens": 521, - } - - usage = LLMUsage.from_metadata(metadata) - - # This is the key fix - prompt tokens should remain as prompt tokens - assert usage.prompt_tokens == 479 - assert usage.completion_tokens == 0 - assert usage.total_tokens == 521 - - def test_from_metadata_with_empty_metadata(self): - """Test from_metadata with empty metadata.""" - metadata: LLMUsageMetadata = {} - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 0 - assert usage.completion_tokens == 0 - assert usage.total_tokens == 0 - assert usage.currency == "USD" - assert usage.latency == 0.0 - - def test_from_metadata_preserves_zero_completion_tokens(self): - """ - Test that zero completion_tokens are preserved when explicitly set. - This is important for agent nodes that only use prompt tokens. - """ - metadata: LLMUsageMetadata = { - "prompt_tokens": 1000, - "completion_tokens": 0, - "total_tokens": 1000, - "prompt_unit_price": 0.15, - "completion_unit_price": 0.60, - "prompt_price": 0.00015, - "completion_price": 0, - "total_price": 0.00015, - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 1000 - assert usage.completion_tokens == 0 - assert usage.total_tokens == 1000 - assert usage.prompt_price == Decimal("0.00015") - assert usage.completion_price == Decimal(0) - assert usage.total_price == Decimal("0.00015") - - def test_from_metadata_with_decimal_values(self): - """Test from_metadata handles decimal values correctly.""" - metadata: LLMUsageMetadata = { - "prompt_tokens": 100, - "completion_tokens": 50, - "total_tokens": 150, - "prompt_unit_price": "0.001", - "completion_unit_price": "0.002", - "prompt_price": "0.1", - "completion_price": "0.1", - "total_price": "0.2", - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_unit_price == Decimal("0.001") - assert usage.completion_unit_price == Decimal("0.002") - assert usage.prompt_price == Decimal("0.1") - assert usage.completion_price == Decimal("0.1") - assert usage.total_price == Decimal("0.2") diff --git a/api/tests/unit_tests/graphon/model_runtime/entities/test_message_entities.py b/api/tests/unit_tests/graphon/model_runtime/entities/test_message_entities.py deleted file mode 100644 index 1918c324ccd..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/entities/test_message_entities.py +++ /dev/null @@ -1,210 +0,0 @@ -import pytest - -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - AudioPromptMessageContent, - DocumentPromptMessageContent, - ImagePromptMessageContent, - PromptMessageContent, - PromptMessageContentType, - PromptMessageFunction, - PromptMessageRole, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, - VideoPromptMessageContent, -) - - -class TestPromptMessageRole: - def test_value_of(self): - assert PromptMessageRole.value_of("system") == PromptMessageRole.SYSTEM - assert PromptMessageRole.value_of("user") == PromptMessageRole.USER - assert PromptMessageRole.value_of("assistant") == PromptMessageRole.ASSISTANT - assert PromptMessageRole.value_of("tool") == PromptMessageRole.TOOL - - with pytest.raises(ValueError, match="invalid prompt message type value invalid"): - PromptMessageRole.value_of("invalid") - - -class TestPromptMessageEntities: - def test_prompt_message_tool(self): - tool = PromptMessageTool(name="test_tool", description="test desc", parameters={"foo": "bar"}) - assert tool.name == "test_tool" - assert tool.description == "test desc" - assert tool.parameters == {"foo": "bar"} - - def test_prompt_message_function(self): - tool = PromptMessageTool(name="test_tool", description="test desc", parameters={"foo": "bar"}) - func = PromptMessageFunction(function=tool) - assert func.type == "function" - assert func.function == tool - - -class TestPromptMessageContent: - def test_text_content(self): - content = TextPromptMessageContent(data="hello") - assert content.type == PromptMessageContentType.TEXT - assert content.data == "hello" - - def test_image_content(self): - content = ImagePromptMessageContent( - format="jpg", base64_data="abc", mime_type="image/jpeg", detail=ImagePromptMessageContent.DETAIL.HIGH - ) - assert content.type == PromptMessageContentType.IMAGE - assert content.detail == ImagePromptMessageContent.DETAIL.HIGH - assert content.data == "data:image/jpeg;base64,abc" - - def test_image_content_url(self): - content = ImagePromptMessageContent(format="jpg", url="https://example.com/image.jpg", mime_type="image/jpeg") - assert content.data == "https://example.com/image.jpg" - - def test_audio_content(self): - content = AudioPromptMessageContent(format="mp3", base64_data="abc", mime_type="audio/mpeg") - assert content.type == PromptMessageContentType.AUDIO - assert content.data == "data:audio/mpeg;base64,abc" - - def test_video_content(self): - content = VideoPromptMessageContent(format="mp4", base64_data="abc", mime_type="video/mp4") - assert content.type == PromptMessageContentType.VIDEO - assert content.data == "data:video/mp4;base64,abc" - - def test_document_content(self): - content = DocumentPromptMessageContent(format="pdf", base64_data="abc", mime_type="application/pdf") - assert content.type == PromptMessageContentType.DOCUMENT - assert content.data == "data:application/pdf;base64,abc" - - -class TestPromptMessages: - def test_user_prompt_message(self): - msg = UserPromptMessage(content="hello") - assert msg.role == PromptMessageRole.USER - assert msg.content == "hello" - assert msg.is_empty() is False - assert msg.get_text_content() == "hello" - - def test_user_prompt_message_complex_content(self): - content = [TextPromptMessageContent(data="hello "), TextPromptMessageContent(data="world")] - msg = UserPromptMessage(content=content) - assert msg.get_text_content() == "hello world" - - # Test validation from dict - msg2 = UserPromptMessage(content=[{"type": "text", "data": "hi"}]) - assert isinstance(msg2.content[0], TextPromptMessageContent) - assert msg2.content[0].data == "hi" - - def test_prompt_message_empty(self): - msg = UserPromptMessage(content=None) - assert msg.is_empty() is True - assert msg.get_text_content() == "" - - def test_assistant_prompt_message(self): - msg = AssistantPromptMessage(content="thinking...") - assert msg.role == PromptMessageRole.ASSISTANT - assert msg.is_empty() is False - - tool_call = AssistantPromptMessage.ToolCall( - id="call_1", - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="test", arguments="{}"), - ) - msg_with_tools = AssistantPromptMessage(content=None, tool_calls=[tool_call]) - assert msg_with_tools.is_empty() is False - assert msg_with_tools.role == PromptMessageRole.ASSISTANT - - def test_assistant_tool_call_id_transform(self): - tool_call = AssistantPromptMessage.ToolCall( - id=123, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="test", arguments="{}"), - ) - assert tool_call.id == "123" - - def test_system_prompt_message(self): - msg = SystemPromptMessage(content="you are a bot") - assert msg.role == PromptMessageRole.SYSTEM - assert msg.content == "you are a bot" - - def test_tool_prompt_message(self): - # Case 1: Both content and tool_call_id are present - msg = ToolPromptMessage(content="result", tool_call_id="call_1") - assert msg.role == PromptMessageRole.TOOL - assert msg.tool_call_id == "call_1" - assert msg.is_empty() is False - - # Case 2: Content is present, but tool_call_id is empty - msg_content_only = ToolPromptMessage(content="result", tool_call_id="") - assert msg_content_only.is_empty() is False - - # Case 3: Content is None, but tool_call_id is present - msg_id_only = ToolPromptMessage(content=None, tool_call_id="call_1") - assert msg_id_only.is_empty() is False - - # Case 4: Both content and tool_call_id are empty - msg_empty = ToolPromptMessage(content=None, tool_call_id="") - assert msg_empty.is_empty() is True - - def test_prompt_message_validation_errors(self): - with pytest.raises(KeyError): - # Invalid content type in list - UserPromptMessage(content=[{"type": "invalid", "data": "foo"}]) - - with pytest.raises(ValueError, match="invalid prompt message"): - # Not a dict or PromptMessageContent - UserPromptMessage(content=[123]) - - def test_prompt_message_serialization(self): - # Case: content is None - assert UserPromptMessage(content=None).serialize_content(None) is None - - # Case: content is str - assert UserPromptMessage(content="hello").serialize_content("hello") == "hello" - - # Case: content is list of dict - content_list = [{"type": "text", "data": "hi"}] - msg = UserPromptMessage(content=content_list) - assert msg.serialize_content(msg.content) == [{"type": PromptMessageContentType.TEXT, "data": "hi"}] - - # Case: content is Sequence but not list (e.g. tuple) - # To hit line 204, we can call serialize_content manually or - # try to pass a type that pydantic doesn't convert to list in its internal state. - # Actually, let's just call it manually on the instance. - msg = UserPromptMessage(content="test") - content_tuple = (TextPromptMessageContent(data="hi"),) - assert msg.serialize_content(content_tuple) == content_tuple - - def test_prompt_message_mixed_content_validation(self): - # Test branch: isinstance(prompt, PromptMessageContent) - # but not (TextPromptMessageContent | MultiModalPromptMessageContent) - # Line 187: prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump()) - - # We need a PromptMessageContent that is NOT Text or MultiModal. - # But PromptMessageContentUnionTypes discriminator handles this usually. - # We can bypass high-level validation by passing the object directly in a list. - - class MockContent(PromptMessageContent): - type: PromptMessageContentType = PromptMessageContentType.TEXT - data: str - - mock_item = MockContent(data="test") - msg = UserPromptMessage(content=[mock_item]) - # It should hit line 187 and convert to TextPromptMessageContent - assert isinstance(msg.content[0], TextPromptMessageContent) - assert msg.content[0].data == "test" - - def test_prompt_message_get_text_content_branches(self): - # content is None - msg_none = UserPromptMessage(content=None) - assert msg_none.get_text_content() == "" - - # content is list but no text content - image = ImagePromptMessageContent(format="jpg", base64_data="abc", mime_type="image/jpeg") - msg_image = UserPromptMessage(content=[image]) - assert msg_image.get_text_content() == "" - - # content is list with mixed - text = TextPromptMessageContent(data="hello") - msg_mixed = UserPromptMessage(content=[text, image]) - assert msg_mixed.get_text_content() == "hello" diff --git a/api/tests/unit_tests/graphon/model_runtime/entities/test_model_entities.py b/api/tests/unit_tests/graphon/model_runtime/entities/test_model_entities.py deleted file mode 100644 index 1988709faa8..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/entities/test_model_entities.py +++ /dev/null @@ -1,220 +0,0 @@ -from decimal import Decimal - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ( - AIModelEntity, - DefaultParameterName, - FetchFrom, - ModelFeature, - ModelPropertyKey, - ModelType, - ModelUsage, - ParameterRule, - ParameterType, - PriceConfig, - PriceInfo, - PriceType, - ProviderModel, -) - - -class TestModelType: - def test_value_of(self): - assert ModelType.value_of("text-generation") == ModelType.LLM - assert ModelType.value_of(ModelType.LLM) == ModelType.LLM - assert ModelType.value_of("embeddings") == ModelType.TEXT_EMBEDDING - assert ModelType.value_of(ModelType.TEXT_EMBEDDING) == ModelType.TEXT_EMBEDDING - assert ModelType.value_of("reranking") == ModelType.RERANK - assert ModelType.value_of(ModelType.RERANK) == ModelType.RERANK - assert ModelType.value_of("speech2text") == ModelType.SPEECH2TEXT - assert ModelType.value_of(ModelType.SPEECH2TEXT) == ModelType.SPEECH2TEXT - assert ModelType.value_of("tts") == ModelType.TTS - assert ModelType.value_of(ModelType.TTS) == ModelType.TTS - assert ModelType.value_of(ModelType.MODERATION) == ModelType.MODERATION - - with pytest.raises(ValueError, match="invalid origin model type invalid"): - ModelType.value_of("invalid") - - def test_to_origin_model_type(self): - assert ModelType.LLM.to_origin_model_type() == "text-generation" - assert ModelType.TEXT_EMBEDDING.to_origin_model_type() == "embeddings" - assert ModelType.RERANK.to_origin_model_type() == "reranking" - assert ModelType.SPEECH2TEXT.to_origin_model_type() == "speech2text" - assert ModelType.TTS.to_origin_model_type() == "tts" - assert ModelType.MODERATION.to_origin_model_type() == "moderation" - - # Testing the else branch in to_origin_model_type - # Since it's a StrEnum, it's hard to get an invalid value here unless we mock or Force it. - # But if we look at the implementation: - # if self == self.LLM: ... elif ... else: raise ValueError - # We can try to create a "dummy" member if possible, or just skip it if we have 100% coverage otherwise. - # Actually, adding a new member to an enum at runtime is possible but messy. - # Let's see if we can trigger it. - - -class TestFetchFrom: - def test_values(self): - assert FetchFrom.PREDEFINED_MODEL == "predefined-model" - assert FetchFrom.CUSTOMIZABLE_MODEL == "customizable-model" - - -class TestModelFeature: - def test_values(self): - assert ModelFeature.TOOL_CALL == "tool-call" - assert ModelFeature.MULTI_TOOL_CALL == "multi-tool-call" - assert ModelFeature.AGENT_THOUGHT == "agent-thought" - assert ModelFeature.VISION == "vision" - assert ModelFeature.STREAM_TOOL_CALL == "stream-tool-call" - assert ModelFeature.DOCUMENT == "document" - assert ModelFeature.VIDEO == "video" - assert ModelFeature.AUDIO == "audio" - assert ModelFeature.STRUCTURED_OUTPUT == "structured-output" - - -class TestDefaultParameterName: - def test_value_of(self): - assert DefaultParameterName.value_of("temperature") == DefaultParameterName.TEMPERATURE - assert DefaultParameterName.value_of("top_p") == DefaultParameterName.TOP_P - - with pytest.raises(ValueError, match="invalid parameter name invalid"): - DefaultParameterName.value_of("invalid") - - -class TestParameterType: - def test_values(self): - assert ParameterType.FLOAT == "float" - assert ParameterType.INT == "int" - assert ParameterType.STRING == "string" - assert ParameterType.BOOLEAN == "boolean" - assert ParameterType.TEXT == "text" - - -class TestModelPropertyKey: - def test_values(self): - assert ModelPropertyKey.MODE == "mode" - assert ModelPropertyKey.CONTEXT_SIZE == "context_size" - - -class TestProviderModel: - def test_provider_model(self): - model = ProviderModel( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - ) - assert model.model == "gpt-4" - assert model.support_structure_output is False - - model_with_features = ProviderModel( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - features=[ModelFeature.STRUCTURED_OUTPUT], - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - ) - assert model_with_features.support_structure_output is True - - -class TestParameterRule: - def test_parameter_rule(self): - rule = ParameterRule( - name="temperature", - label=I18nObject(en_US="Temperature"), - type=ParameterType.FLOAT, - default=0.7, - min=0.0, - max=1.0, - precision=2, - ) - assert rule.name == "temperature" - assert rule.default == 0.7 - - -class TestPriceConfig: - def test_price_config(self): - config = PriceConfig(input=Decimal("0.01"), output=Decimal("0.02"), unit=Decimal("0.001"), currency="USD") - assert config.input == Decimal("0.01") - assert config.output == Decimal("0.02") - - -class TestAIModelEntity: - def test_ai_model_entity_no_json_schema(self): - entity = AIModelEntity( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - parameter_rules=[ - ParameterRule(name="temperature", label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT) - ], - ) - assert ModelFeature.STRUCTURED_OUTPUT not in (entity.features or []) - - def test_ai_model_entity_with_json_schema(self): - # Case: json_schema in parameter rules, features is None - entity = AIModelEntity( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - parameter_rules=[ - ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) - ], - ) - assert ModelFeature.STRUCTURED_OUTPUT in entity.features - - def test_ai_model_entity_with_json_schema_and_features_empty(self): - # Case: json_schema in parameter rules, features is empty list - entity = AIModelEntity( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - features=[], - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - parameter_rules=[ - ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) - ], - ) - assert ModelFeature.STRUCTURED_OUTPUT in entity.features - - def test_ai_model_entity_with_json_schema_and_other_features(self): - # Case: json_schema in parameter rules, features has other things - entity = AIModelEntity( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - features=[ModelFeature.VISION], - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - parameter_rules=[ - ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) - ], - ) - assert ModelFeature.STRUCTURED_OUTPUT in entity.features - assert ModelFeature.VISION in entity.features - - -class TestModelUsage: - def test_model_usage(self): - usage = ModelUsage() - assert isinstance(usage, ModelUsage) - - -class TestPriceType: - def test_values(self): - assert PriceType.INPUT == "input" - assert PriceType.OUTPUT == "output" - - -class TestPriceInfo: - def test_price_info(self): - info = PriceInfo(unit_price=Decimal("0.01"), unit=Decimal(1000), total_amount=Decimal("0.05"), currency="USD") - assert info.total_amount == Decimal("0.05") diff --git a/api/tests/unit_tests/graphon/model_runtime/errors/test_invoke.py b/api/tests/unit_tests/graphon/model_runtime/errors/test_invoke.py deleted file mode 100644 index 20048222307..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/errors/test_invoke.py +++ /dev/null @@ -1,63 +0,0 @@ -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) - - -class TestInvokeErrors: - def test_invoke_error_with_description(self): - error = InvokeError("Custom description") - assert error.description == "Custom description" - assert str(error) == "Custom description" - assert isinstance(error, ValueError) - - def test_invoke_error_without_description(self): - error = InvokeError() - assert error.description is None - assert str(error) == "InvokeError" - - def test_invoke_connection_error(self): - # Now preserves class-level description - error = InvokeConnectionError() - assert error.description == "Connection Error" - assert str(error) == "Connection Error" - assert isinstance(error, InvokeError) - - # Test with explicit description - error_with_desc = InvokeConnectionError("Connection Error") - assert error_with_desc.description == "Connection Error" - assert str(error_with_desc) == "Connection Error" - - def test_invoke_server_unavailable_error(self): - error = InvokeServerUnavailableError() - assert error.description == "Server Unavailable Error" - assert str(error) == "Server Unavailable Error" - assert isinstance(error, InvokeError) - - def test_invoke_rate_limit_error(self): - error = InvokeRateLimitError() - assert error.description == "Rate Limit Error" - assert str(error) == "Rate Limit Error" - assert isinstance(error, InvokeError) - - def test_invoke_authorization_error(self): - error = InvokeAuthorizationError() - assert error.description == "Incorrect model credentials provided, please check and try again. " - assert str(error) == "Incorrect model credentials provided, please check and try again. " - assert isinstance(error, InvokeError) - - def test_invoke_bad_request_error(self): - error = InvokeBadRequestError() - assert error.description == "Bad Request Error" - assert str(error) == "Bad Request Error" - assert isinstance(error, InvokeError) - - def test_invoke_error_inheritance(self): - # Test that we can override the default description in subclasses - error = InvokeBadRequestError("Overridden Error") - assert error.description == "Overridden Error" - assert str(error) == "Overridden Error" diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_ai_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_ai_model.py deleted file mode 100644 index 64edd69789b..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_ai_model.py +++ /dev/null @@ -1,254 +0,0 @@ -import decimal -from unittest.mock import MagicMock, patch - -import pytest -from pydantic import BaseModel - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ( - AIModelEntity, - DefaultParameterName, - FetchFrom, - ModelPropertyKey, - ModelType, - ParameterRule, - ParameterType, - PriceConfig, - PriceType, -) -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - - -class _ConcreteAIModel(AIModel): - model_type = ModelType.LLM - - -class TestAIModel: - @pytest.fixture - def provider_schema(self) -> ProviderEntity: - return ProviderEntity( - provider="langgenius/openai/openai", - provider_name="openai", - label=I18nObject(en_US="OpenAI"), - supported_model_types=[ModelType.LLM], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - @pytest.fixture - def model_runtime(self) -> MagicMock: - return MagicMock() - - @pytest.fixture - def ai_model(self, provider_schema: ProviderEntity, model_runtime: MagicMock) -> AIModel: - return _ConcreteAIModel( - provider_schema=provider_schema, - model_runtime=model_runtime, - ) - - def test_init_stores_runtime_state_and_is_not_pydantic_model( - self, ai_model: AIModel, provider_schema: ProviderEntity, model_runtime: MagicMock - ) -> None: - assert ai_model.model_type == ModelType.LLM - assert ai_model.provider_schema is provider_schema - assert ai_model.model_runtime is model_runtime - assert ai_model.provider == "langgenius/openai/openai" - assert ai_model.provider_display_name == "OpenAI" - assert ai_model.started_at == 0 - assert not isinstance(ai_model, BaseModel) - - def test_direct_base_class_requires_subclass_model_type( - self, provider_schema: ProviderEntity, model_runtime: MagicMock - ) -> None: - with pytest.raises(TypeError, match="subclasses must define model_type"): - AIModel(provider_schema=provider_schema, model_runtime=model_runtime) - - def test_subclass_uses_class_level_model_type( - self, provider_schema: ProviderEntity, model_runtime: MagicMock - ) -> None: - model = _ConcreteAIModel(provider_schema=provider_schema, model_runtime=model_runtime) - assert model.model_type == ModelType.LLM - - def test_invoke_error_mapping(self, ai_model: AIModel) -> None: - mapping = ai_model._invoke_error_mapping - assert InvokeConnectionError in mapping - assert InvokeServerUnavailableError in mapping - assert InvokeRateLimitError in mapping - assert InvokeAuthorizationError in mapping - assert InvokeBadRequestError in mapping - assert ValueError in mapping - - def test_transform_invoke_error(self, ai_model: AIModel) -> None: - err = Exception("Original error") - - with patch.object(AIModel, "_invoke_error_mapping", {InvokeAuthorizationError: [Exception]}): - transformed = ai_model._transform_invoke_error(err) - assert isinstance(transformed, InvokeAuthorizationError) - assert "Incorrect model credentials provided" in str(transformed.description) - - class CustomNonInvokeError(Exception): - pass - - with patch.object(AIModel, "_invoke_error_mapping", {CustomNonInvokeError: [Exception]}): - transformed = ai_model._transform_invoke_error(err) - assert transformed == err - - transformed = ai_model._transform_invoke_error(Exception("Unmapped")) - assert isinstance(transformed, InvokeError) - assert transformed.description == "[OpenAI] Error: Unmapped" - - def test_get_price(self, ai_model: AIModel) -> None: - model_name = "test_model" - credentials = {"key": "value"} - - mock_schema = MagicMock(spec=AIModelEntity) - mock_schema.pricing = PriceConfig( - input=decimal.Decimal("0.002"), - output=decimal.Decimal("0.004"), - unit=decimal.Decimal(1000), - currency="USD", - ) - - with patch.object(AIModel, "get_model_schema", return_value=mock_schema): - price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 2000) - assert price_info.unit_price == decimal.Decimal("0.002") - - price_info = ai_model.get_price(model_name, credentials, PriceType.OUTPUT, 2000) - assert price_info.unit_price == decimal.Decimal("0.004") - - mock_schema.pricing = None - with patch.object(AIModel, "get_model_schema", return_value=mock_schema): - price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 1000) - assert price_info.total_amount == decimal.Decimal("0.0") - - def test_get_price_no_price_config_error(self, ai_model: AIModel) -> None: - class ChangingPriceConfig: - def __init__(self) -> None: - self.input = decimal.Decimal("0.01") - self.unit = decimal.Decimal(1) - self.currency = "USD" - self.called = 0 - - def __bool__(self) -> bool: - self.called += 1 - return self.called <= 2 - - mock_schema = MagicMock() - mock_schema.pricing = ChangingPriceConfig() - - with patch.object(AIModel, "get_model_schema", return_value=mock_schema): - with pytest.raises(ValueError, match="Price config not found"): - ai_model.get_price("test_model", {}, PriceType.INPUT, 1000) - - def test_get_model_schema_delegates_to_runtime( - self, ai_model: AIModel, model_runtime: MagicMock, provider_schema: ProviderEntity - ) -> None: - model_name = "test_model" - credentials = {"api_key": "abc"} - - mock_schema = AIModelEntity( - model="test_model", - label=I18nObject(en_US="Test Model"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, - parameter_rules=[], - ) - model_runtime.get_model_schema.return_value = mock_schema - - schema = ai_model.get_model_schema(model_name, credentials) - - assert schema == mock_schema - model_runtime.get_model_schema.assert_called_once_with( - provider=provider_schema.provider, - model_type=ModelType.LLM, - model=model_name, - credentials=credentials, - ) - - def test_get_customizable_model_schema_from_credentials_template_mapping_value_error( - self, ai_model: AIModel - ) -> None: - mock_schema = AIModelEntity( - model="test_model", - label=I18nObject(en_US="Test Model"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, - parameter_rules=[ - ParameterRule( - name="invalid", - use_template="invalid_template_name", - label=I18nObject(en_US="Invalid"), - type=ParameterType.FLOAT, - ) - ], - ) - - with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema): - schema = ai_model.get_customizable_model_schema_from_credentials("test_model", {}) - assert schema is not None - assert schema.parameter_rules[0].use_template == "invalid_template_name" - - def test_get_customizable_model_schema_from_credentials(self, ai_model: AIModel) -> None: - mock_schema = AIModelEntity( - model="test_model", - label=I18nObject(en_US="Test Model"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, - parameter_rules=[ - ParameterRule( - name="temp", use_template="temperature", label=I18nObject(en_US="Temp"), type=ParameterType.FLOAT - ), - ParameterRule( - name="top_p", - use_template="top_p", - label=I18nObject(en_US="Top P"), - type=ParameterType.FLOAT, - help=I18nObject(en_US=""), - ), - ParameterRule( - name="max_tokens", - use_template="max_tokens", - label=I18nObject(en_US="Max Tokens"), - type=ParameterType.INT, - help=I18nObject(en_US="", zh_Hans=""), - ), - ParameterRule(name="custom", label=I18nObject(en_US="Custom"), type=ParameterType.STRING), - ], - ) - - with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema): - schema = ai_model.get_customizable_model_schema_from_credentials("test_model", {}) - - assert schema is not None - assert schema.parameter_rules[0].max == 1.0 - assert schema.parameter_rules[1].help is not None - assert schema.parameter_rules[1].help.en_US != "" - assert schema.parameter_rules[2].help is not None - assert schema.parameter_rules[2].help.zh_Hans != "" - assert schema.parameter_rules[3].use_template is None - - def test_get_customizable_model_schema_from_credentials_none(self, ai_model: AIModel) -> None: - with patch.object(AIModel, "get_customizable_model_schema", return_value=None): - schema = ai_model.get_customizable_model_schema_from_credentials("model", {}) - assert schema is None - - def test_get_customizable_model_schema_default(self, ai_model: AIModel) -> None: - assert ai_model.get_customizable_model_schema("model", {}) is None - - def test_get_default_parameter_rule_variable_map(self, ai_model: AIModel) -> None: - result = ai_model._get_default_parameter_rule_variable_map(DefaultParameterName.TEMPERATURE) - assert result["default"] == 0.0 - - with pytest.raises(Exception, match="Invalid model parameter rule name"): - ai_model._get_default_parameter_rule_variable_map("invalid_name") diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_large_language_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_large_language_model.py deleted file mode 100644 index 668a7e34767..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_large_language_model.py +++ /dev/null @@ -1,452 +0,0 @@ -import logging -from collections.abc import Generator, Iterator, Sequence -from dataclasses import dataclass, field -from decimal import Decimal -from types import SimpleNamespace -from typing import Any -from unittest.mock import MagicMock - -import pytest - -import graphon.model_runtime.model_providers.__base.large_language_model as llm_module - -# Access large_language_model members via llm_module to avoid partial import issues in CI -from graphon.model_runtime.callbacks.base_callback import Callback -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMUsage, -) -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelType, PriceInfo -from graphon.model_runtime.model_providers.__base.large_language_model import _build_llm_result_from_chunks - - -def _usage(prompt_tokens: int = 1, completion_tokens: int = 2) -> LLMUsage: - return LLMUsage( - prompt_tokens=prompt_tokens, - prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal(1), - prompt_price=Decimal(prompt_tokens) * Decimal("0.001"), - completion_tokens=completion_tokens, - completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal(1), - completion_price=Decimal(completion_tokens) * Decimal("0.002"), - total_tokens=prompt_tokens + completion_tokens, - total_price=Decimal(prompt_tokens) * Decimal("0.001") + Decimal(completion_tokens) * Decimal("0.002"), - currency="USD", - latency=0.0, - ) - - -def _tool_call_delta( - *, - tool_call_id: str, - tool_type: str = "function", - function_name: str = "", - function_arguments: str = "", -) -> AssistantPromptMessage.ToolCall: - return AssistantPromptMessage.ToolCall( - id=tool_call_id, - type=tool_type, - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=function_name, arguments=function_arguments), - ) - - -def _chunk( - *, - model: str = "test-model", - content: str | list[Any] | None = None, - tool_calls: list[AssistantPromptMessage.ToolCall] | None = None, - usage: LLMUsage | None = None, - system_fingerprint: str | None = None, -) -> LLMResultChunk: - return LLMResultChunk( - model=model, - system_fingerprint=system_fingerprint, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=content, tool_calls=tool_calls or []), - usage=usage, - ), - ) - - -@dataclass -class SpyCallback(Callback): - raise_error: bool = False - before: list[dict[str, Any]] = field(default_factory=list) - new_chunk: list[dict[str, Any]] = field(default_factory=list) - after: list[dict[str, Any]] = field(default_factory=list) - error: list[dict[str, Any]] = field(default_factory=list) - - def on_before_invoke(self, **kwargs: Any) -> None: # type: ignore[override] - self.before.append(kwargs) - - def on_new_chunk(self, **kwargs: Any) -> None: # type: ignore[override] - self.new_chunk.append(kwargs) - - def on_after_invoke(self, **kwargs: Any) -> None: # type: ignore[override] - self.after.append(kwargs) - - def on_invoke_error(self, **kwargs: Any) -> None: # type: ignore[override] - self.error.append(kwargs) - - -class _TestLLM(llm_module.LargeLanguageModel): - def get_price(self, model: str, credentials: dict, price_type: Any, tokens: int) -> PriceInfo: # type: ignore[override] - return PriceInfo( - unit_price=Decimal("0.01"), - unit=Decimal(1), - total_amount=Decimal(tokens) * Decimal("0.01"), - currency="USD", - ) - - def _transform_invoke_error(self, error: Exception) -> Exception: # type: ignore[override] - return RuntimeError(f"transformed: {error}") - - -@pytest.fixture -def llm() -> _TestLLM: - provider_schema = SimpleNamespace(provider="provider", label=SimpleNamespace(en_US="Provider")) - model_runtime = MagicMock() - model_runtime.get_llm_num_tokens.return_value = 0 - return _TestLLM(provider_schema=provider_schema, model_runtime=model_runtime, started_at=1.0) - - -def test_gen_tool_call_id_is_uuid_based(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="abc123")) - assert llm_module._gen_tool_call_id() == "chatcmpl-tool-abc123" - - -def test_run_callbacks_no_callbacks_noop() -> None: - invoked: list[int] = [] - llm_module._run_callbacks(None, event="x", invoke=lambda _: invoked.append(1)) - llm_module._run_callbacks([], event="x", invoke=lambda _: invoked.append(1)) - assert invoked == [] - - -def test_run_callbacks_swallows_error_when_raise_error_false(caplog: pytest.LogCaptureFixture) -> None: - class Boom: - raise_error = False - - caplog.set_level(logging.WARNING) - llm_module._run_callbacks( - [Boom()], event="on_before_invoke", invoke=lambda _: (_ for _ in ()).throw(ValueError("boom")) - ) - assert any("Callback" in record.message and "failed with error" in record.message for record in caplog.records) - - -def test_run_callbacks_reraises_when_raise_error_true() -> None: - class Boom: - raise_error = True - - with pytest.raises(ValueError, match="boom"): - llm_module._run_callbacks( - [Boom()], event="on_before_invoke", invoke=lambda _: (_ for _ in ()).throw(ValueError("boom")) - ) - - -def test_get_or_create_tool_call_empty_id_returns_last() -> None: - calls = [ - _tool_call_delta(tool_call_id="id1", function_name="a"), - _tool_call_delta(tool_call_id="id2", function_name="b"), - ] - assert llm_module._get_or_create_tool_call(calls, "") is calls[-1] - - -def test_get_or_create_tool_call_empty_id_without_existing_raises() -> None: - with pytest.raises(ValueError, match="tool_call_id is empty"): - llm_module._get_or_create_tool_call([], "") - - -def test_get_or_create_tool_call_creates_if_missing() -> None: - calls: list[AssistantPromptMessage.ToolCall] = [] - tool_call = llm_module._get_or_create_tool_call(calls, "new-id") - assert tool_call.id == "new-id" - assert tool_call.function.name == "" - assert tool_call.function.arguments == "" - assert calls == [tool_call] - - -def test_get_or_create_tool_call_returns_existing_when_found() -> None: - existing = _tool_call_delta(tool_call_id="same-id", function_name="fn", function_arguments="{}") - calls = [existing] - assert llm_module._get_or_create_tool_call(calls, "same-id") is existing - - -def test_merge_tool_call_delta_updates_fields_and_appends_arguments() -> None: - tool_call = _tool_call_delta(tool_call_id="id", tool_type="function", function_name="x", function_arguments="{") - delta = _tool_call_delta(tool_call_id="id2", tool_type="function", function_name="y", function_arguments="}") - llm_module._merge_tool_call_delta(tool_call, delta) - assert tool_call.id == "id2" - assert tool_call.type == "function" - assert tool_call.function.name == "y" - assert tool_call.function.arguments == "{}" - - -def test_increase_tool_call_generates_id_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="fixed")) - delta = _tool_call_delta(tool_call_id="", function_name="fn", function_arguments="{") - existing: list[AssistantPromptMessage.ToolCall] = [] - llm_module._increase_tool_call([delta], existing) - assert len(existing) == 1 - assert existing[0].id == "chatcmpl-tool-fixed" - assert existing[0].function.name == "fn" - assert existing[0].function.arguments == "{" - - -def test_increase_tool_call_merges_incremental_arguments() -> None: - existing: list[AssistantPromptMessage.ToolCall] = [] - llm_module._increase_tool_call( - [_tool_call_delta(tool_call_id="id", function_name="fn", function_arguments="{")], existing - ) - llm_module._increase_tool_call( - [_tool_call_delta(tool_call_id="id", function_name="", function_arguments="}")], existing - ) - assert len(existing) == 1 - assert existing[0].function.name == "fn" - assert existing[0].function.arguments == "{}" - - -@pytest.mark.parametrize( - ("content", "expected_type"), - [ - ("hello", str), - ([TextPromptMessageContent(data="hello")], list), - ], -) -def test_build_llm_result_from_chunks_accumulates_and_raises_error( - content: str | list[TextPromptMessageContent], - expected_type: type, - monkeypatch: pytest.MonkeyPatch, - caplog: pytest.LogCaptureFixture, -) -> None: - monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="drain")) - caplog.set_level(logging.DEBUG) - - tool_delta = _tool_call_delta(tool_call_id="", function_name="fn", function_arguments="{}") - first = _chunk(content=content, tool_calls=[tool_delta], usage=_usage(3, 4), system_fingerprint="fp1") - - def iter_with_error() -> Iterator[LLMResultChunk]: - yield first - raise RuntimeError("drain boom") - - with pytest.raises(RuntimeError, match="drain boom"): - _build_llm_result_from_chunks( - model="m", prompt_messages=[UserPromptMessage(content="u")], chunks=iter_with_error() - ) - - assert any("Error while consuming non-stream plugin chunk iterator" in record.message for record in caplog.records) - - -def test_build_llm_result_from_chunks_empty_iterator() -> None: - def empty() -> Iterator[LLMResultChunk]: - if False: # pragma: no cover - yield _chunk() - return - - result = _build_llm_result_from_chunks(model="m", prompt_messages=[], chunks=empty()) - assert result.message.content == [] - assert result.usage.total_tokens == 0 - assert result.system_fingerprint is None - - -def test_build_llm_result_from_chunks_accumulates_all_chunks() -> None: - chunks = iter([_chunk(content="first"), _chunk(content="second")]) - result = _build_llm_result_from_chunks(model="m", prompt_messages=[], chunks=chunks) - assert result.message.content == "firstsecond" - - -def test_invoke_llm_via_runtime_passes_list_converted_stop(llm: _TestLLM) -> None: - llm.model_runtime = MagicMock() - prompt_messages: Sequence[PromptMessage] = (UserPromptMessage(content="hi"),) - result = llm_module._invoke_llm_via_runtime( - llm_model=llm, - provider="prov", - model="m", - credentials={"k": "v"}, - model_parameters={"temp": 1}, - prompt_messages=prompt_messages, - tools=None, - stop=("a", "b"), - stream=True, - ) - - llm.model_runtime.invoke_llm.assert_called_once_with( - provider="prov", - model="m", - credentials={"k": "v"}, - model_parameters={"temp": 1}, - prompt_messages=list(prompt_messages), - tools=None, - stop=("a", "b"), - stream=True, - ) - assert result is llm.model_runtime.invoke_llm.return_value - - -def test_normalize_non_stream_runtime_result_passthrough_llmresult() -> None: - llm_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()) - assert ( - llm_module._normalize_non_stream_runtime_result(model="m", prompt_messages=[], result=llm_result) is llm_result - ) - - -def test_normalize_non_stream_runtime_result_builds_from_chunks() -> None: - chunks = iter([_chunk(content="hello", usage=_usage(1, 1))]) - result = llm_module._normalize_non_stream_runtime_result( - model="m", prompt_messages=[UserPromptMessage(content="u")], result=chunks - ) - assert isinstance(result, LLMResult) - assert result.message.content == "hello" - - -def test_invoke_non_stream_normalizes_and_sets_prompt_messages(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: - plugin_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()) - monkeypatch.setattr( - "graphon.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime", - lambda **_: plugin_result, - ) - cb = SpyCallback() - prompt_messages = [UserPromptMessage(content="hi")] - result = llm.invoke(model="m", credentials={}, prompt_messages=prompt_messages, stream=False, callbacks=[cb]) - assert isinstance(result, LLMResult) - assert result.prompt_messages == prompt_messages - assert len(cb.before) == 1 - assert len(cb.after) == 1 - assert cb.after[0]["result"].prompt_messages == prompt_messages - - -def test_invoke_stream_wraps_generator_and_triggers_callbacks(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: - plugin_chunks = iter( - [ - _chunk(model="m1", content="a"), - _chunk( - model="m2", content=[TextPromptMessageContent(data="b")], usage=_usage(2, 3), system_fingerprint="fp" - ), - _chunk(model="m3", content=None), - ] - ) - monkeypatch.setattr( - "graphon.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime", - lambda **_: plugin_chunks, - ) - - cb = SpyCallback() - prompt_messages = [UserPromptMessage(content="hi")] - gen = llm.invoke(model="m", credentials={}, prompt_messages=prompt_messages, stream=True, callbacks=[cb]) - - assert isinstance(gen, Generator) - chunks = list(gen) - assert len(chunks) == 3 - assert all(chunk.prompt_messages == prompt_messages for chunk in chunks) - assert len(cb.before) == 1 - assert len(cb.new_chunk) == 3 - assert len(cb.after) == 1 - final_result: LLMResult = cb.after[0]["result"] - assert final_result.model == "m3" - assert final_result.system_fingerprint == "fp" - assert isinstance(final_result.message.content, list) - assert [c.data for c in final_result.message.content] == ["a", "b"] - assert final_result.usage.total_tokens == 5 - - -def test_invoke_triggers_error_callbacks_and_raises_transformed(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: - def boom(**_: Any) -> Any: - raise ValueError("plugin down") - - monkeypatch.setattr( - "graphon.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime", boom - ) - cb = SpyCallback() - with pytest.raises(RuntimeError, match="transformed: plugin down"): - llm.invoke( - model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False, callbacks=[cb] - ) - assert len(cb.error) == 1 - assert isinstance(cb.error[0]["ex"], ValueError) - - -def test_invoke_raises_not_implemented_for_unsupported_result_type( - llm: _TestLLM, monkeypatch: pytest.MonkeyPatch -) -> None: - monkeypatch.setattr(llm_module, "_invoke_llm_via_runtime", lambda **_: "not-a-result") - monkeypatch.setattr(llm_module, "_normalize_non_stream_runtime_result", lambda **_: "not-a-result") - with pytest.raises(NotImplementedError, match="unsupported invoke result type"): - llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False) - - -def test_invoke_appends_logging_callback_in_debug(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: - captured_callbacks: list[list[Callback]] = [] - - class FakeLoggingCallback(SpyCallback): - pass - - monkeypatch.setattr(llm_module, "LoggingCallback", FakeLoggingCallback) - monkeypatch.setattr(llm_module.logger, "isEnabledFor", lambda level: level == logging.DEBUG) - monkeypatch.setattr( - "graphon.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime", - lambda **_: LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()), - ) - - original_trigger = llm._trigger_before_invoke_callbacks - - def spy_trigger(*args: Any, **kwargs: Any) -> None: - captured_callbacks.append(list(kwargs["callbacks"])) - original_trigger(*args, **kwargs) - - monkeypatch.setattr(llm, "_trigger_before_invoke_callbacks", spy_trigger) - llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False) - assert any(isinstance(cb, FakeLoggingCallback) for cb in captured_callbacks[0]) - - -def test_get_num_tokens_returns_0_when_runtime_returns_0(llm: _TestLLM) -> None: - llm.model_runtime.get_llm_num_tokens.return_value = 0 - assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 0 - - -def test_get_num_tokens_uses_runtime(llm: _TestLLM) -> None: - llm.model_runtime.get_llm_num_tokens.return_value = 42 - assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 42 - llm.model_runtime.get_llm_num_tokens.assert_called_once_with( - provider="provider", - model_type=ModelType.LLM, - model="m", - credentials={}, - prompt_messages=[UserPromptMessage(content="x")], - tools=None, - ) - - -def test_calc_response_usage_uses_prices_and_latency(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(llm_module.time, "perf_counter", lambda: 4.5) - llm.started_at = 1.0 - usage = llm.calc_response_usage(model="m", credentials={}, prompt_tokens=10, completion_tokens=5) - assert usage.total_tokens == 15 - assert usage.total_price == Decimal("0.15") - assert usage.latency == 3.5 - - -def test_invoke_result_generator_raises_transformed_on_iteration_error(llm: _TestLLM) -> None: - def broken() -> Iterator[LLMResultChunk]: - yield _chunk(content="ok") - raise ValueError("chunk stream broken") - - gen = llm._invoke_result_generator( - model="m", - result=broken(), - credentials={}, - prompt_messages=[UserPromptMessage(content="u")], - model_parameters={}, - callbacks=[SpyCallback()], - ) - - with pytest.raises(RuntimeError, match="transformed: chunk stream broken"): - list(gen) diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_moderation_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_moderation_model.py deleted file mode 100644 index a42a9308066..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_moderation_model.py +++ /dev/null @@ -1,56 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.errors.invoke import InvokeError -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel - - -@pytest.fixture -def provider_schema() -> ProviderEntity: - return ProviderEntity( - provider="test_provider", - label=I18nObject(en_US="test_provider"), - supported_model_types=[ModelType.MODERATION], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -@pytest.fixture -def model_runtime() -> MagicMock: - return MagicMock() - - -@pytest.fixture -def moderation_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> ModerationModel: - return ModerationModel(provider_schema=provider_schema, model_runtime=model_runtime) - - -def test_model_type(moderation_model: ModerationModel) -> None: - assert moderation_model.model_type == ModelType.MODERATION - - -def test_invoke_success(moderation_model: ModerationModel, model_runtime: MagicMock) -> None: - with patch("time.perf_counter", return_value=1.0): - model_runtime.invoke_moderation.return_value = True - - result = moderation_model.invoke(model="test_model", credentials={"api_key": "abc"}, text="test text") - - assert result is True - assert moderation_model.started_at == 1.0 - model_runtime.invoke_moderation.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - text="test text", - ) - - -def test_invoke_exception(moderation_model: ModerationModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_moderation.side_effect = Exception("Test error") - - with pytest.raises(InvokeError, match="Test error"): - moderation_model.invoke(model="test_model", credentials={"api_key": "abc"}, text="test text") diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_rerank_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_rerank_model.py deleted file mode 100644 index 9650ed2db75..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_rerank_model.py +++ /dev/null @@ -1,110 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult -from graphon.model_runtime.errors.invoke import InvokeError -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel - - -@pytest.fixture -def provider_schema() -> ProviderEntity: - return ProviderEntity( - provider="test_provider", - label=I18nObject(en_US="test_provider"), - supported_model_types=[ModelType.RERANK], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -@pytest.fixture -def model_runtime() -> MagicMock: - return MagicMock() - - -@pytest.fixture -def rerank_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> RerankModel: - return RerankModel(provider_schema=provider_schema, model_runtime=model_runtime) - - -def test_model_type_is_rerank_by_default(rerank_model: RerankModel) -> None: - assert rerank_model.model_type == ModelType.RERANK - - -def test_invoke_calls_runtime_and_passes_args(rerank_model: RerankModel, model_runtime: MagicMock) -> None: - expected = RerankResult(model="rerank", docs=[RerankDocument(index=0, text="a", score=0.5)]) - model_runtime.invoke_rerank.return_value = expected - - result = rerank_model.invoke( - model="rerank", - credentials={"k": "v"}, - query="q", - docs=["d1", "d2"], - score_threshold=0.2, - top_n=10, - ) - - assert result == expected - model_runtime.invoke_rerank.assert_called_once_with( - provider="test_provider", - model="rerank", - credentials={"k": "v"}, - query="q", - docs=["d1", "d2"], - score_threshold=0.2, - top_n=10, - ) - - -def test_invoke_transforms_and_raises_on_runtime_error(rerank_model: RerankModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_rerank.side_effect = Exception("runtime down") - - with pytest.raises(InvokeError, match="runtime down"): - rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"]) - - -def test_invoke_multimodal_calls_runtime_and_passes_args(rerank_model: RerankModel, model_runtime: MagicMock) -> None: - expected = RerankResult(model="mm", docs=[RerankDocument(index=0, text="x", score=0.9)]) - model_runtime.invoke_multimodal_rerank.return_value = expected - - query = {"type": "text", "text": "q"} - docs = [{"type": "text", "text": "d1"}] - result = rerank_model.invoke_multimodal_rerank( - model="mm", - credentials={"k": "v"}, - query=query, - docs=docs, - score_threshold=None, - top_n=None, - ) - - assert result == expected - model_runtime.invoke_multimodal_rerank.assert_called_once_with( - provider="test_provider", - model="mm", - credentials={"k": "v"}, - query=query, - docs=docs, - score_threshold=None, - top_n=None, - ) - - -def test_invoke_multimodal_transforms_and_raises_on_runtime_error( - rerank_model: RerankModel, model_runtime: MagicMock -) -> None: - model_runtime.invoke_multimodal_rerank.side_effect = Exception("multimodal runtime down") - - query = {"content": "q", "content_type": "text"} - docs = [{"content": "d1", "content_type": "text"}] - - with pytest.raises(InvokeError, match="multimodal runtime down"): - rerank_model.invoke_multimodal_rerank( - model="mm", - credentials={}, - query=query, - docs=docs, - ) diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_runtime_user_forwarding.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_runtime_user_forwarding.py deleted file mode 100644 index 98bb1eb1b87..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_runtime_user_forwarding.py +++ /dev/null @@ -1,170 +0,0 @@ -from decimal import Decimal -from io import BytesIO -from unittest.mock import MagicMock - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.entities.rerank_entities import RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel - - -def _provider_schema(model_type: ModelType) -> ProviderEntity: - return ProviderEntity( - provider="langgenius/openai/openai", - provider_name="openai", - label=I18nObject(en_US="OpenAI"), - supported_model_types=[model_type], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -def _embedding_usage() -> EmbeddingUsage: - return EmbeddingUsage( - tokens=1, - total_tokens=1, - unit_price=Decimal(0), - price_unit=Decimal(0), - total_price=Decimal(0), - currency="USD", - latency=0.0, - ) - - -def test_large_language_model_invokes_runtime_without_user_id() -> None: - runtime = MagicMock() - runtime.invoke_llm.return_value = LLMResult( - model="gpt-4o-mini", - prompt_messages=[], - message=AssistantPromptMessage(content="ok"), - usage=LLMUsage.empty_usage(), - ) - model = LargeLanguageModel(provider_schema=_provider_schema(ModelType.LLM), model_runtime=runtime) - - model.invoke( - model="gpt-4o-mini", - credentials={"api_key": "secret"}, - prompt_messages=[UserPromptMessage(content="hi")], - stream=False, - ) - - assert "user_id" not in runtime.invoke_llm.call_args.kwargs - - -def test_text_embedding_model_invokes_runtime_without_user_id_for_text_requests() -> None: - runtime = MagicMock() - runtime.invoke_text_embedding.return_value = EmbeddingResult( - model="text-embedding-3-small", - embeddings=[[0.1]], - usage=_embedding_usage(), - ) - model = TextEmbeddingModel(provider_schema=_provider_schema(ModelType.TEXT_EMBEDDING), model_runtime=runtime) - - model.invoke( - model="text-embedding-3-small", - credentials={"api_key": "secret"}, - texts=["hello"], - ) - - assert "user_id" not in runtime.invoke_text_embedding.call_args.kwargs - - -def test_text_embedding_model_invokes_runtime_without_user_id_for_multimodal_requests() -> None: - runtime = MagicMock() - runtime.invoke_multimodal_embedding.return_value = EmbeddingResult( - model="text-embedding-3-small", - embeddings=[[0.1]], - usage=_embedding_usage(), - ) - model = TextEmbeddingModel(provider_schema=_provider_schema(ModelType.TEXT_EMBEDDING), model_runtime=runtime) - - model.invoke( - model="text-embedding-3-small", - credentials={"api_key": "secret"}, - multimodel_documents=[{"content": "hello", "content_type": "text"}], - ) - - assert "user_id" not in runtime.invoke_multimodal_embedding.call_args.kwargs - - -def test_rerank_model_invokes_runtime_without_user_id_for_text_requests() -> None: - runtime = MagicMock() - runtime.invoke_rerank.return_value = RerankResult(model="rerank", docs=[]) - model = RerankModel(provider_schema=_provider_schema(ModelType.RERANK), model_runtime=runtime) - - model.invoke( - model="rerank", - credentials={"api_key": "secret"}, - query="q", - docs=["d1"], - ) - - assert "user_id" not in runtime.invoke_rerank.call_args.kwargs - - -def test_rerank_model_invokes_runtime_without_user_id_for_multimodal_requests() -> None: - runtime = MagicMock() - runtime.invoke_multimodal_rerank.return_value = RerankResult(model="rerank", docs=[]) - model = RerankModel(provider_schema=_provider_schema(ModelType.RERANK), model_runtime=runtime) - - model.invoke_multimodal_rerank( - model="rerank", - credentials={"api_key": "secret"}, - query={"content": "q", "content_type": "text"}, - docs=[{"content": "d1", "content_type": "text"}], - ) - - assert "user_id" not in runtime.invoke_multimodal_rerank.call_args.kwargs - - -def test_tts_model_invokes_runtime_without_user_id() -> None: - runtime = MagicMock() - runtime.invoke_tts.return_value = [b"chunk"] - model = TTSModel(provider_schema=_provider_schema(ModelType.TTS), model_runtime=runtime) - - list( - model.invoke( - model="tts-1", - credentials={"api_key": "secret"}, - content_text="hello", - voice="alloy", - ) - ) - - assert "user_id" not in runtime.invoke_tts.call_args.kwargs - - -def test_speech_to_text_model_invokes_runtime_without_user_id() -> None: - runtime = MagicMock() - runtime.invoke_speech_to_text.return_value = "transcript" - model = Speech2TextModel(provider_schema=_provider_schema(ModelType.SPEECH2TEXT), model_runtime=runtime) - - model.invoke( - model="whisper-1", - credentials={"api_key": "secret"}, - file=BytesIO(b"audio"), - ) - - assert "user_id" not in runtime.invoke_speech_to_text.call_args.kwargs - - -def test_moderation_model_invokes_runtime_without_user_id() -> None: - runtime = MagicMock() - runtime.invoke_moderation.return_value = True - model = ModerationModel(provider_schema=_provider_schema(ModelType.MODERATION), model_runtime=runtime) - - model.invoke( - model="omni-moderation-latest", - credentials={"api_key": "secret"}, - text="unsafe?", - ) - - assert "user_id" not in runtime.invoke_moderation.call_args.kwargs diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_speech2text_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_speech2text_model.py deleted file mode 100644 index b03923bbc22..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_speech2text_model.py +++ /dev/null @@ -1,56 +0,0 @@ -from io import BytesIO -from unittest.mock import MagicMock - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.errors.invoke import InvokeError -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel - - -@pytest.fixture -def provider_schema() -> ProviderEntity: - return ProviderEntity( - provider="test_provider", - label=I18nObject(en_US="test_provider"), - supported_model_types=[ModelType.SPEECH2TEXT], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -@pytest.fixture -def model_runtime() -> MagicMock: - return MagicMock() - - -@pytest.fixture -def speech2text_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> Speech2TextModel: - return Speech2TextModel(provider_schema=provider_schema, model_runtime=model_runtime) - - -def test_model_type(speech2text_model: Speech2TextModel) -> None: - assert speech2text_model.model_type == ModelType.SPEECH2TEXT - - -def test_invoke_success(speech2text_model: Speech2TextModel, model_runtime: MagicMock) -> None: - file = BytesIO(b"audio data") - model_runtime.invoke_speech_to_text.return_value = "transcribed text" - - result = speech2text_model.invoke(model="test_model", credentials={"api_key": "abc"}, file=file) - - assert result == "transcribed text" - model_runtime.invoke_speech_to_text.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - file=file, - ) - - -def test_invoke_exception(speech2text_model: Speech2TextModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_speech_to_text.side_effect = Exception("Test error") - - with pytest.raises(InvokeError, match="Test error"): - speech2text_model.invoke(model="test_model", credentials={"api_key": "abc"}, file=BytesIO(b"audio data")) diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_text_embedding_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_text_embedding_model.py deleted file mode 100644 index 64caf3a3157..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_text_embedding_model.py +++ /dev/null @@ -1,146 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult -from graphon.model_runtime.errors.invoke import InvokeError -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel - - -@pytest.fixture -def provider_schema() -> ProviderEntity: - return ProviderEntity( - provider="test_provider", - label=I18nObject(en_US="test_provider"), - supported_model_types=[ModelType.TEXT_EMBEDDING], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -@pytest.fixture -def model_runtime() -> MagicMock: - return MagicMock() - - -@pytest.fixture -def text_embedding_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> TextEmbeddingModel: - return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=model_runtime) - - -def test_model_type(text_embedding_model: TextEmbeddingModel) -> None: - assert text_embedding_model.model_type == ModelType.TEXT_EMBEDDING - - -def test_invoke_with_texts(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None: - expected_result = MagicMock(spec=EmbeddingResult) - model_runtime.invoke_text_embedding.return_value = expected_result - - result = text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"}, texts=["hello", "world"]) - - assert result == expected_result - model_runtime.invoke_text_embedding.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - texts=["hello", "world"], - input_type=EmbeddingInputType.DOCUMENT, - ) - - -def test_invoke_with_multimodal_documents(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None: - expected_result = MagicMock(spec=EmbeddingResult) - model_runtime.invoke_multimodal_embedding.return_value = expected_result - - result = text_embedding_model.invoke( - model="test_model", - credentials={"api_key": "abc"}, - multimodel_documents=[{"type": "text", "text": "hello"}], - ) - - assert result == expected_result - model_runtime.invoke_multimodal_embedding.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - documents=[{"type": "text", "text": "hello"}], - input_type=EmbeddingInputType.DOCUMENT, - ) - - -def test_invoke_no_input(text_embedding_model: TextEmbeddingModel) -> None: - with pytest.raises(ValueError, match="No texts or files provided"): - text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"}) - - -def test_invoke_prefers_texts_over_multimodal_documents( - text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock -) -> None: - expected_result = MagicMock(spec=EmbeddingResult) - model_runtime.invoke_text_embedding.return_value = expected_result - - result = text_embedding_model.invoke( - model="test_model", - credentials={"api_key": "abc"}, - texts=["hello"], - multimodel_documents=[{"type": "text", "text": "world"}], - ) - - assert result == expected_result - model_runtime.invoke_text_embedding.assert_called_once() - model_runtime.invoke_multimodal_embedding.assert_not_called() - - -def test_invoke_exception(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_text_embedding.side_effect = Exception("Test error") - - with pytest.raises(InvokeError, match="Test error"): - text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"}, texts=["hello"]) - - -def test_get_num_tokens(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None: - model_runtime.get_text_embedding_num_tokens.return_value = [1, 1] - - result = text_embedding_model.get_num_tokens( - model="test_model", credentials={"api_key": "abc"}, texts=["hello", "world"] - ) - - assert result == [1, 1] - model_runtime.get_text_embedding_num_tokens.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - texts=["hello", "world"], - ) - - -def test_get_context_size(text_embedding_model: TextEmbeddingModel) -> None: - mock_schema = MagicMock() - mock_schema.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 2048} - - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): - assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 2048 - - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None): - assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 1000 - - mock_schema.model_properties = {} - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): - assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 1000 - - -def test_get_max_chunks(text_embedding_model: TextEmbeddingModel) -> None: - mock_schema = MagicMock() - mock_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10} - - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): - assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 10 - - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None): - assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 1 - - mock_schema.model_properties = {} - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): - assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 1 diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_tts_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_tts_model.py deleted file mode 100644 index d15efb69c30..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_tts_model.py +++ /dev/null @@ -1,83 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.errors.invoke import InvokeError -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel - - -@pytest.fixture -def provider_schema() -> ProviderEntity: - return ProviderEntity( - provider="test_provider", - label=I18nObject(en_US="test_provider"), - supported_model_types=[ModelType.TTS], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -@pytest.fixture -def model_runtime() -> MagicMock: - return MagicMock() - - -@pytest.fixture -def tts_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> TTSModel: - return TTSModel(provider_schema=provider_schema, model_runtime=model_runtime) - - -def test_model_type(tts_model: TTSModel) -> None: - assert tts_model.model_type == ModelType.TTS - - -def test_invoke_success(tts_model: TTSModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_tts.return_value = [b"audio_chunk"] - - result = tts_model.invoke( - model="test_model", - credentials={"api_key": "abc"}, - content_text="Hello world", - voice="alloy", - ) - - assert list(result) == [b"audio_chunk"] - model_runtime.invoke_tts.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - content_text="Hello world", - voice="alloy", - ) - - -def test_invoke_exception(tts_model: TTSModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_tts.side_effect = Exception("Test error") - - with pytest.raises(InvokeError, match="Test error"): - tts_model.invoke( - model="test_model", - credentials={"api_key": "abc"}, - content_text="Hello world", - voice="alloy", - ) - - -def test_get_tts_model_voices(tts_model: TTSModel, model_runtime: MagicMock) -> None: - model_runtime.get_tts_model_voices.return_value = [{"name": "Voice1"}] - - result = tts_model.get_tts_model_voices( - model="test_model", - credentials={"api_key": "abc"}, - language="en-US", - ) - - assert result == [{"name": "Voice1"}] - model_runtime.get_tts_model_voices.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - language="en-US", - ) diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py deleted file mode 100644 index d4d3eeb18c3..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py +++ /dev/null @@ -1,96 +0,0 @@ -from unittest.mock import MagicMock, patch - -import graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer as gpt2_tokenizer_module -from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer - - -class TestGPT2Tokenizer: - def setup_method(self): - # Reset the global tokenizer before each test to ensure we test initialization - gpt2_tokenizer_module._tokenizer = None - - def test_get_encoder_tiktoken(self): - """ - Test that get_encoder successfully uses tiktoken when available. - """ - mock_encoding = MagicMock() - # Mock tiktoken to be sure it's used - with patch("tiktoken.get_encoding", return_value=mock_encoding) as mock_get_encoding: - encoder = GPT2Tokenizer.get_encoder() - assert encoder == mock_encoding - mock_get_encoding.assert_called_once_with("gpt2") - - # Verify singleton behavior within the same test - encoder2 = GPT2Tokenizer.get_encoder() - assert encoder2 is encoder - assert mock_get_encoding.call_count == 1 - - def test_get_encoder_tiktoken_fallback(self): - """ - Test that get_encoder falls back to transformers when tiktoken fails. - """ - # patch tiktoken.get_encoding to raise an exception - with patch("tiktoken.get_encoding", side_effect=Exception("Tiktoken failure")): - # patch transformers.GPT2Tokenizer - with patch("transformers.GPT2Tokenizer.from_pretrained") as mock_from_pretrained: - mock_transformer_tokenizer = MagicMock() - mock_from_pretrained.return_value = mock_transformer_tokenizer - - with patch( - "graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer.logger" - ) as mock_logger: - encoder = GPT2Tokenizer.get_encoder() - - assert encoder == mock_transformer_tokenizer - mock_from_pretrained.assert_called_once() - mock_logger.info.assert_called_once_with("Fallback to Transformers' GPT-2 tokenizer from tiktoken") - - def test_get_num_tokens(self): - """ - Test get_num_tokens returns the correct count. - """ - mock_encoder = MagicMock() - mock_encoder.encode.return_value = [1, 2, 3, 4, 5] - - with patch.object(GPT2Tokenizer, "get_encoder", return_value=mock_encoder): - tokens_count = GPT2Tokenizer.get_num_tokens("test text") - assert tokens_count == 5 - mock_encoder.encode.assert_called_once_with("test text") - - def test_get_num_tokens_by_gpt2_direct(self): - """ - Test _get_num_tokens_by_gpt2 directly. - """ - mock_encoder = MagicMock() - mock_encoder.encode.return_value = [1, 2] - - with patch.object(GPT2Tokenizer, "get_encoder", return_value=mock_encoder): - tokens_count = GPT2Tokenizer._get_num_tokens_by_gpt2("hello") - assert tokens_count == 2 - mock_encoder.encode.assert_called_once_with("hello") - - def test_get_encoder_already_initialized(self): - """ - Test that if _tokenizer is already set, it returns it immediately. - """ - mock_existing_tokenizer = MagicMock() - gpt2_tokenizer_module._tokenizer = mock_existing_tokenizer - - # Tiktoken should not be called if already initialized - with patch("tiktoken.get_encoding") as mock_get_encoding: - encoder = GPT2Tokenizer.get_encoder() - assert encoder == mock_existing_tokenizer - mock_get_encoding.assert_not_called() - - def test_get_encoder_thread_safety(self): - """ - Simple test to ensure the lock is used. - """ - mock_encoding = MagicMock() - with patch("tiktoken.get_encoding", return_value=mock_encoding): - # We patch the lock in the module - with patch("graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer._lock") as mock_lock: - encoder = GPT2Tokenizer.get_encoder() - assert encoder == mock_encoding - mock_lock.__enter__.assert_called_once() - mock_lock.__exit__.assert_called_once() diff --git a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_common_validator.py b/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_common_validator.py deleted file mode 100644 index 60ded4b90aa..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_common_validator.py +++ /dev/null @@ -1,201 +0,0 @@ -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.provider_entities import ( - CredentialFormSchema, - FormOption, - FormShowOnObject, - FormType, -) -from graphon.model_runtime.schema_validators.common_validator import CommonValidator - - -class TestCommonValidator: - def test_validate_credential_form_schema_required_missing(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True - ) - with pytest.raises(ValueError, match="Variable api_key is required"): - validator._validate_credential_form_schema(schema, {}) - - def test_validate_credential_form_schema_not_required_missing_with_default(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key"), - type=FormType.TEXT_INPUT, - required=False, - default="default_value", - ) - assert validator._validate_credential_form_schema(schema, {}) == "default_value" - - def test_validate_credential_form_schema_not_required_missing_no_default(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=False - ) - assert validator._validate_credential_form_schema(schema, {}) is None - - def test_validate_credential_form_schema_max_length_exceeded(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, max_length=5 - ) - with pytest.raises(ValueError, match="Variable api_key length should not be greater than 5"): - validator._validate_credential_form_schema(schema, {"api_key": "123456"}) - - def test_validate_credential_form_schema_not_string(self): - validator = CommonValidator() - schema = CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT) - with pytest.raises(ValueError, match="Variable api_key should be string"): - validator._validate_credential_form_schema(schema, {"api_key": 123}) - - def test_validate_credential_form_schema_select_invalid_option(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="mode", - label=I18nObject(en_US="Mode"), - type=FormType.SELECT, - options=[ - FormOption(label=I18nObject(en_US="Fast"), value="fast"), - FormOption(label=I18nObject(en_US="Slow"), value="slow"), - ], - ) - with pytest.raises(ValueError, match="Variable mode is not in options"): - validator._validate_credential_form_schema(schema, {"mode": "medium"}) - - def test_validate_credential_form_schema_select_valid_option(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="mode", - label=I18nObject(en_US="Mode"), - type=FormType.SELECT, - options=[ - FormOption(label=I18nObject(en_US="Fast"), value="fast"), - FormOption(label=I18nObject(en_US="Slow"), value="slow"), - ], - ) - assert validator._validate_credential_form_schema(schema, {"mode": "fast"}) == "fast" - - def test_validate_credential_form_schema_switch_invalid(self): - validator = CommonValidator() - schema = CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH) - with pytest.raises(ValueError, match="Variable enabled should be true or false"): - validator._validate_credential_form_schema(schema, {"enabled": "maybe"}) - - def test_validate_credential_form_schema_switch_valid(self): - validator = CommonValidator() - schema = CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH) - assert validator._validate_credential_form_schema(schema, {"enabled": "true"}) is True - assert validator._validate_credential_form_schema(schema, {"enabled": "FALSE"}) is False - - def test_validate_and_filter_credential_form_schemas_with_show_on(self): - validator = CommonValidator() - schemas = [ - CredentialFormSchema( - variable="auth_type", - label=I18nObject(en_US="Auth Type"), - type=FormType.SELECT, - options=[ - FormOption(label=I18nObject(en_US="API Key"), value="api_key"), - FormOption(label=I18nObject(en_US="OAuth"), value="oauth"), - ], - ), - CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key"), - type=FormType.TEXT_INPUT, - show_on=[FormShowOnObject(variable="auth_type", value="api_key")], - ), - CredentialFormSchema( - variable="client_id", - label=I18nObject(en_US="Client ID"), - type=FormType.TEXT_INPUT, - show_on=[FormShowOnObject(variable="auth_type", value="oauth")], - ), - ] - - # Case 1: auth_type = api_key - credentials = {"auth_type": "api_key", "api_key": "my_secret"} - result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) - assert "auth_type" in result - assert "api_key" in result - assert "client_id" not in result - assert result["api_key"] == "my_secret" - - # Case 2: auth_type = oauth - credentials = {"auth_type": "oauth", "client_id": "my_client"} - result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) - # Note: 'auth_type' contains 'oauth'. 'result' contains keys that pass validation. - # Since 'oauth' is not an empty string, it is in result. - assert "auth_type" in result - assert "api_key" not in result - assert "client_id" in result - assert result["client_id"] == "my_client" - - def test_validate_and_filter_show_on_missing_variable(self): - validator = CommonValidator() - schemas = [ - CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key"), - type=FormType.TEXT_INPUT, - show_on=[FormShowOnObject(variable="auth_type", value="api_key")], - ) - ] - # auth_type is missing in credentials, so api_key should be filtered out - result = validator._validate_and_filter_credential_form_schemas(schemas, {}) - assert result == {} - - def test_validate_and_filter_show_on_mismatch_value(self): - validator = CommonValidator() - schemas = [ - CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key"), - type=FormType.TEXT_INPUT, - show_on=[FormShowOnObject(variable="auth_type", value="api_key")], - ) - ] - # auth_type is oauth, which doesn't match show_on - result = validator._validate_and_filter_credential_form_schemas(schemas, {"auth_type": "oauth"}) - assert result == {} - - def test_validate_and_filter_multiple_show_on(self): - validator = CommonValidator() - schemas = [ - CredentialFormSchema( - variable="target", - label=I18nObject(en_US="Target"), - type=FormType.TEXT_INPUT, - show_on=[FormShowOnObject(variable="v1", value="a"), FormShowOnObject(variable="v2", value="b")], - ) - ] - # Both match - assert "target" in validator._validate_and_filter_credential_form_schemas( - schemas, {"v1": "a", "v2": "b", "target": "val"} - ) - # One mismatch - assert "target" not in validator._validate_and_filter_credential_form_schemas( - schemas, {"v1": "a", "v2": "c", "target": "val"} - ) - # One missing - assert "target" not in validator._validate_and_filter_credential_form_schemas( - schemas, {"v1": "a", "target": "val"} - ) - - def test_validate_and_filter_skips_falsy_results(self): - validator = CommonValidator() - schemas = [ - CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH), - CredentialFormSchema( - variable="empty_str", label=I18nObject(en_US="Empty"), type=FormType.TEXT_INPUT, required=False - ), - ] - # Result of false switch is False. if result: is false. Not added. - # Result of empty string is "", if result: is false. Not added. - credentials = {"enabled": "false", "empty_str": ""} - result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) - assert "enabled" not in result - assert "empty_str" not in result diff --git a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_model_credential_schema_validator.py b/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_model_credential_schema_validator.py deleted file mode 100644 index 3932844b918..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_model_credential_schema_validator.py +++ /dev/null @@ -1,233 +0,0 @@ -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - CredentialFormSchema, - FieldModelSchema, - FormOption, - FormShowOnObject, - FormType, - ModelCredentialSchema, -) -from graphon.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator - - -def test_validate_and_filter_with_none_schema(): - validator = ModelCredentialSchemaValidator(ModelType.LLM, None) - with pytest.raises(ValueError, match="Model credential schema is None"): - validator.validate_and_filter({}) - - -def test_validate_and_filter_success(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="ๆจกๅž‹")), - credential_form_schemas=[ - CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key", zh_Hans="API Key"), - type=FormType.SECRET_INPUT, - required=True, - ), - CredentialFormSchema( - variable="optional_field", - label=I18nObject(en_US="Optional", zh_Hans="ๅฏ้€‰"), - type=FormType.TEXT_INPUT, - required=False, - default="default_val", - ), - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - credentials = {"api_key": "sk-123456"} - result = validator.validate_and_filter(credentials) - - assert result["api_key"] == "sk-123456" - assert result["optional_field"] == "default_val" - assert credentials["__model_type"] == ModelType.LLM.value - - -def test_validate_and_filter_with_show_on(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="ๆจกๅž‹")), - credential_form_schemas=[ - CredentialFormSchema( - variable="mode", label=I18nObject(en_US="Mode", zh_Hans="ๆจกๅผ"), type=FormType.TEXT_INPUT, required=True - ), - CredentialFormSchema( - variable="conditional_field", - label=I18nObject(en_US="Conditional", zh_Hans="ๆกไปถ"), - type=FormType.TEXT_INPUT, - required=True, - show_on=[FormShowOnObject(variable="mode", value="advanced")], - ), - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - # mode is 'simple', conditional_field should be filtered out - credentials = {"mode": "simple", "conditional_field": "secret"} - result = validator.validate_and_filter(credentials) - assert "conditional_field" not in result - assert result["mode"] == "simple" - - # mode is 'advanced', conditional_field should be kept - credentials = {"mode": "advanced", "conditional_field": "secret"} - result = validator.validate_and_filter(credentials) - assert result["conditional_field"] == "secret" - assert result["mode"] == "advanced" - - # show_on variable missing in credentials - credentials = {"conditional_field": "secret"} # mode missing - with pytest.raises(ValueError, match="Variable mode is required"): # because mode is required in schema - validator.validate_and_filter(credentials) - - -def test_validate_and_filter_show_on_missing_trigger_var(): - # specifically test all_show_on_match = False when variable not in credentials - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="ๆจกๅž‹")), - credential_form_schemas=[ - CredentialFormSchema( - variable="optional_trigger", - label=I18nObject(en_US="Optional Trigger", zh_Hans="ๅฏ้€‰่งฆๅ‘"), - type=FormType.TEXT_INPUT, - required=False, - ), - CredentialFormSchema( - variable="conditional_field", - label=I18nObject(en_US="Conditional", zh_Hans="ๆกไปถ"), - type=FormType.TEXT_INPUT, - required=False, - show_on=[FormShowOnObject(variable="optional_trigger", value="active")], - ), - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - # optional_trigger missing, conditional_field should be skipped - result = validator.validate_and_filter({"conditional_field": "val"}) - assert "conditional_field" not in result - - -def test_common_validator_logic_required(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="ๆจกๅž‹")), - credential_form_schemas=[ - CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key", zh_Hans="API Key"), - type=FormType.SECRET_INPUT, - required=True, - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - with pytest.raises(ValueError, match="Variable api_key is required"): - validator.validate_and_filter({}) - - with pytest.raises(ValueError, match="Variable api_key is required"): - validator.validate_and_filter({"api_key": ""}) - - -def test_common_validator_logic_max_length(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="ๆจกๅž‹")), - credential_form_schemas=[ - CredentialFormSchema( - variable="key", - label=I18nObject(en_US="Key", zh_Hans="Key"), - type=FormType.TEXT_INPUT, - required=True, - max_length=5, - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - with pytest.raises(ValueError, match="Variable key length should not be greater than 5"): - validator.validate_and_filter({"key": "123456"}) - - -def test_common_validator_logic_invalid_type(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="ๆจกๅž‹")), - credential_form_schemas=[ - CredentialFormSchema( - variable="key", label=I18nObject(en_US="Key", zh_Hans="Key"), type=FormType.TEXT_INPUT, required=True - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - with pytest.raises(ValueError, match="Variable key should be string"): - validator.validate_and_filter({"key": 123}) - - -def test_common_validator_logic_switch(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="ๆจกๅž‹")), - credential_form_schemas=[ - CredentialFormSchema( - variable="enabled", - label=I18nObject(en_US="Enabled", zh_Hans="ๅฏ็”จ"), - type=FormType.SWITCH, - required=True, - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - result = validator.validate_and_filter({"enabled": "true"}) - assert result["enabled"] is True - - result = validator.validate_and_filter({"enabled": "false"}) - assert "enabled" not in result - - with pytest.raises(ValueError, match="Variable enabled should be true or false"): - validator.validate_and_filter({"enabled": "not_a_bool"}) - - -def test_common_validator_logic_options(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="ๆจกๅž‹")), - credential_form_schemas=[ - CredentialFormSchema( - variable="choice", - label=I18nObject(en_US="Choice", zh_Hans="้€‰ๆ‹ฉ"), - type=FormType.SELECT, - required=True, - options=[ - FormOption(label=I18nObject(en_US="A", zh_Hans="A"), value="a"), - FormOption(label=I18nObject(en_US="B", zh_Hans="B"), value="b"), - ], - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - result = validator.validate_and_filter({"choice": "a"}) - assert result["choice"] == "a" - - with pytest.raises(ValueError, match="Variable choice is not in options"): - validator.validate_and_filter({"choice": "c"}) - - -def test_validate_and_filter_optional_no_default(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="ๆจกๅž‹")), - credential_form_schemas=[ - CredentialFormSchema( - variable="optional", - label=I18nObject(en_US="Optional", zh_Hans="ๅฏ้€‰"), - type=FormType.TEXT_INPUT, - required=False, - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - result = validator.validate_and_filter({}) - assert "optional" not in result diff --git a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_provider_credential_schema_validator.py b/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_provider_credential_schema_validator.py deleted file mode 100644 index f7a2a5b6235..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_provider_credential_schema_validator.py +++ /dev/null @@ -1,72 +0,0 @@ -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderCredentialSchema -from graphon.model_runtime.schema_validators.provider_credential_schema_validator import ( - ProviderCredentialSchemaValidator, -) - - -class TestProviderCredentialSchemaValidator: - def test_validate_and_filter_success(self): - # Setup schema - schema = ProviderCredentialSchema( - credential_form_schemas=[ - CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True - ), - CredentialFormSchema( - variable="endpoint", - label=I18nObject(en_US="Endpoint"), - type=FormType.TEXT_INPUT, - required=False, - default="https://api.example.com", - ), - ] - ) - validator = ProviderCredentialSchemaValidator(schema) - - # Test valid credentials - credentials = {"api_key": "my-secret-key"} - result = validator.validate_and_filter(credentials) - - assert result == {"api_key": "my-secret-key", "endpoint": "https://api.example.com"} - - def test_validate_and_filter_missing_required(self): - # Setup schema - schema = ProviderCredentialSchema( - credential_form_schemas=[ - CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True - ) - ] - ) - validator = ProviderCredentialSchemaValidator(schema) - - # Test missing required credentials - with pytest.raises(ValueError, match="Variable api_key is required"): - validator.validate_and_filter({}) - - def test_validate_and_filter_extra_fields_filtered(self): - # Setup schema - schema = ProviderCredentialSchema( - credential_form_schemas=[ - CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True - ) - ] - ) - validator = ProviderCredentialSchemaValidator(schema) - - # Test credentials with extra fields - credentials = {"api_key": "my-secret-key", "extra_field": "should-be-filtered"} - result = validator.validate_and_filter(credentials) - - assert "api_key" in result - assert "extra_field" not in result - assert result == {"api_key": "my-secret-key"} - - def test_init(self): - schema = ProviderCredentialSchema(credential_form_schemas=[]) - validator = ProviderCredentialSchemaValidator(schema) - assert validator.provider_credential_schema == schema diff --git a/api/tests/unit_tests/graphon/model_runtime/utils/test_encoders.py b/api/tests/unit_tests/graphon/model_runtime/utils/test_encoders.py deleted file mode 100644 index 8edc143faeb..00000000000 --- a/api/tests/unit_tests/graphon/model_runtime/utils/test_encoders.py +++ /dev/null @@ -1,231 +0,0 @@ -import dataclasses -import datetime -from collections import deque -from decimal import Decimal -from enum import Enum -from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network -from pathlib import Path, PurePath -from re import compile -from typing import Any -from unittest.mock import MagicMock -from uuid import UUID - -import pytest -from pydantic import BaseModel, ConfigDict -from pydantic.networks import AnyUrl, NameEmail -from pydantic.types import SecretBytes, SecretStr -from pydantic_core import Url -from pydantic_extra_types.color import Color - -from graphon.model_runtime.utils.encoders import ( - _model_dump, - decimal_encoder, - generate_encoders_by_class_tuples, - isoformat, - jsonable_encoder, -) - - -class MockEnum(Enum): - A = "a" - B = "b" - - -class MockPydanticModel(BaseModel): - model_config = ConfigDict(populate_by_name=True) - name: str - age: int - - -@dataclasses.dataclass -class MockDataclass: - name: str - value: Any - - -class MockWithDict: - def __init__(self, data): - self.data = data - - def __iter__(self): - return iter(self.data.items()) - - -class MockWithVars: - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - - -class TestEncoders: - def test_model_dump(self): - model = MockPydanticModel(name="test", age=20) - result = _model_dump(model) - assert result == {"name": "test", "age": 20} - - def test_isoformat(self): - d = datetime.date(2023, 1, 1) - assert isoformat(d) == "2023-01-01" - t = datetime.time(12, 0, 0) - assert isoformat(t) == "12:00:00" - - def test_decimal_encoder(self): - assert decimal_encoder(Decimal("1.0")) == 1.0 - assert decimal_encoder(Decimal(1)) == 1 - assert decimal_encoder(Decimal("1.5")) == 1.5 - assert decimal_encoder(Decimal(0)) == 0 - assert decimal_encoder(Decimal(-1)) == -1 - - def test_generate_encoders_by_class_tuples(self): - type_map = {int: str, float: str, str: int} - result = generate_encoders_by_class_tuples(type_map) - assert result[str] == (int, float) - assert result[int] == (str,) - - def test_jsonable_encoder_basic_types(self): - assert jsonable_encoder("string") == "string" - assert jsonable_encoder(123) == 123 - assert jsonable_encoder(1.23) == 1.23 - assert jsonable_encoder(None) is None - - def test_jsonable_encoder_pydantic(self): - model = MockPydanticModel(name="test", age=20) - assert jsonable_encoder(model) == {"name": "test", "age": 20} - - def test_jsonable_encoder_pydantic_root(self): - # Manually create a mock that behaves like a model with __root__ - # because Pydantic v2 handles root differently, but the code checks for "__root__" - model = MagicMock(spec=BaseModel) - # _model_dump(obj, mode="json", ...) -> model.model_dump(mode="json", ...) - model.model_dump.return_value = {"__root__": [1, 2, 3]} - assert jsonable_encoder(model) == [1, 2, 3] - - def test_jsonable_encoder_dataclass(self): - obj = MockDataclass(name="test", value=1) - assert jsonable_encoder(obj) == {"name": "test", "value": 1} - # Test dataclass type (should not be treated as instance) - # It should fall back to vars() or dict() or at least not crash - with pytest.raises(ValueError): - jsonable_encoder(MockDataclass) - - def test_jsonable_encoder_enum(self): - assert jsonable_encoder(MockEnum.A) == "a" - - def test_jsonable_encoder_path(self): - assert jsonable_encoder(Path("/tmp/test")) == "/tmp/test" - assert jsonable_encoder(PurePath("/tmp/test")) == "/tmp/test" - - def test_jsonable_encoder_decimal(self): - # In jsonable_encoder, Decimal is formatted as string via format(obj, "f") - assert jsonable_encoder(Decimal("1.23")) == "1.23" - assert jsonable_encoder(Decimal("1.000")) == "1.000" - - def test_jsonable_encoder_dict(self): - d = {"a": 1, "b": [2, 3], "_private": "hidden"} - assert jsonable_encoder(d) == {"a": 1, "b": [2, 3], "_private": "hidden"} - assert jsonable_encoder(d, excluded_key_prefixes=("_",)) == {"a": 1, "b": [2, 3]} - - d_with_none = {"a": 1, "b": None} - assert jsonable_encoder(d_with_none, exclude_none=True) == {"a": 1} - assert jsonable_encoder(d_with_none, exclude_none=False) == {"a": 1, "b": None} - - def test_jsonable_encoder_collections(self): - assert jsonable_encoder([1, 2]) == [1, 2] - assert jsonable_encoder((1, 2)) == [1, 2] - assert jsonable_encoder({1, 2}) == [1, 2] - assert jsonable_encoder(frozenset([1, 2])) == [1, 2] - assert jsonable_encoder(deque([1, 2])) == [1, 2] - - def gen(): - yield 1 - yield 2 - - assert jsonable_encoder(gen()) == [1, 2] - - def test_jsonable_encoder_custom_encoder(self): - custom = {int: lambda x: str(x + 1)} - assert jsonable_encoder(1, custom_encoder=custom) == "2" - - # Test subclass matching for custom encoder - class SubInt(int): - pass - - assert jsonable_encoder(SubInt(1), custom_encoder=custom) == "2" - - def test_jsonable_encoder_special_types(self): - # These hit ENCODERS_BY_TYPE or encoders_by_class_tuples - assert jsonable_encoder(b"bytes") == "bytes" - assert jsonable_encoder(Color("red")) == "red" - - dt = datetime.datetime(2023, 1, 1, 12, 0, 0) - assert jsonable_encoder(dt) == dt.isoformat() - - date = datetime.date(2023, 1, 1) - assert jsonable_encoder(date) == date.isoformat() - - time = datetime.time(12, 0, 0) - assert jsonable_encoder(time) == time.isoformat() - - td = datetime.timedelta(seconds=60) - assert jsonable_encoder(td) == 60.0 - - assert jsonable_encoder(IPv4Address("127.0.0.1")) == "127.0.0.1" - assert jsonable_encoder(IPv4Interface("127.0.0.1/24")) == "127.0.0.1/24" - assert jsonable_encoder(IPv4Network("127.0.0.0/24")) == "127.0.0.0/24" - assert jsonable_encoder(IPv6Address("::1")) == "::1" - assert jsonable_encoder(IPv6Interface("::1/128")) == "::1/128" - assert jsonable_encoder(IPv6Network("::/128")) == "::/128" - - assert jsonable_encoder(NameEmail(name="test", email="test@example.com")) == "test " - - assert jsonable_encoder(compile("abc")) == "abc" - - # Secret types - # Check what they actually return in this environment - res_bytes = jsonable_encoder(SecretBytes(b"secret")) - assert "**********" in res_bytes - - res_str = jsonable_encoder(SecretStr("secret")) - assert res_str == "**********" - - u = UUID("12345678-1234-5678-1234-567812345678") - assert jsonable_encoder(u) == str(u) - - url = AnyUrl("https://example.com") - assert jsonable_encoder(url) == "https://example.com/" - - purl = Url("https://example.com") - assert jsonable_encoder(purl) == "https://example.com/" - - def test_jsonable_encoder_fallback(self): - # dict(obj) success - obj_dict = MockWithDict({"a": 1}) - assert jsonable_encoder(obj_dict) == {"a": 1} - - # vars(obj) success - obj_vars = MockWithVars(x=10, y=20) - assert jsonable_encoder(obj_vars) == {"x": 10, "y": 20} - - # error fallback - class ReallyUnserializable: - __slots__ = ["__weakref__"] # No __dict__ - - def __iter__(self): - raise TypeError("not iterable") - - with pytest.raises(ValueError) as exc: - jsonable_encoder(ReallyUnserializable()) - assert "not iterable" in str(exc.value) - - def test_jsonable_encoder_nested(self): - data = { - "model": MockPydanticModel(name="test", age=20), - "list": [Decimal("1.1"), {MockEnum.A: Path("/tmp")}], - "set": {1, 2}, - } - expected = { - "model": {"name": "test", "age": 20}, - "list": ["1.1", {"a": "/tmp"}], - "set": [1, 2], - } - assert jsonable_encoder(data) == expected diff --git a/api/tests/unit_tests/graphon/node_events/test_base.py b/api/tests/unit_tests/graphon/node_events/test_base.py deleted file mode 100644 index 4ff12702650..00000000000 --- a/api/tests/unit_tests/graphon/node_events/test_base.py +++ /dev/null @@ -1,19 +0,0 @@ -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.node_events.base import NodeRunResult - - -def test_node_run_result_accepts_trigger_info_metadata() -> None: - result = NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - metadata={ - WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: { - "provider_id": "provider-id", - "event_name": "event-name", - } - }, - ) - - assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == { - "provider_id": "provider-id", - "event_name": "event-name", - } diff --git a/api/tests/unit_tests/graphon/utils/test_json_in_md_parser.py b/api/tests/unit_tests/graphon/utils/test_json_in_md_parser.py deleted file mode 100644 index a8c86d288cf..00000000000 --- a/api/tests/unit_tests/graphon/utils/test_json_in_md_parser.py +++ /dev/null @@ -1,75 +0,0 @@ -import pytest - -from graphon.utils.json_in_md_parser import ( - OutputParserError, - parse_and_check_json_markdown, - parse_json_markdown, -) - - -def test_parse_json_markdown_extracts_fenced_json_object() -> None: - src = """ - ```json - {"a": 1, "b": "x"} - ``` - """ - - assert parse_json_markdown(src) == {"a": 1, "b": "x"} - - -def test_parse_json_markdown_extracts_raw_json_array() -> None: - assert parse_json_markdown('[{"a": 1}]') == {"a": 1} - - -def test_parse_json_markdown_raises_when_no_json_block_exists() -> None: - with pytest.raises(ValueError, match="could not find json block"): - parse_json_markdown("plain text only") - - -def test_parse_and_check_json_markdown_unwraps_single_dict_list() -> None: - parsed = parse_and_check_json_markdown( - """ - ```json - [{"present": 1, "other": 2}] - ``` - """, - ["present"], - ) - - assert parsed == {"present": 1, "other": 2} - - -def test_parse_and_check_json_markdown_rejects_invalid_json() -> None: - with pytest.raises(OutputParserError, match="got invalid json object"): - parse_and_check_json_markdown( - """ - ```json - {invalid json} - ``` - """, - [], - ) - - -def test_parse_and_check_json_markdown_rejects_invalid_return_shapes() -> None: - with pytest.raises(OutputParserError, match="got invalid return object"): - parse_and_check_json_markdown( - """ - ```json - [1, 2] - ``` - """, - ["present"], - ) - - -def test_parse_and_check_json_markdown_requires_expected_keys() -> None: - with pytest.raises(OutputParserError, match="expected key `missing`"): - parse_and_check_json_markdown( - """ - ```json - {"present": 1} - ``` - """, - ["present", "missing"], - ) diff --git a/api/tests/unit_tests/libs/_human_input/support.py b/api/tests/unit_tests/libs/_human_input/support.py index e6cc23161e4..13577b7ca56 100644 --- a/api/tests/unit_tests/libs/_human_input/support.py +++ b/api/tests/unit_tests/libs/_human_input/support.py @@ -6,6 +6,7 @@ from typing import Any from graphon.nodes.human_input.entities import FormInput from graphon.nodes.human_input.enums import TimeoutUnit + from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/libs/_human_input/test_form_service.py b/api/tests/unit_tests/libs/_human_input/test_form_service.py index fa2c02020b4..f1ce1a2c1c9 100644 --- a/api/tests/unit_tests/libs/_human_input/test_form_service.py +++ b/api/tests/unit_tests/libs/_human_input/test_form_service.py @@ -5,7 +5,6 @@ Unit tests for FormService. from datetime import timedelta import pytest - from graphon.nodes.human_input.entities import ( FormInput, UserAction, @@ -14,6 +13,7 @@ from graphon.nodes.human_input.enums import ( FormInputType, TimeoutUnit, ) + from libs.datetime_utils import naive_utc_now from .support import ( diff --git a/api/tests/unit_tests/libs/_human_input/test_models.py b/api/tests/unit_tests/libs/_human_input/test_models.py index 866ee61b3eb..0babfbb3157 100644 --- a/api/tests/unit_tests/libs/_human_input/test_models.py +++ b/api/tests/unit_tests/libs/_human_input/test_models.py @@ -5,7 +5,6 @@ Unit tests for human input form models. from datetime import datetime, timedelta import pytest - from graphon.nodes.human_input.entities import ( FormInput, UserAction, @@ -14,6 +13,7 @@ from graphon.nodes.human_input.enums import ( FormInputType, TimeoutUnit, ) + from libs.datetime_utils import naive_utc_now from .support import FormSubmissionData, FormSubmissionRequest, HumanInputForm diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py index bb3a6db1a1c..86163f15540 100644 --- a/api/tests/unit_tests/models/test_conversation_variable.py +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -1,7 +1,8 @@ from uuid import uuid4 -from factories import variable_factory from graphon.variables import SegmentType + +from factories import variable_factory from models import ConversationVariable diff --git a/api/tests/unit_tests/models/test_model.py b/api/tests/unit_tests/models/test_model.py index e21f0e4fbd9..a5909f60a80 100644 --- a/api/tests/unit_tests/models/test_model.py +++ b/api/tests/unit_tests/models/test_model.py @@ -2,9 +2,9 @@ import importlib import types import pytest +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod from core.workflow.file_reference import build_file_reference -from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod from models.model import Conversation, Message diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index 550441539a1..e7c0479757b 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -3,14 +3,14 @@ import json from unittest import mock from uuid import uuid4 +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable +from graphon.variables.segments import IntegerSegment, Segment + from constants import HIDDEN_VALUE from core.helper import encrypter from core.workflow.file_reference import build_file_reference from factories.variable_factory import build_segment -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File -from graphon.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable -from graphon.variables.segments import IntegerSegment, Segment from models.workflow import ( Workflow, WorkflowDraftVariable, diff --git a/api/tests/unit_tests/models/test_workflow_models.py b/api/tests/unit_tests/models/test_workflow_models.py index eb9fef75878..507e1c8c3ac 100644 --- a/api/tests/unit_tests/models/test_workflow_models.py +++ b/api/tests/unit_tests/models/test_workflow_models.py @@ -13,12 +13,12 @@ from datetime import UTC, datetime from uuid import uuid4 import pytest - from graphon.enums import ( BuiltinNodeTypes, WorkflowExecutionStatus, WorkflowNodeExecutionStatus, ) + from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import ( Workflow, diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py index ccc9c938159..10850970d82 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -9,11 +9,6 @@ from decimal import Decimal from unittest.mock import MagicMock, PropertyMock import pytest -from pytest_mock import MockerFixture -from sqlalchemy.orm import Session, sessionmaker - -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.repositories.factory import OrderConfig from graphon.entities import ( WorkflowNodeExecution, ) @@ -23,6 +18,11 @@ from graphon.enums import ( WorkflowNodeExecutionStatus, ) from graphon.model_runtime.utils.encoders import jsonable_encoder +from pytest_mock import MockerFixture +from sqlalchemy.orm import Session, sessionmaker + +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from models.account import Account, Tenant from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py index e8c094b75d8..2322be9e80d 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py @@ -6,13 +6,13 @@ from datetime import datetime from typing import Any from unittest.mock import MagicMock, Mock +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from graphon.entities.workflow_node_execution import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/services/dataset_service_test_helpers.py b/api/tests/unit_tests/services/dataset_service_test_helpers.py index c95b60fad03..ef73bc0e01b 100644 --- a/api/tests/unit_tests/services/dataset_service_test_helpers.py +++ b/api/tests/unit_tests/services/dataset_service_test_helpers.py @@ -10,6 +10,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, create_autospec, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from werkzeug.exceptions import Forbidden, NotFound from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError @@ -17,7 +18,6 @@ from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod from enums.cloud_plan import CloudPlan -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from models import Account, TenantAccountRole from models.dataset import ( ChildChunk, diff --git a/api/tests/unit_tests/services/document_service_validation.py b/api/tests/unit_tests/services/document_service_validation.py index 3358c8b44d2..7c36e9d9602 100644 --- a/api/tests/unit_tests/services/document_service_validation.py +++ b/api/tests/unit_tests/services/document_service_validation.py @@ -109,10 +109,10 @@ This test suite follows a comprehensive testing strategy that covers: from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelType from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, DatasetProcessRule, Document from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import ( diff --git a/api/tests/unit_tests/services/test_app_dsl_service.py b/api/tests/unit_tests/services/test_app_dsl_service.py index afea8ec92a7..179518a5fad 100644 --- a/api/tests/unit_tests/services/test_app_dsl_service.py +++ b/api/tests/unit_tests/services/test_app_dsl_service.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock import pytest import yaml +from graphon.enums import BuiltinNodeTypes from core.trigger.constants import ( TRIGGER_PLUGIN_NODE_TYPE, TRIGGER_SCHEDULE_NODE_TYPE, TRIGGER_WEBHOOK_NODE_TYPE, ) -from graphon.enums import BuiltinNodeTypes from models import Account, AppMode from models.model import IconType from services import app_dsl_service diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py index da932396000..3df7d500cf2 100644 --- a/api/tests/unit_tests/services/test_datasource_provider_service.py +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -1,10 +1,10 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy.orm import Session from core.plugin.entities.plugin_daemon import CredentialType -from graphon.model_runtime.entities.provider_entities import FormType from models.account import Account from models.model import EndUser from models.oauth import DatasourceProvider diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py index 55af5648219..9be475d043f 100644 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -3,18 +3,18 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest - -import services.human_input_service as human_input_service_module -from core.repositories.human_input_repository import ( - HumanInputFormRecord, - HumanInputFormSubmissionRepository, -) from graphon.nodes.human_input.entities import ( FormDefinition, FormInput, UserAction, ) from graphon.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus + +import services.human_input_service as human_input_service_module +from core.repositories.human_input_repository import ( + HumanInputFormRecord, + HumanInputFormSubmissionRepository, +) from libs.datetime_utils import naive_utc_now from models.human_input import RecipientType from services.human_input_service import ( diff --git a/api/tests/unit_tests/services/test_model_load_balancing_service.py b/api/tests/unit_tests/services/test_model_load_balancing_service.py index 1e898ada11c..b43e79dff50 100644 --- a/api/tests/unit_tests/services/test_model_load_balancing_service.py +++ b/api/tests/unit_tests/services/test_model_load_balancing_service.py @@ -6,9 +6,6 @@ from typing import Any, cast from unittest.mock import MagicMock import pytest -from pytest_mock import MockerFixture - -from constants import HIDDEN_VALUE from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.provider_entities import ( @@ -18,6 +15,9 @@ from graphon.model_runtime.entities.provider_entities import ( ModelCredentialSchema, ProviderCredentialSchema, ) +from pytest_mock import MockerFixture + +from constants import HIDDEN_VALUE from models.provider import LoadBalancingModelConfig from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py index 97f3bd6f013..1bd979b9ec2 100644 --- a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -1,11 +1,11 @@ import types import pytest - -from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.provider_entities import ConfigurateMethod + +from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration from models.provider import ProviderType from services.model_provider_service import ModelProviderService diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index 2fe61617855..9c231352256 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -16,9 +16,7 @@ from typing import Any from uuid import uuid4 import pytest - -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File +from graphon.file import File, FileTransferMethod, FileType from graphon.variables.segments import ( ArrayFileSegment, ArrayNumberSegment, @@ -30,6 +28,7 @@ from graphon.variables.segments import ( ObjectSegment, StringSegment, ) + from services.variable_truncator import ( DummyVariableTruncator, MaxDepthExceededError, diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index 239cc83518e..a62c9f45556 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -13,10 +13,10 @@ from datetime import datetime from unittest.mock import MagicMock, create_autospec, patch import pytest +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker -from graphon.enums import WorkflowExecutionStatus from models.workflow import WorkflowPause from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index da606c8329e..cd71981bcf1 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -15,7 +15,6 @@ from typing import Any, cast from unittest.mock import ANY, MagicMock, patch import pytest - from graphon.entities import WorkflowNodeExecution from graphon.enums import ( BuiltinNodeTypes, @@ -29,6 +28,7 @@ from graphon.model_runtime.entities.model_entities import ModelType from graphon.node_events import NodeRunResult from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig from graphon.variables.input_entities import VariableEntityType + from libs.datetime_utils import naive_utc_now from models.human_input import RecipientType from models.model import App, AppMode diff --git a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py index 2db83576b0c..8525672da8e 100644 --- a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py +++ b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py @@ -4,13 +4,12 @@ import json from unittest.mock import Mock, patch import pytest +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segments import ObjectSegment, StringSegment +from graphon.variables.types import SegmentType from sqlalchemy import Engine from core.workflow.file_reference import build_file_reference -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File -from graphon.variables.segments import ObjectSegment, StringSegment -from graphon.variables.types import SegmentType from models.model import UploadFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile from services.workflow_draft_variable_service import DraftVarLoader diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index 6200c9f8596..e7e72793a32 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -4,6 +4,10 @@ import uuid from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from sqlalchemy import Engine from sqlalchemy.orm import Session @@ -13,11 +17,6 @@ from core.workflow.variable_prefixes import ( ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) -from graphon.enums import BuiltinNodeTypes -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType from libs.uuid_utils import uuidv7 from models.account import Account from models.enums import DraftVariableType diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py index ce66b78b64d..077a7c27a2b 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py @@ -8,13 +8,13 @@ from datetime import UTC, datetime from threading import Event import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool from models.enums import CreatorUserRole from models.model import AppMode from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py index d7192994b2d..98d057e41fe 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py @@ -3,6 +3,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from sqlalchemy.orm import sessionmaker from core.workflow.human_input_compat import ( @@ -12,9 +15,6 @@ from core.workflow.human_input_compat import ( ExternalRecipient, MemberRecipient, ) -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.human_input.entities import HumanInputNodeData from services import workflow_service as workflow_service_module from services.workflow_service import WorkflowService diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 6b04a1bc09c..b9d097350b9 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -3,11 +3,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest - from graphon.entities.graph_config import NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes from graphon.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction from graphon.nodes.human_input.enums import FormInputType + from models.model import App from models.workflow import Workflow from services import workflow_service as workflow_service_module diff --git a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py index 591da56f494..7119217e94e 100644 --- a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py +++ b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py @@ -5,8 +5,8 @@ from types import SimpleNamespace from typing import Any import pytest - from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus + from tasks import human_input_timeout_tasks as task_module diff --git a/api/tests/unit_tests/tools/test_mcp_tool.py b/api/tests/unit_tests/tools/test_mcp_tool.py index f31bf800468..68359ba078d 100644 --- a/api/tests/unit_tests/tools/test_mcp_tool.py +++ b/api/tests/unit_tests/tools/test_mcp_tool.py @@ -3,6 +3,7 @@ from decimal import Decimal from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMUsage from core.mcp.types import ( AudioContent, @@ -17,7 +18,6 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.mcp_tool.tool import MCPTool -from graphon.model_runtime.entities.llm_entities import LLMUsage def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool: diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index c166a946d92..ffa6833524d 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -2,9 +2,6 @@ from decimal import Decimal from unittest.mock import MagicMock, patch import pytest - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, @@ -21,6 +18,9 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output + def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage: """Create a mock LLMUsage with all required fields""" diff --git a/api/tests/workflow_test_utils.py b/api/tests/workflow_test_utils.py index a29df0bb6b4..d33ac2c7108 100644 --- a/api/tests/workflow_test_utils.py +++ b/api/tests/workflow_test_utils.py @@ -1,12 +1,13 @@ from collections.abc import Mapping from typing import Any -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool -from graphon.entities.graph_init_params import GraphInitParams +from graphon.entities import GraphInitParams from graphon.runtime import VariablePool from graphon.variables.variables import Variable +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool + def build_test_run_context( *, diff --git a/api/uv.lock b/api/uv.lock index ed2b76ac3c0..e0bd0de84d0 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1489,12 +1489,12 @@ dependencies = [ { name = "google-auth-httplib2" }, { name = "google-cloud-aiplatform" }, { name = "googleapis-common-protos" }, + { name = "graphon" }, { name = "gunicorn" }, { name = "httpx", extra = ["socks"] }, { name = "httpx-sse" }, { name = "jieba" }, { name = "json-repair" }, - { name = "jsonschema" }, { name = "langfuse" }, { name = "langsmith" }, { name = "litellm" }, @@ -1526,7 +1526,6 @@ dependencies = [ { name = "psycopg2-binary" }, { name = "pycryptodome" }, { name = "pydantic" }, - { name = "pydantic-extra-types" }, { name = "pydantic-settings" }, { name = "pyjwt" }, { name = "pypandoc" }, @@ -1547,7 +1546,6 @@ dependencies = [ { name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] }, { name = "weave" }, { name = "weaviate-client" }, - { name = "webvtt-py" }, { name = "yarl" }, ] @@ -1590,7 +1588,6 @@ dev = [ { name = "types-greenlet" }, { name = "types-html5lib" }, { name = "types-jmespath" }, - { name = "types-jsonschema" }, { name = "types-markdown" }, { name = "types-oauthlib" }, { name = "types-objgraph" }, @@ -1692,12 +1689,12 @@ requires-dist = [ { name = "google-auth-httplib2", specifier = "==0.3.0" }, { name = "google-cloud-aiplatform", specifier = ">=1.123.0" }, { name = "googleapis-common-protos", specifier = ">=1.65.0" }, + { name = "graphon", specifier = ">=0.1.2" }, { name = "gunicorn", specifier = "~=25.1.0" }, { name = "httpx", extras = ["socks"], specifier = "~=0.28.0" }, { name = "httpx-sse", specifier = "~=0.4.0" }, { name = "jieba", specifier = "==0.42.1" }, { name = "json-repair", specifier = ">=0.55.1" }, - { name = "jsonschema", specifier = ">=4.25.1" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.7.16" }, { name = "litellm", specifier = "==1.82.6" }, @@ -1729,7 +1726,6 @@ requires-dist = [ { name = "psycopg2-binary", specifier = "~=2.9.6" }, { name = "pycryptodome", specifier = "==3.23.0" }, { name = "pydantic", specifier = "~=2.12.5" }, - { name = "pydantic-extra-types", specifier = "~=2.11.0" }, { name = "pydantic-settings", specifier = "~=2.13.1" }, { name = "pyjwt", specifier = "~=2.12.0" }, { name = "pypandoc", specifier = "~=1.13" }, @@ -1750,7 +1746,6 @@ requires-dist = [ { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.21.5" }, { name = "weave", specifier = ">=0.52.16" }, { name = "weaviate-client", specifier = "==4.20.4" }, - { name = "webvtt-py", specifier = "~=0.5.1" }, { name = "yarl", specifier = "~=1.23.0" }, ] @@ -1793,7 +1788,6 @@ dev = [ { name = "types-greenlet", specifier = "~=3.3.0" }, { name = "types-html5lib", specifier = "~=1.1.11" }, { name = "types-jmespath", specifier = ">=1.0.2.20240106" }, - { name = "types-jsonschema", specifier = "~=4.26.0" }, { name = "types-markdown", specifier = "~=3.10.2" }, { name = "types-oauthlib", specifier = "~=3.3.0" }, { name = "types-objgraph", specifier = "~=3.6.0" }, @@ -2652,6 +2646,34 @@ requests = [ { name = "requests-toolbelt" }, ] +[[package]] +name = "graphon" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "charset-normalizer" }, + { name = "httpx" }, + { name = "json-repair" }, + { name = "jsonschema" }, + { name = "orjson" }, + { name = "pandas", extra = ["excel"] }, + { name = "pydantic" }, + { name = "pydantic-extra-types" }, + { name = "pypandoc" }, + { name = "pypdfium2" }, + { name = "python-docx" }, + { name = "pyyaml" }, + { name = "tiktoken" }, + { name = "transformers" }, + { name = "typing-extensions" }, + { name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] }, + { name = "webvtt-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/fc/0a5342a1c29bc367c2254c170ef130a84a60d8cd1c9cc84a7a85e96c1042/graphon-0.1.2.tar.gz", hash = "sha256:a2210629f93258ad2e7cbe85b5d4c6826814f6c679aa2a23ca100511363b9240", size = 214744, upload-time = "2026-03-27T20:09:53.802Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/46/65b5e366ec2d7017b6d6448e2635b3772d86840a6f7297277471b1bfbfbd/graphon-0.1.2-py3-none-any.whl", hash = "sha256:79f0c7796de7b8642d070730bb8bdaf1c68ccdfcecac38e0b2282e0543f0a6db", size = 314398, upload-time = "2026-03-27T20:09:52.524Z" }, +] + [[package]] name = "graphql-core" version = "3.2.7" @@ -6850,18 +6872,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/61/91/915c4a6e6e9bd2bca3ec0c21c1771b175c59e204b85e57f3f572370fe753/types_jmespath-1.1.0.20260124-py3-none-any.whl", hash = "sha256:ec387666d446b15624215aa9cbd2867ffd885b6c74246d357c65e830c7a138b3", size = 11509, upload-time = "2026-01-24T03:18:45.536Z" }, ] -[[package]] -name = "types-jsonschema" -version = "4.26.0.20260202" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "referencing" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a1/07/68f63e715eb327ed2f5292e29e8be99785db0f72c7664d2c63bd4dbdc29d/types_jsonschema-4.26.0.20260202.tar.gz", hash = "sha256:29831baa4308865a9aec547a61797a06fc152b0dac8dddd531e002f32265cb07", size = 16168, upload-time = "2026-02-02T04:11:22.585Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/06/962d4f364f779d7389cd31a1bb581907b057f52f0ace2c119a8dd8409db6/types_jsonschema-4.26.0.20260202-py3-none-any.whl", hash = "sha256:41c95343abc4de9264e333a55e95dfb4d401e463856d0164eec9cb182e8746da", size = 15914, upload-time = "2026-02-02T04:11:21.61Z" }, -] - [[package]] name = "types-markdown" version = "3.10.2.20260211" From ec0f20de03eb74fbeabdd291d4f20259aab76208 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Fri, 27 Mar 2026 23:29:38 +0100 Subject: [PATCH 04/14] refactor: use EnumText for prompt_type and customize_token_strategy (#34204) --- .../create_site_record_when_app_created.py | 3 ++- api/models/enums.py | 16 ++++++++++++++++ api/models/model.py | 17 ++++++++++++----- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index 5e7caf8cbed..84be592b1a9 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -1,5 +1,6 @@ from events.app_event import app_was_created from extensions.ext_database import db +from models.enums import CustomizeTokenStrategy from models.model import Site @@ -16,7 +17,7 @@ def handle(sender, **kwargs): icon=app.icon, icon_background=app.icon_background, default_language=account.interface_language, - customize_token_strategy="not_allow", + customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW, code=Site.generate_code(16), created_by=app.created_by, updated_by=app.updated_by, diff --git a/api/models/enums.py b/api/models/enums.py index cdec7b2f122..bf2e927f002 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -158,6 +158,15 @@ class FeedbackFromSource(StrEnum): ADMIN = "admin" +class CustomizeTokenStrategy(StrEnum): + """Site token customization strategy""" + + MUST = "must" + ALLOW = "allow" + NOT_ALLOW = "not_allow" + UUID = "uuid" + + class FeedbackRating(StrEnum): """MessageFeedback rating""" @@ -314,6 +323,13 @@ class MessageChainType(StrEnum): SYSTEM = "system" +class PromptType(StrEnum): + """Prompt configuration type""" + + SIMPLE = "simple" + ADVANCED = "advanced" + + class ProviderQuotaType(StrEnum): PAID = "paid" """hosted paid quota""" diff --git a/api/models/model.py b/api/models/model.py index b03cb7711fd..066d2acdce0 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -40,12 +40,14 @@ from .enums import ( ConversationFromSource, ConversationStatus, CreatorUserRole, + CustomizeTokenStrategy, FeedbackFromSource, FeedbackRating, InvokeFrom, MessageChainType, MessageFileBelongsTo, MessageStatus, + PromptType, ProviderQuotaType, TagType, ) @@ -649,8 +651,11 @@ class AppModelConfig(TypeBase): agent_mode: Mapped[str | None] = mapped_column(LongText, default=None) sensitive_word_avoidance: Mapped[str | None] = mapped_column(LongText, default=None) retriever_resource: Mapped[str | None] = mapped_column(LongText, default=None) - prompt_type: Mapped[str] = mapped_column( - String(255), nullable=False, server_default=sa.text("'simple'"), default="simple" + prompt_type: Mapped[PromptType] = mapped_column( + EnumText(PromptType, length=255), + nullable=False, + server_default=sa.text("'simple'"), + default=PromptType.SIMPLE, ) chat_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None) completion_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None) @@ -802,7 +807,7 @@ class AppModelConfig(TypeBase): "dataset_query_variable": self.dataset_query_variable, "pre_prompt": self.pre_prompt, "agent_mode": self.agent_mode_dict, - "prompt_type": self.prompt_type, + "prompt_type": self.prompt_type.value if isinstance(self.prompt_type, PromptType) else self.prompt_type, "chat_prompt_config": self.chat_prompt_config_dict, "completion_prompt_config": self.completion_prompt_config_dict, "dataset_configs": self.dataset_configs_dict, @@ -846,7 +851,7 @@ class AppModelConfig(TypeBase): self.retriever_resource = ( json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None ) - self.prompt_type = model_config.get("prompt_type", "simple") + self.prompt_type = PromptType(model_config.get("prompt_type", "simple")) self.chat_prompt_config = ( json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None ) @@ -2084,7 +2089,9 @@ class Site(Base): use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", LongText, default="") customize_domain = mapped_column(String(255)) - customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False) + customize_token_strategy: Mapped[CustomizeTokenStrategy] = mapped_column( + EnumText(CustomizeTokenStrategy, length=255), nullable=False + ) prompt_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) status: Mapped[AppStatus] = mapped_column( EnumText(AppStatus, length=255), nullable=False, server_default=sa.text("'normal'"), default=AppStatus.NORMAL From 08e81459758f6d3614abda123b70632e1e31cf49 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 28 Mar 2026 07:53:01 +0900 Subject: [PATCH 05/14] chore(deps): bump cryptography from 44.0.3 to 46.0.6 in /api (#34210) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- api/uv.lock | 74 ++++++++++++++++++++++++++++------------------------- 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/api/uv.lock b/api/uv.lock index e0bd0de84d0..747f18e791a 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -204,7 +204,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/9a/7d/b22cb9a0d4f396ee0 [[package]] name = "alibabacloud-tea-openapi" -version = "0.4.3" +version = "0.4.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alibabacloud-credentials" }, @@ -213,9 +213,9 @@ dependencies = [ { name = "cryptography" }, { name = "darabonba-core" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/91/4f/b5288eea8f4d4b032c9a8f2cd1d926d5017977d10b874956f31e5343f299/alibabacloud_tea_openapi-0.4.3.tar.gz", hash = "sha256:12aef036ed993637b6f141abbd1de9d6199d5516f4a901588bb65d6a3768d41b", size = 21864, upload-time = "2026-01-15T07:55:16.744Z" } +sdist = { url = "https://files.pythonhosted.org/packages/30/93/138bcdc8fc596add73e37cf2073798f285284d1240bda9ee02f9384fc6be/alibabacloud_tea_openapi-0.4.4.tar.gz", hash = "sha256:1b0917bc03cd49417da64945e92731716d53e2eb8707b235f54e45b7473221ce", size = 21960, upload-time = "2026-03-26T10:16:16.792Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a5/37/48ee5468ecad19c6d44cf3b9629d77078e836ee3ec760f0366247f307b7c/alibabacloud_tea_openapi-0.4.3-py3-none-any.whl", hash = "sha256:d0b3a373b760ef6278b25fc128c73284301e07888977bf97519e7636d47bdf0a", size = 26159, upload-time = "2026-01-15T07:55:15.72Z" }, + { url = "https://files.pythonhosted.org/packages/f5/5a/6bfc4506438c1809c486f66217ad11eab78157192b3d5707b4e2f4212f6c/alibabacloud_tea_openapi-0.4.4-py3-none-any.whl", hash = "sha256:cea6bc1fe35b0319a8752cb99eb0ecb0dab7ca1a71b99c12970ba0867410995f", size = 26236, upload-time = "2026-03-26T10:16:15.861Z" }, ] [[package]] @@ -1308,43 +1308,47 @@ wheels = [ [[package]] name = "cryptography" -version = "44.0.3" +version = "46.0.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/53/d6/1411ab4d6108ab167d06254c5be517681f1e331f90edf1379895bcb87020/cryptography-44.0.3.tar.gz", hash = "sha256:fe19d8bc5536a91a24a8133328880a41831b6c5df54599a8417b62fe015d3053", size = 711096, upload-time = "2025-05-02T19:36:04.667Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a4/ba/04b1bd4218cbc58dc90ce967106d51582371b898690f3ae0402876cc4f34/cryptography-46.0.6.tar.gz", hash = "sha256:27550628a518c5c6c903d84f637fbecf287f6cb9ced3804838a1295dc1fd0759", size = 750542, upload-time = "2026-03-25T23:34:53.396Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/08/53/c776d80e9d26441bb3868457909b4e74dd9ccabd182e10b2b0ae7a07e265/cryptography-44.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:962bc30480a08d133e631e8dfd4783ab71cc9e33d5d7c1e192f0b7c06397bb88", size = 6670281, upload-time = "2025-05-02T19:34:50.665Z" }, - { url = "https://files.pythonhosted.org/packages/6a/06/af2cf8d56ef87c77319e9086601bef621bedf40f6f59069e1b6d1ec498c5/cryptography-44.0.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ffc61e8f3bf5b60346d89cd3d37231019c17a081208dfbbd6e1605ba03fa137", size = 3959305, upload-time = "2025-05-02T19:34:53.042Z" }, - { url = "https://files.pythonhosted.org/packages/ae/01/80de3bec64627207d030f47bf3536889efee8913cd363e78ca9a09b13c8e/cryptography-44.0.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58968d331425a6f9eedcee087f77fd3c927c88f55368f43ff7e0a19891f2642c", size = 4171040, upload-time = "2025-05-02T19:34:54.675Z" }, - { url = "https://files.pythonhosted.org/packages/bd/48/bb16b7541d207a19d9ae8b541c70037a05e473ddc72ccb1386524d4f023c/cryptography-44.0.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:e28d62e59a4dbd1d22e747f57d4f00c459af22181f0b2f787ea83f5a876d7c76", size = 3963411, upload-time = "2025-05-02T19:34:56.61Z" }, - { url = "https://files.pythonhosted.org/packages/42/b2/7d31f2af5591d217d71d37d044ef5412945a8a8e98d5a2a8ae4fd9cd4489/cryptography-44.0.3-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:af653022a0c25ef2e3ffb2c673a50e5a0d02fecc41608f4954176f1933b12359", size = 3689263, upload-time = "2025-05-02T19:34:58.591Z" }, - { url = "https://files.pythonhosted.org/packages/25/50/c0dfb9d87ae88ccc01aad8eb93e23cfbcea6a6a106a9b63a7b14c1f93c75/cryptography-44.0.3-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:157f1f3b8d941c2bd8f3ffee0af9b049c9665c39d3da9db2dc338feca5e98a43", size = 4196198, upload-time = "2025-05-02T19:35:00.988Z" }, - { url = "https://files.pythonhosted.org/packages/66/c9/55c6b8794a74da652690c898cb43906310a3e4e4f6ee0b5f8b3b3e70c441/cryptography-44.0.3-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:c6cd67722619e4d55fdb42ead64ed8843d64638e9c07f4011163e46bc512cf01", size = 3966502, upload-time = "2025-05-02T19:35:03.091Z" }, - { url = "https://files.pythonhosted.org/packages/b6/f7/7cb5488c682ca59a02a32ec5f975074084db4c983f849d47b7b67cc8697a/cryptography-44.0.3-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:b424563394c369a804ecbee9b06dfb34997f19d00b3518e39f83a5642618397d", size = 4196173, upload-time = "2025-05-02T19:35:05.018Z" }, - { url = "https://files.pythonhosted.org/packages/d2/0b/2f789a8403ae089b0b121f8f54f4a3e5228df756e2146efdf4a09a3d5083/cryptography-44.0.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:c91fc8e8fd78af553f98bc7f2a1d8db977334e4eea302a4bfd75b9461c2d8904", size = 4087713, upload-time = "2025-05-02T19:35:07.187Z" }, - { url = "https://files.pythonhosted.org/packages/1d/aa/330c13655f1af398fc154089295cf259252f0ba5df93b4bc9d9c7d7f843e/cryptography-44.0.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:25cd194c39fa5a0aa4169125ee27d1172097857b27109a45fadc59653ec06f44", size = 4299064, upload-time = "2025-05-02T19:35:08.879Z" }, - { url = "https://files.pythonhosted.org/packages/10/a8/8c540a421b44fd267a7d58a1fd5f072a552d72204a3f08194f98889de76d/cryptography-44.0.3-cp37-abi3-win32.whl", hash = "sha256:3be3f649d91cb182c3a6bd336de8b61a0a71965bd13d1a04a0e15b39c3d5809d", size = 2773887, upload-time = "2025-05-02T19:35:10.41Z" }, - { url = "https://files.pythonhosted.org/packages/b9/0d/c4b1657c39ead18d76bbd122da86bd95bdc4095413460d09544000a17d56/cryptography-44.0.3-cp37-abi3-win_amd64.whl", hash = "sha256:3883076d5c4cc56dbef0b898a74eb6992fdac29a7b9013870b34efe4ddb39a0d", size = 3209737, upload-time = "2025-05-02T19:35:12.12Z" }, - { url = "https://files.pythonhosted.org/packages/34/a3/ad08e0bcc34ad436013458d7528e83ac29910943cea42ad7dd4141a27bbb/cryptography-44.0.3-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:5639c2b16764c6f76eedf722dbad9a0914960d3489c0cc38694ddf9464f1bb2f", size = 6673501, upload-time = "2025-05-02T19:35:13.775Z" }, - { url = "https://files.pythonhosted.org/packages/b1/f0/7491d44bba8d28b464a5bc8cc709f25a51e3eac54c0a4444cf2473a57c37/cryptography-44.0.3-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3ffef566ac88f75967d7abd852ed5f182da252d23fac11b4766da3957766759", size = 3960307, upload-time = "2025-05-02T19:35:15.917Z" }, - { url = "https://files.pythonhosted.org/packages/f7/c8/e5c5d0e1364d3346a5747cdcd7ecbb23ca87e6dea4f942a44e88be349f06/cryptography-44.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:192ed30fac1728f7587c6f4613c29c584abdc565d7417c13904708db10206645", size = 4170876, upload-time = "2025-05-02T19:35:18.138Z" }, - { url = "https://files.pythonhosted.org/packages/73/96/025cb26fc351d8c7d3a1c44e20cf9a01e9f7cf740353c9c7a17072e4b264/cryptography-44.0.3-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:7d5fe7195c27c32a64955740b949070f21cba664604291c298518d2e255931d2", size = 3964127, upload-time = "2025-05-02T19:35:19.864Z" }, - { url = "https://files.pythonhosted.org/packages/01/44/eb6522db7d9f84e8833ba3bf63313f8e257729cf3a8917379473fcfd6601/cryptography-44.0.3-cp39-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3f07943aa4d7dad689e3bb1638ddc4944cc5e0921e3c227486daae0e31a05e54", size = 3689164, upload-time = "2025-05-02T19:35:21.449Z" }, - { url = "https://files.pythonhosted.org/packages/68/fb/d61a4defd0d6cee20b1b8a1ea8f5e25007e26aeb413ca53835f0cae2bcd1/cryptography-44.0.3-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:cb90f60e03d563ca2445099edf605c16ed1d5b15182d21831f58460c48bffb93", size = 4198081, upload-time = "2025-05-02T19:35:23.187Z" }, - { url = "https://files.pythonhosted.org/packages/1b/50/457f6911d36432a8811c3ab8bd5a6090e8d18ce655c22820994913dd06ea/cryptography-44.0.3-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:ab0b005721cc0039e885ac3503825661bd9810b15d4f374e473f8c89b7d5460c", size = 3967716, upload-time = "2025-05-02T19:35:25.426Z" }, - { url = "https://files.pythonhosted.org/packages/35/6e/dca39d553075980ccb631955c47b93d87d27f3596da8d48b1ae81463d915/cryptography-44.0.3-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:3bb0847e6363c037df8f6ede57d88eaf3410ca2267fb12275370a76f85786a6f", size = 4197398, upload-time = "2025-05-02T19:35:27.678Z" }, - { url = "https://files.pythonhosted.org/packages/9b/9d/d1f2fe681eabc682067c66a74addd46c887ebacf39038ba01f8860338d3d/cryptography-44.0.3-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:b0cc66c74c797e1db750aaa842ad5b8b78e14805a9b5d1348dc603612d3e3ff5", size = 4087900, upload-time = "2025-05-02T19:35:29.312Z" }, - { url = "https://files.pythonhosted.org/packages/c4/f5/3599e48c5464580b73b236aafb20973b953cd2e7b44c7c2533de1d888446/cryptography-44.0.3-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6866df152b581f9429020320e5eb9794c8780e90f7ccb021940d7f50ee00ae0b", size = 4301067, upload-time = "2025-05-02T19:35:31.547Z" }, - { url = "https://files.pythonhosted.org/packages/a7/6c/d2c48c8137eb39d0c193274db5c04a75dab20d2f7c3f81a7dcc3a8897701/cryptography-44.0.3-cp39-abi3-win32.whl", hash = "sha256:c138abae3a12a94c75c10499f1cbae81294a6f983b3af066390adee73f433028", size = 2775467, upload-time = "2025-05-02T19:35:33.805Z" }, - { url = "https://files.pythonhosted.org/packages/c9/ad/51f212198681ea7b0deaaf8846ee10af99fba4e894f67b353524eab2bbe5/cryptography-44.0.3-cp39-abi3-win_amd64.whl", hash = "sha256:5d186f32e52e66994dce4f766884bcb9c68b8da62d61d9d215bfe5fb56d21334", size = 3210375, upload-time = "2025-05-02T19:35:35.369Z" }, - { url = "https://files.pythonhosted.org/packages/8d/4b/c11ad0b6c061902de5223892d680e89c06c7c4d606305eb8de56c5427ae6/cryptography-44.0.3-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:896530bc9107b226f265effa7ef3f21270f18a2026bc09fed1ebd7b66ddf6375", size = 3390230, upload-time = "2025-05-02T19:35:49.062Z" }, - { url = "https://files.pythonhosted.org/packages/58/11/0a6bf45d53b9b2290ea3cec30e78b78e6ca29dc101e2e296872a0ffe1335/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:9b4d4a5dbee05a2c390bf212e78b99434efec37b17a4bff42f50285c5c8c9647", size = 3895216, upload-time = "2025-05-02T19:35:51.351Z" }, - { url = "https://files.pythonhosted.org/packages/0a/27/b28cdeb7270e957f0077a2c2bfad1b38f72f1f6d699679f97b816ca33642/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:02f55fb4f8b79c1221b0961488eaae21015b69b210e18c386b69de182ebb1259", size = 4115044, upload-time = "2025-05-02T19:35:53.044Z" }, - { url = "https://files.pythonhosted.org/packages/35/b0/ec4082d3793f03cb248881fecefc26015813199b88f33e3e990a43f79835/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:dd3db61b8fe5be220eee484a17233287d0be6932d056cf5738225b9c05ef4fff", size = 3898034, upload-time = "2025-05-02T19:35:54.72Z" }, - { url = "https://files.pythonhosted.org/packages/0b/7f/adf62e0b8e8d04d50c9a91282a57628c00c54d4ae75e2b02a223bd1f2613/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:978631ec51a6bbc0b7e58f23b68a8ce9e5f09721940933e9c217068388789fe5", size = 4114449, upload-time = "2025-05-02T19:35:57.139Z" }, - { url = "https://files.pythonhosted.org/packages/87/62/d69eb4a8ee231f4bf733a92caf9da13f1c81a44e874b1d4080c25ecbb723/cryptography-44.0.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:5d20cc348cca3a8aa7312f42ab953a56e15323800ca3ab0706b8cd452a3a056c", size = 3134369, upload-time = "2025-05-02T19:35:58.907Z" }, + { url = "https://files.pythonhosted.org/packages/47/23/9285e15e3bc57325b0a72e592921983a701efc1ee8f91c06c5f0235d86d9/cryptography-46.0.6-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:64235194bad039a10bb6d2d930ab3323baaec67e2ce36215fd0952fad0930ca8", size = 7176401, upload-time = "2026-03-25T23:33:22.096Z" }, + { url = "https://files.pythonhosted.org/packages/60/f8/e61f8f13950ab6195b31913b42d39f0f9afc7d93f76710f299b5ec286ae6/cryptography-46.0.6-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:26031f1e5ca62fcb9d1fcb34b2b60b390d1aacaa15dc8b895a9ed00968b97b30", size = 4275275, upload-time = "2026-03-25T23:33:23.844Z" }, + { url = "https://files.pythonhosted.org/packages/19/69/732a736d12c2631e140be2348b4ad3d226302df63ef64d30dfdb8db7ad1c/cryptography-46.0.6-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9a693028b9cbe51b5a1136232ee8f2bc242e4e19d456ded3fa7c86e43c713b4a", size = 4425320, upload-time = "2026-03-25T23:33:25.703Z" }, + { url = "https://files.pythonhosted.org/packages/d4/12/123be7292674abf76b21ac1fc0e1af50661f0e5b8f0ec8285faac18eb99e/cryptography-46.0.6-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:67177e8a9f421aa2d3a170c3e56eca4e0128883cf52a071a7cbf53297f18b175", size = 4278082, upload-time = "2026-03-25T23:33:27.423Z" }, + { url = "https://files.pythonhosted.org/packages/5b/ba/d5e27f8d68c24951b0a484924a84c7cdaed7502bac9f18601cd357f8b1d2/cryptography-46.0.6-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:d9528b535a6c4f8ff37847144b8986a9a143585f0540fbcb1a98115b543aa463", size = 4926514, upload-time = "2026-03-25T23:33:29.206Z" }, + { url = "https://files.pythonhosted.org/packages/34/71/1ea5a7352ae516d5512d17babe7e1b87d9db5150b21f794b1377eac1edc0/cryptography-46.0.6-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:22259338084d6ae497a19bae5d4c66b7ca1387d3264d1c2c0e72d9e9b6a77b97", size = 4457766, upload-time = "2026-03-25T23:33:30.834Z" }, + { url = "https://files.pythonhosted.org/packages/01/59/562be1e653accee4fdad92c7a2e88fced26b3fdfce144047519bbebc299e/cryptography-46.0.6-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:760997a4b950ff00d418398ad73fbc91aa2894b5c1db7ccb45b4f68b42a63b3c", size = 3986535, upload-time = "2026-03-25T23:33:33.02Z" }, + { url = "https://files.pythonhosted.org/packages/d6/8b/b1ebfeb788bf4624d36e45ed2662b8bd43a05ff62157093c1539c1288a18/cryptography-46.0.6-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:3dfa6567f2e9e4c5dceb8ccb5a708158a2a871052fa75c8b78cb0977063f1507", size = 4277618, upload-time = "2026-03-25T23:33:34.567Z" }, + { url = "https://files.pythonhosted.org/packages/dd/52/a005f8eabdb28df57c20f84c44d397a755782d6ff6d455f05baa2785bd91/cryptography-46.0.6-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:cdcd3edcbc5d55757e5f5f3d330dd00007ae463a7e7aa5bf132d1f22a4b62b19", size = 4890802, upload-time = "2026-03-25T23:33:37.034Z" }, + { url = "https://files.pythonhosted.org/packages/ec/4d/8e7d7245c79c617d08724e2efa397737715ca0ec830ecb3c91e547302555/cryptography-46.0.6-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:d4e4aadb7fc1f88687f47ca20bb7227981b03afaae69287029da08096853b738", size = 4457425, upload-time = "2026-03-25T23:33:38.904Z" }, + { url = "https://files.pythonhosted.org/packages/1d/5c/f6c3596a1430cec6f949085f0e1a970638d76f81c3ea56d93d564d04c340/cryptography-46.0.6-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2b417edbe8877cda9022dde3a008e2deb50be9c407eef034aeeb3a8b11d9db3c", size = 4405530, upload-time = "2026-03-25T23:33:40.842Z" }, + { url = "https://files.pythonhosted.org/packages/7e/c9/9f9cea13ee2dbde070424e0c4f621c091a91ffcc504ffea5e74f0e1daeff/cryptography-46.0.6-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:380343e0653b1c9d7e1f55b52aaa2dbb2fdf2730088d48c43ca1c7c0abb7cc2f", size = 4667896, upload-time = "2026-03-25T23:33:42.781Z" }, + { url = "https://files.pythonhosted.org/packages/ad/b5/1895bc0821226f129bc74d00eccfc6a5969e2028f8617c09790bf89c185e/cryptography-46.0.6-cp311-abi3-win32.whl", hash = "sha256:bcb87663e1f7b075e48c3be3ecb5f0b46c8fc50b50a97cf264e7f60242dca3f2", size = 3026348, upload-time = "2026-03-25T23:33:45.021Z" }, + { url = "https://files.pythonhosted.org/packages/c3/f8/c9bcbf0d3e6ad288b9d9aa0b1dee04b063d19e8c4f871855a03ab3a297ab/cryptography-46.0.6-cp311-abi3-win_amd64.whl", hash = "sha256:6739d56300662c468fddb0e5e291f9b4d084bead381667b9e654c7dd81705124", size = 3483896, upload-time = "2026-03-25T23:33:46.649Z" }, + { url = "https://files.pythonhosted.org/packages/c4/cc/f330e982852403da79008552de9906804568ae9230da8432f7496ce02b71/cryptography-46.0.6-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:12cae594e9473bca1a7aceb90536060643128bb274fcea0fc459ab90f7d1ae7a", size = 7162776, upload-time = "2026-03-25T23:34:13.308Z" }, + { url = "https://files.pythonhosted.org/packages/49/b3/dc27efd8dcc4bff583b3f01d4a3943cd8b5821777a58b3a6a5f054d61b79/cryptography-46.0.6-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:639301950939d844a9e1c4464d7e07f902fe9a7f6b215bb0d4f28584729935d8", size = 4270529, upload-time = "2026-03-25T23:34:15.019Z" }, + { url = "https://files.pythonhosted.org/packages/e6/05/e8d0e6eb4f0d83365b3cb0e00eb3c484f7348db0266652ccd84632a3d58d/cryptography-46.0.6-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ed3775295fb91f70b4027aeba878d79b3e55c0b3e97eaa4de71f8f23a9f2eb77", size = 4414827, upload-time = "2026-03-25T23:34:16.604Z" }, + { url = "https://files.pythonhosted.org/packages/2f/97/daba0f5d2dc6d855e2dcb70733c812558a7977a55dd4a6722756628c44d1/cryptography-46.0.6-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:8927ccfbe967c7df312ade694f987e7e9e22b2425976ddbf28271d7e58845290", size = 4271265, upload-time = "2026-03-25T23:34:18.586Z" }, + { url = "https://files.pythonhosted.org/packages/89/06/fe1fce39a37ac452e58d04b43b0855261dac320a2ebf8f5260dd55b201a9/cryptography-46.0.6-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:b12c6b1e1651e42ab5de8b1e00dc3b6354fdfd778e7fa60541ddacc27cd21410", size = 4916800, upload-time = "2026-03-25T23:34:20.561Z" }, + { url = "https://files.pythonhosted.org/packages/ff/8a/b14f3101fe9c3592603339eb5d94046c3ce5f7fc76d6512a2d40efd9724e/cryptography-46.0.6-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:063b67749f338ca9c5a0b7fe438a52c25f9526b851e24e6c9310e7195aad3b4d", size = 4448771, upload-time = "2026-03-25T23:34:22.406Z" }, + { url = "https://files.pythonhosted.org/packages/01/b3/0796998056a66d1973fd52ee89dc1bb3b6581960a91ad4ac705f182d398f/cryptography-46.0.6-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:02fad249cb0e090b574e30b276a3da6a149e04ee2f049725b1f69e7b8351ec70", size = 3978333, upload-time = "2026-03-25T23:34:24.281Z" }, + { url = "https://files.pythonhosted.org/packages/c5/3d/db200af5a4ffd08918cd55c08399dc6c9c50b0bc72c00a3246e099d3a849/cryptography-46.0.6-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:7e6142674f2a9291463e5e150090b95a8519b2fb6e6aaec8917dd8d094ce750d", size = 4271069, upload-time = "2026-03-25T23:34:25.895Z" }, + { url = "https://files.pythonhosted.org/packages/d7/18/61acfd5b414309d74ee838be321c636fe71815436f53c9f0334bf19064fa/cryptography-46.0.6-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:456b3215172aeefb9284550b162801d62f5f264a081049a3e94307fe20792cfa", size = 4878358, upload-time = "2026-03-25T23:34:27.67Z" }, + { url = "https://files.pythonhosted.org/packages/8b/65/5bf43286d566f8171917cae23ac6add941654ccf085d739195a4eacf1674/cryptography-46.0.6-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:341359d6c9e68834e204ceaf25936dffeafea3829ab80e9503860dcc4f4dac58", size = 4448061, upload-time = "2026-03-25T23:34:29.375Z" }, + { url = "https://files.pythonhosted.org/packages/e0/25/7e49c0fa7205cf3597e525d156a6bce5b5c9de1fd7e8cb01120e459f205a/cryptography-46.0.6-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9a9c42a2723999a710445bc0d974e345c32adfd8d2fac6d8a251fa829ad31cfb", size = 4399103, upload-time = "2026-03-25T23:34:32.036Z" }, + { url = "https://files.pythonhosted.org/packages/44/46/466269e833f1c4718d6cd496ffe20c56c9c8d013486ff66b4f69c302a68d/cryptography-46.0.6-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6617f67b1606dfd9fe4dbfa354a9508d4a6d37afe30306fe6c101b7ce3274b72", size = 4659255, upload-time = "2026-03-25T23:34:33.679Z" }, + { url = "https://files.pythonhosted.org/packages/0a/09/ddc5f630cc32287d2c953fc5d32705e63ec73e37308e5120955316f53827/cryptography-46.0.6-cp38-abi3-win32.whl", hash = "sha256:7f6690b6c55e9c5332c0b59b9c8a3fb232ebf059094c17f9019a51e9827df91c", size = 3010660, upload-time = "2026-03-25T23:34:35.418Z" }, + { url = "https://files.pythonhosted.org/packages/1b/82/ca4893968aeb2709aacfb57a30dec6fa2ab25b10fa9f064b8882ce33f599/cryptography-46.0.6-cp38-abi3-win_amd64.whl", hash = "sha256:79e865c642cfc5c0b3eb12af83c35c5aeff4fa5c672dc28c43721c2c9fdd2f0f", size = 3471160, upload-time = "2026-03-25T23:34:37.191Z" }, + { url = "https://files.pythonhosted.org/packages/2e/84/7ccff00ced5bac74b775ce0beb7d1be4e8637536b522b5df9b73ada42da2/cryptography-46.0.6-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:2ea0f37e9a9cf0df2952893ad145fd9627d326a59daec9b0802480fa3bcd2ead", size = 3475444, upload-time = "2026-03-25T23:34:38.944Z" }, + { url = "https://files.pythonhosted.org/packages/bc/1f/4c926f50df7749f000f20eede0c896769509895e2648db5da0ed55db711d/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:a3e84d5ec9ba01f8fd03802b2147ba77f0c8f2617b2aff254cedd551844209c8", size = 4218227, upload-time = "2026-03-25T23:34:40.871Z" }, + { url = "https://files.pythonhosted.org/packages/c6/65/707be3ffbd5f786028665c3223e86e11c4cda86023adbc56bd72b1b6bab5/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:12f0fa16cc247b13c43d56d7b35287ff1569b5b1f4c5e87e92cc4fcc00cd10c0", size = 4381399, upload-time = "2026-03-25T23:34:42.609Z" }, + { url = "https://files.pythonhosted.org/packages/f3/6d/73557ed0ef7d73d04d9aba745d2c8e95218213687ee5e76b7d236a5030fc/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:50575a76e2951fe7dbd1f56d181f8c5ceeeb075e9ff88e7ad997d2f42af06e7b", size = 4217595, upload-time = "2026-03-25T23:34:44.205Z" }, + { url = "https://files.pythonhosted.org/packages/9e/c5/e1594c4eec66a567c3ac4400008108a415808be2ce13dcb9a9045c92f1a0/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:90e5f0a7b3be5f40c3a0a0eafb32c681d8d2c181fc2a1bdabe9b3f611d9f6b1a", size = 4380912, upload-time = "2026-03-25T23:34:46.328Z" }, + { url = "https://files.pythonhosted.org/packages/1a/89/843b53614b47f97fe1abc13f9a86efa5ec9e275292c457af1d4a60dc80e0/cryptography-46.0.6-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:6728c49e3b2c180ef26f8e9f0a883a2c585638db64cf265b49c9ba10652d430e", size = 3409955, upload-time = "2026-03-25T23:34:48.465Z" }, ] [[package]] From 865ee473ce8f4ae6df9fe7dbbf7884d7c1520ab0 Mon Sep 17 00:00:00 2001 From: YBoy Date: Sat, 28 Mar 2026 00:55:11 +0200 Subject: [PATCH 06/14] test: migrate messages clean service retention tests to testcontainers (#34207) --- .../services/test_messages_clean_service.py | 67 +++- .../test_messages_clean_service.py | 311 ------------------ 2 files changed, 66 insertions(+), 312 deletions(-) delete mode 100644 api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 9528257963e..2340dd2a03d 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import datetime import json import uuid from decimal import Decimal -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from faker import Faker @@ -1169,3 +1171,66 @@ class TestMessagesCleanServiceIntegration: # Verify all messages were deleted assert db_session_with_containers.query(Message).where(Message.id.in_(msg_ids)).count() == 0 + + def test_from_time_range_validation(self): + """Test that from_time_range raises ValueError for invalid inputs.""" + policy = MagicMock(spec=BillingDisabledPolicy) + now = datetime.datetime.now() + + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + MessagesCleanService.from_time_range(policy, now, now) + + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_time_range(policy, now - datetime.timedelta(days=1), now, batch_size=0) + + def test_from_time_range_success(self): + """Test that from_time_range creates a service with correct parameters.""" + policy = MagicMock(spec=BillingDisabledPolicy) + start = datetime.datetime(2024, 1, 1) + end = datetime.datetime(2024, 2, 1) + + service = MessagesCleanService.from_time_range(policy, start, end) + assert service._start_from == start + assert service._end_before == end + + def test_from_days_validation(self): + """Test that from_days raises ValueError for invalid inputs.""" + policy = MagicMock(spec=BillingDisabledPolicy) + + with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"): + MessagesCleanService.from_days(policy, days=-1) + + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_days(policy, days=30, batch_size=0) + + def test_from_days_success(self): + """Test that from_days creates a service with correct parameters.""" + policy = MagicMock(spec=BillingDisabledPolicy) + + with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: + fixed_now = datetime.datetime(2024, 6, 1) + mock_now.return_value = fixed_now + + service = MessagesCleanService.from_days(policy, days=10) + assert service._start_from is None + assert service._end_before == fixed_now - datetime.timedelta(days=10) + + def test_batch_delete_message_relations_empty(self, db_session_with_containers: Session): + """Test that batch_delete_message_relations with empty list does nothing.""" + # Get execute call count before + MessagesCleanService._batch_delete_message_relations(db_session_with_containers, []) + # No exception means success โ€” empty list is a no-op + + def test_run_calls_clean_messages(self): + """Test that run() delegates to _clean_messages_by_time_range.""" + policy = MagicMock(spec=BillingDisabledPolicy) + service = MessagesCleanService( + policy=policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + with patch.object(service, "_clean_messages_by_time_range") as mock_clean: + mock_clean.return_value = {"total_deleted": 5} + result = service.run() + assert result == {"total_deleted": 5} + mock_clean.assert_called_once() diff --git a/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py b/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py deleted file mode 100644 index f9d901fca24..00000000000 --- a/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py +++ /dev/null @@ -1,311 +0,0 @@ -import datetime -from unittest.mock import MagicMock, patch - -import pytest - -from services.retention.conversation.messages_clean_policy import ( - BillingDisabledPolicy, -) -from services.retention.conversation.messages_clean_service import MessagesCleanService - - -class TestMessagesCleanService: - @pytest.fixture(autouse=True) - def mock_db_engine(self): - with patch("services.retention.conversation.messages_clean_service.db") as mock_db: - mock_db.engine = MagicMock() - yield mock_db.engine - - @pytest.fixture - def mock_db_session(self, mock_db_engine): - with patch("services.retention.conversation.messages_clean_service.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session - yield mock_session - - @pytest.fixture - def mock_policy(self): - policy = MagicMock(spec=BillingDisabledPolicy) - return policy - - def test_run_calls_clean_messages(self, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - with patch.object(service, "_clean_messages_by_time_range") as mock_clean: - mock_clean.return_value = {"total_deleted": 5} - result = service.run() - assert result == {"total_deleted": 5} - mock_clean.assert_called_once() - - def test_clean_messages_by_time_range_basic(self, mock_db_session, mock_policy): - # Arrange - end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) - service = MessagesCleanService( - policy=mock_policy, - end_before=end_before, - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime(2024, 1, 1, 10, 0, 0))]), # messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps - MagicMock( - rowcount=1 - ), # delete relations (this is wrong, relations delete doesn't use rowcount here, but execute) - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete messages - MagicMock(all=lambda: []), # next batch empty - ] - - # Reset side_effect to be more robust - # The service calls session.execute for: - # 1. Fetch messages - # 2. Fetch apps - # 3. Batch delete relations (8 calls if IDs exist) - # 4. Delete messages - - mock_returns = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime(2024, 1, 1, 10, 0, 0))]), # fetch messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # fetch apps - ] - # 8 deletes for relations - mock_returns.extend([MagicMock() for _ in range(8)]) - # 1 delete for messages - mock_returns.append(MagicMock(rowcount=1)) - # Final fetch messages (empty) - mock_returns.append(MagicMock(all=lambda: [])) - - mock_db_session.execute.side_effect = mock_returns - mock_policy.filter_message_ids.return_value = ["msg1"] - - # Act - with patch("services.retention.conversation.messages_clean_service.time.sleep"): - stats = service.run() - - # Assert - assert stats["total_messages"] == 1 - assert stats["total_deleted"] == 1 - assert stats["batches"] == 2 - - def test_clean_messages_by_time_range_with_start_from(self, mock_db_session, mock_policy): - start_from = datetime.datetime(2024, 1, 1, 0, 0, 0) - end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) - service = MessagesCleanService( - policy=mock_policy, - start_from=start_from, - end_before=end_before, - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: []), # No messages - ] - - stats = service.run() - assert stats["total_messages"] == 0 - - def test_clean_messages_by_time_range_with_cursor(self, mock_db_session, mock_policy): - # Test pagination with cursor - end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) - service = MessagesCleanService( - policy=mock_policy, - end_before=end_before, - batch_size=1, - ) - - msg1_time = datetime.datetime(2024, 1, 1, 10, 0, 0) - msg2_time = datetime.datetime(2024, 1, 1, 11, 0, 0) - - mock_returns = [] - # Batch 1 - mock_returns.append(MagicMock(all=lambda: [("msg1", "app1", msg1_time)])) - mock_returns.append(MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")])) - mock_returns.extend([MagicMock() for _ in range(8)]) # relations - mock_returns.append(MagicMock(rowcount=1)) # messages - - # Batch 2 - mock_returns.append(MagicMock(all=lambda: [("msg2", "app1", msg2_time)])) - mock_returns.append(MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")])) - mock_returns.extend([MagicMock() for _ in range(8)]) # relations - mock_returns.append(MagicMock(rowcount=1)) # messages - - # Batch 3 - mock_returns.append(MagicMock(all=lambda: [])) - - mock_db_session.execute.side_effect = mock_returns - mock_policy.filter_message_ids.return_value = ["msg1"] # Simplified - - with patch("services.retention.conversation.messages_clean_service.time.sleep"): - stats = service.run() - - assert stats["batches"] == 3 - assert stats["total_messages"] == 2 - - def test_clean_messages_by_time_range_dry_run(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - dry_run=True, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps - MagicMock(all=lambda: []), # next batch empty - ] - mock_policy.filter_message_ids.return_value = ["msg1"] - - with patch("services.retention.conversation.messages_clean_service.random.sample") as mock_sample: - mock_sample.return_value = ["msg1"] - stats = service.run() - assert stats["filtered_messages"] == 1 - assert stats["total_deleted"] == 0 # Dry run - mock_sample.assert_called() - - def test_clean_messages_by_time_range_no_apps_found(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: []), # apps NOT found - MagicMock(all=lambda: []), # next batch empty - ] - - stats = service.run() - assert stats["total_messages"] == 1 - assert stats["total_deleted"] == 0 - - def test_clean_messages_by_time_range_no_app_ids(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: []), # next batch empty - ] - - # We need to successfully execute line 228 and 229, then return empty at 251. - # line 228: raw_messages = list(session.execute(msg_stmt).all()) - # line 251: app_ids = list({msg.app_id for msg in messages}) - - calls = [] - - def list_side_effect(arg): - calls.append(arg) - if len(calls) == 2: # This is the second call to list() in the loop - return [] - return list(arg) - - with patch("services.retention.conversation.messages_clean_service.list", side_effect=list_side_effect): - stats = service.run() - assert stats["batches"] == 2 - assert stats["total_messages"] == 1 - - def test_from_time_range_validation(self, mock_policy): - now = datetime.datetime.now() - # Test start_from >= end_before - with pytest.raises(ValueError, match="start_from .* must be less than end_before"): - MessagesCleanService.from_time_range(mock_policy, now, now) - - # Test batch_size <= 0 - with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): - MessagesCleanService.from_time_range(mock_policy, now - datetime.timedelta(days=1), now, batch_size=0) - - def test_from_time_range_success(self, mock_policy): - start = datetime.datetime(2024, 1, 1) - end = datetime.datetime(2024, 2, 1) - # Mock logger to avoid actual logging if needed, though it's fine - service = MessagesCleanService.from_time_range(mock_policy, start, end) - assert service._start_from == start - assert service._end_before == end - - def test_from_days_validation(self, mock_policy): - # Test days < 0 - with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"): - MessagesCleanService.from_days(mock_policy, days=-1) - - # Test batch_size <= 0 - with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): - MessagesCleanService.from_days(mock_policy, days=30, batch_size=0) - - def test_from_days_success(self, mock_policy): - with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: - fixed_now = datetime.datetime(2024, 6, 1) - mock_now.return_value = fixed_now - - service = MessagesCleanService.from_days(mock_policy, days=10) - assert service._start_from is None - assert service._end_before == fixed_now - datetime.timedelta(days=10) - - def test_clean_messages_by_time_range_no_messages_to_delete(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps - MagicMock(all=lambda: []), # next batch empty - ] - mock_policy.filter_message_ids.return_value = [] # Policy says NO - - stats = service.run() - assert stats["total_messages"] == 1 - assert stats["filtered_messages"] == 0 - assert stats["total_deleted"] == 0 - - def test_batch_delete_message_relations_empty(self, mock_db_session): - MessagesCleanService._batch_delete_message_relations(mock_db_session, []) - mock_db_session.execute.assert_not_called() - - def test_batch_delete_message_relations_with_ids(self, mock_db_session): - MessagesCleanService._batch_delete_message_relations(mock_db_session, ["msg1", "msg2"]) - assert mock_db_session.execute.call_count == 8 # 8 tables to clean up - - def test_clean_messages_interval_from_env(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - - mock_returns = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps - ] - mock_returns.extend([MagicMock() for _ in range(8)]) # relations - mock_returns.append(MagicMock(rowcount=1)) # messages - mock_returns.append(MagicMock(all=lambda: [])) # next batch empty - - mock_db_session.execute.side_effect = mock_returns - mock_policy.filter_message_ids.return_value = ["msg1"] - - with patch( - "services.retention.conversation.messages_clean_service.dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", - 500, - ): - with patch("services.retention.conversation.messages_clean_service.time.sleep") as mock_sleep: - with patch("services.retention.conversation.messages_clean_service.random.uniform") as mock_uniform: - mock_uniform.return_value = 300.0 - service.run() - mock_uniform.assert_called_with(0, 500) - mock_sleep.assert_called_with(0.3) From c5eae67ac92e466717367a09f9cb7853a16ab8ce Mon Sep 17 00:00:00 2001 From: Maa-Lee | odeili Date: Sat, 28 Mar 2026 00:01:05 +0000 Subject: [PATCH 07/14] refactor: use select for API key auth lookups (#34146) Co-authored-by: Asuka Minato --- api/services/auth/api_key_auth_service.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index 56aaf407eeb..3282dcfb113 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -35,15 +35,13 @@ class ApiKeyAuthService: @staticmethod def get_auth_credentials(tenant_id: str, category: str, provider: str): - data_source_api_key_bindings = ( - db.session.query(DataSourceApiKeyAuthBinding) - .where( + data_source_api_key_bindings = db.session.scalar( + select(DataSourceApiKeyAuthBinding).where( DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.category == category, DataSourceApiKeyAuthBinding.provider == provider, DataSourceApiKeyAuthBinding.disabled.is_(False), ) - .first() ) if not data_source_api_key_bindings: return None @@ -54,10 +52,11 @@ class ApiKeyAuthService: @staticmethod def delete_provider_auth(tenant_id: str, binding_id: str): - data_source_api_key_binding = ( - db.session.query(DataSourceApiKeyAuthBinding) - .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) - .first() + data_source_api_key_binding = db.session.scalar( + select(DataSourceApiKeyAuthBinding).where( + DataSourceApiKeyAuthBinding.tenant_id == tenant_id, + DataSourceApiKeyAuthBinding.id == binding_id, + ) ) if data_source_api_key_binding: db.session.delete(data_source_api_key_binding) From 5851b42af38e83d4aa398355fe5d941e87e36b28 Mon Sep 17 00:00:00 2001 From: YBoy Date: Sat, 28 Mar 2026 09:48:48 +0200 Subject: [PATCH 08/14] test: migrate metadata service tests to testcontainers (#34220) --- .../services/test_metadata_service.py | 558 ------------------ 1 file changed, 558 deletions(-) delete mode 100644 api/tests/unit_tests/services/test_metadata_service.py diff --git a/api/tests/unit_tests/services/test_metadata_service.py b/api/tests/unit_tests/services/test_metadata_service.py deleted file mode 100644 index bbdc16d4f87..00000000000 --- a/api/tests/unit_tests/services/test_metadata_service.py +++ /dev/null @@ -1,558 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from datetime import UTC, datetime -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture - -from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource -from models.dataset import Dataset -from services.entities.knowledge_entities.knowledge_entities import ( - DocumentMetadataOperation, - MetadataArgs, - MetadataDetail, - MetadataOperationData, -) -from services.metadata_service import MetadataService - - -@dataclass -class _DocumentStub: - id: str - name: str - uploader: str - upload_date: datetime - last_update_date: datetime - data_source_type: str - doc_metadata: dict[str, object] | None - - -@pytest.fixture -def mock_db(mocker: MockerFixture) -> MagicMock: - mocked_db = mocker.patch("services.metadata_service.db") - mocked_db.session = MagicMock() - return mocked_db - - -@pytest.fixture -def mock_redis_client(mocker: MockerFixture) -> MagicMock: - return mocker.patch("services.metadata_service.redis_client") - - -@pytest.fixture -def mock_current_account(mocker: MockerFixture) -> MagicMock: - mock_user = SimpleNamespace(id="user-1") - return mocker.patch("services.metadata_service.current_account_with_tenant", return_value=(mock_user, "tenant-1")) - - -def _build_document(document_id: str, doc_metadata: dict[str, object] | None = None) -> _DocumentStub: - now = datetime(2025, 1, 1, 10, 30, tzinfo=UTC) - return _DocumentStub( - id=document_id, - name=f"doc-{document_id}", - uploader="qa@example.com", - upload_date=now, - last_update_date=now, - data_source_type="upload_file", - doc_metadata=doc_metadata, - ) - - -def _dataset(**kwargs: Any) -> Dataset: - return cast(Dataset, SimpleNamespace(**kwargs)) - - -def test_create_metadata_should_raise_value_error_when_name_exceeds_limit() -> None: - # Arrange - metadata_args = MetadataArgs(type="string", name="x" * 256) - - # Act + Assert - with pytest.raises(ValueError, match="cannot exceed 255"): - MetadataService.create_metadata("dataset-1", metadata_args) - - -def test_create_metadata_should_raise_value_error_when_metadata_name_already_exists( - mock_db: MagicMock, - mock_current_account: MagicMock, -) -> None: - # Arrange - metadata_args = MetadataArgs(type="string", name="priority") - mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() - - # Act + Assert - with pytest.raises(ValueError, match="already exists"): - MetadataService.create_metadata("dataset-1", metadata_args) - - # Assert - mock_current_account.assert_called_once() - - -def test_create_metadata_should_raise_value_error_when_name_collides_with_builtin( - mock_db: MagicMock, mock_current_account: MagicMock -) -> None: - # Arrange - metadata_args = MetadataArgs(type="string", name=BuiltInField.document_name) - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="Built-in fields"): - MetadataService.create_metadata("dataset-1", metadata_args) - - -def test_create_metadata_should_persist_metadata_when_input_is_valid( - mock_db: MagicMock, mock_current_account: MagicMock -) -> None: - # Arrange - metadata_args = MetadataArgs(type="number", name="score") - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Act - result = MetadataService.create_metadata("dataset-1", metadata_args) - - # Assert - assert result.tenant_id == "tenant-1" - assert result.dataset_id == "dataset-1" - assert result.type == "number" - assert result.name == "score" - assert result.created_by == "user-1" - mock_db.session.add.assert_called_once_with(result) - mock_db.session.commit.assert_called_once() - mock_current_account.assert_called_once() - - -def test_update_metadata_name_should_raise_value_error_when_name_exceeds_limit() -> None: - # Arrange - too_long_name = "x" * 256 - - # Act + Assert - with pytest.raises(ValueError, match="cannot exceed 255"): - MetadataService.update_metadata_name("dataset-1", "metadata-1", too_long_name) - - -def test_update_metadata_name_should_raise_value_error_when_duplicate_name_exists( - mock_db: MagicMock, mock_current_account: MagicMock -) -> None: - # Arrange - mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() - - # Act + Assert - with pytest.raises(ValueError, match="already exists"): - MetadataService.update_metadata_name("dataset-1", "metadata-1", "duplicate") - - # Assert - mock_current_account.assert_called_once() - - -def test_update_metadata_name_should_raise_value_error_when_name_collides_with_builtin( - mock_db: MagicMock, - mock_current_account: MagicMock, -) -> None: - # Arrange - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="Built-in fields"): - MetadataService.update_metadata_name("dataset-1", "metadata-1", BuiltInField.source) - - # Assert - mock_current_account.assert_called_once() - - -def test_update_metadata_name_should_update_bound_documents_and_return_metadata( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mock_current_account: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - fixed_now = datetime(2025, 2, 1, 0, 0, tzinfo=UTC) - mocker.patch("services.metadata_service.naive_utc_now", return_value=fixed_now) - - metadata = SimpleNamespace(id="metadata-1", name="old_name", updated_by=None, updated_at=None) - bindings = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")] - query_duplicate = MagicMock() - query_duplicate.filter_by.return_value.first.return_value = None - query_metadata = MagicMock() - query_metadata.filter_by.return_value.first.return_value = metadata - query_bindings = MagicMock() - query_bindings.filter_by.return_value.all.return_value = bindings - mock_db.session.query.side_effect = [query_duplicate, query_metadata, query_bindings] - - doc_1 = _build_document("1", {"old_name": "value", "other": "keep"}) - doc_2 = _build_document("2", None) - mock_get_documents = mocker.patch("services.metadata_service.DocumentService.get_document_by_ids") - mock_get_documents.return_value = [doc_1, doc_2] - - # Act - result = MetadataService.update_metadata_name("dataset-1", "metadata-1", "new_name") - - # Assert - assert result is metadata - assert metadata.name == "new_name" - assert metadata.updated_by == "user-1" - assert metadata.updated_at == fixed_now - assert doc_1.doc_metadata == {"other": "keep", "new_name": "value"} - assert doc_2.doc_metadata == {"new_name": None} - mock_get_documents.assert_called_once_with(["doc-1", "doc-2"]) - mock_db.session.commit.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - mock_current_account.assert_called_once() - - -def test_update_metadata_name_should_return_none_when_metadata_does_not_exist( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mock_current_account: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - mock_logger = mocker.patch("services.metadata_service.logger") - - query_duplicate = MagicMock() - query_duplicate.filter_by.return_value.first.return_value = None - query_metadata = MagicMock() - query_metadata.filter_by.return_value.first.return_value = None - mock_db.session.query.side_effect = [query_duplicate, query_metadata] - - # Act - result = MetadataService.update_metadata_name("dataset-1", "missing-id", "new_name") - - # Assert - assert result is None - mock_logger.exception.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - mock_current_account.assert_called_once() - - -def test_delete_metadata_should_remove_metadata_and_related_document_fields( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - metadata = SimpleNamespace(id="metadata-1", name="obsolete") - bindings = [SimpleNamespace(document_id="doc-1")] - query_metadata = MagicMock() - query_metadata.filter_by.return_value.first.return_value = metadata - query_bindings = MagicMock() - query_bindings.filter_by.return_value.all.return_value = bindings - mock_db.session.query.side_effect = [query_metadata, query_bindings] - - document = _build_document("1", {"obsolete": "legacy", "remaining": "value"}) - mocker.patch("services.metadata_service.DocumentService.get_document_by_ids", return_value=[document]) - - # Act - result = MetadataService.delete_metadata("dataset-1", "metadata-1") - - # Assert - assert result is metadata - assert document.doc_metadata == {"remaining": "value"} - mock_db.session.delete.assert_called_once_with(metadata) - mock_db.session.commit.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - - -def test_delete_metadata_should_return_none_when_metadata_is_missing( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - mock_logger = mocker.patch("services.metadata_service.logger") - - # Act - result = MetadataService.delete_metadata("dataset-1", "missing-id") - - # Assert - assert result is None - mock_logger.exception.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - - -def test_get_built_in_fields_should_return_all_expected_fields() -> None: - # Arrange - expected_names = { - BuiltInField.document_name, - BuiltInField.uploader, - BuiltInField.upload_date, - BuiltInField.last_update_date, - BuiltInField.source, - } - - # Act - result = MetadataService.get_built_in_fields() - - # Assert - assert {item["name"] for item in result} == expected_names - assert [item["type"] for item in result] == ["string", "string", "time", "time", "string"] - - -def test_enable_built_in_field_should_return_immediately_when_already_enabled( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - dataset = _dataset(id="dataset-1", built_in_field_enabled=True) - get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id") - - # Act - MetadataService.enable_built_in_field(dataset) - - # Assert - get_docs.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_enable_built_in_field_should_populate_documents_and_enable_flag( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - dataset = _dataset(id="dataset-1", built_in_field_enabled=False) - doc_1 = _build_document("1", {"custom": "value"}) - doc_2 = _build_document("2", None) - mocker.patch( - "services.metadata_service.DocumentService.get_working_documents_by_dataset_id", - return_value=[doc_1, doc_2], - ) - - # Act - MetadataService.enable_built_in_field(dataset) - - # Assert - assert dataset.built_in_field_enabled is True - assert doc_1.doc_metadata is not None - assert doc_1.doc_metadata[BuiltInField.document_name] == "doc-1" - assert doc_1.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file - assert doc_2.doc_metadata is not None - assert doc_2.doc_metadata[BuiltInField.uploader] == "qa@example.com" - mock_db.session.commit.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - - -def test_disable_built_in_field_should_return_immediately_when_already_disabled( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - dataset = _dataset(id="dataset-1", built_in_field_enabled=False) - get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id") - - # Act - MetadataService.disable_built_in_field(dataset) - - # Assert - get_docs.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_disable_built_in_field_should_remove_builtin_keys_and_disable_flag( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - dataset = _dataset(id="dataset-1", built_in_field_enabled=True) - document = _build_document( - "1", - { - BuiltInField.document_name: "doc", - BuiltInField.uploader: "user", - BuiltInField.upload_date: 1.0, - BuiltInField.last_update_date: 2.0, - BuiltInField.source: MetadataDataSource.upload_file, - "custom": "keep", - }, - ) - mocker.patch( - "services.metadata_service.DocumentService.get_working_documents_by_dataset_id", - return_value=[document], - ) - - # Act - MetadataService.disable_built_in_field(dataset) - - # Assert - assert dataset.built_in_field_enabled is False - assert document.doc_metadata == {"custom": "keep"} - mock_db.session.commit.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - - -def test_update_documents_metadata_should_replace_metadata_and_create_bindings_on_full_update( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mock_current_account: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - dataset = _dataset(id="dataset-1", built_in_field_enabled=False) - document = _build_document("1", {"legacy": "value"}) - mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document) - delete_chain = mock_db.session.query.return_value.filter_by.return_value - delete_chain.delete.return_value = 1 - operation = DocumentMetadataOperation( - document_id="1", - metadata_list=[MetadataDetail(id="meta-1", name="priority", value="high")], - partial_update=False, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Act - MetadataService.update_documents_metadata(dataset, metadata_args) - - # Assert - assert document.doc_metadata == {"priority": "high"} - delete_chain.delete.assert_called_once() - assert mock_db.session.commit.call_count == 1 - mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1") - mock_current_account.assert_called_once() - - -def test_update_documents_metadata_should_skip_existing_binding_and_preserve_existing_fields_on_partial_update( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mock_current_account: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - dataset = _dataset(id="dataset-1", built_in_field_enabled=True) - document = _build_document("1", {"existing": "value"}) - mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document) - mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() - operation = DocumentMetadataOperation( - document_id="1", - metadata_list=[MetadataDetail(id="meta-1", name="new_key", value="new_value")], - partial_update=True, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Act - MetadataService.update_documents_metadata(dataset, metadata_args) - - # Assert - assert document.doc_metadata is not None - assert document.doc_metadata["existing"] == "value" - assert document.doc_metadata["new_key"] == "new_value" - assert document.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file - assert mock_db.session.commit.call_count == 1 - assert mock_db.session.add.call_count == 1 - mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1") - mock_current_account.assert_called_once() - - -def test_update_documents_metadata_should_raise_and_rollback_when_document_not_found( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - dataset = _dataset(id="dataset-1", built_in_field_enabled=False) - mocker.patch("services.metadata_service.DocumentService.get_document", return_value=None) - operation = DocumentMetadataOperation(document_id="404", metadata_list=[], partial_update=True) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Act + Assert - with pytest.raises(ValueError, match="Document not found"): - MetadataService.update_documents_metadata(dataset, metadata_args) - - # Assert - mock_db.session.rollback.assert_called_once() - mock_redis_client.delete.assert_called_once_with("document_metadata_lock_404") - - -@pytest.mark.parametrize( - ("dataset_id", "document_id", "expected_key"), - [ - ("dataset-1", None, "dataset_metadata_lock_dataset-1"), - (None, "doc-1", "document_metadata_lock_doc-1"), - ], -) -def test_knowledge_base_metadata_lock_check_should_set_lock_when_not_already_locked( - dataset_id: str | None, - document_id: str | None, - expected_key: str, - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - - # Act - MetadataService.knowledge_base_metadata_lock_check(dataset_id, document_id) - - # Assert - mock_redis_client.set.assert_called_once_with(expected_key, 1, ex=3600) - - -def test_knowledge_base_metadata_lock_check_should_raise_when_dataset_lock_exists( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = 1 - - # Act + Assert - with pytest.raises(ValueError, match="knowledge base metadata operation is running"): - MetadataService.knowledge_base_metadata_lock_check("dataset-1", None) - - -def test_knowledge_base_metadata_lock_check_should_raise_when_document_lock_exists( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = 1 - - # Act + Assert - with pytest.raises(ValueError, match="document metadata operation is running"): - MetadataService.knowledge_base_metadata_lock_check(None, "doc-1") - - -def test_get_dataset_metadatas_should_exclude_builtin_and_include_binding_counts(mock_db: MagicMock) -> None: - # Arrange - dataset = _dataset( - id="dataset-1", - built_in_field_enabled=True, - doc_metadata=[ - {"id": "meta-1", "name": "priority", "type": "string"}, - {"id": "built-in", "name": "ignored", "type": "string"}, - {"id": "meta-2", "name": "score", "type": "number"}, - ], - ) - count_chain = mock_db.session.query.return_value.filter_by.return_value - count_chain.count.side_effect = [3, 1] - - # Act - result = MetadataService.get_dataset_metadatas(dataset) - - # Assert - assert result["built_in_field_enabled"] is True - assert result["doc_metadata"] == [ - {"id": "meta-1", "name": "priority", "type": "string", "count": 3}, - {"id": "meta-2", "name": "score", "type": "number", "count": 1}, - ] - - -def test_get_dataset_metadatas_should_return_empty_list_when_no_metadata(mock_db: MagicMock) -> None: - # Arrange - dataset = _dataset(id="dataset-1", built_in_field_enabled=False, doc_metadata=None) - - # Act - result = MetadataService.get_dataset_metadatas(dataset) - - # Assert - assert result == {"doc_metadata": [], "built_in_field_enabled": False} - mock_db.session.query.assert_not_called() From 3409c519e2e2cbf0607a342fe3ac46b3a1fe4aa8 Mon Sep 17 00:00:00 2001 From: YBoy Date: Sat, 28 Mar 2026 09:49:27 +0200 Subject: [PATCH 09/14] test: migrate tag service tests to testcontainers (#34219) --- .../unit_tests/services/test_tag_service.py | 1336 ----------------- 1 file changed, 1336 deletions(-) delete mode 100644 api/tests/unit_tests/services/test_tag_service.py diff --git a/api/tests/unit_tests/services/test_tag_service.py b/api/tests/unit_tests/services/test_tag_service.py deleted file mode 100644 index b09463b1bc4..00000000000 --- a/api/tests/unit_tests/services/test_tag_service.py +++ /dev/null @@ -1,1336 +0,0 @@ -""" -Comprehensive unit tests for TagService. - -This test suite provides complete coverage of tag management operations in Dify, -following TDD principles with the Arrange-Act-Assert pattern. - -The TagService is responsible for managing tags that can be associated with -datasets (knowledge bases) and applications. Tags enable users to organize, -categorize, and filter their content effectively. - -## Test Coverage - -### 1. Tag Retrieval (TestTagServiceRetrieval) -Tests tag listing and filtering: -- Get tags with binding counts -- Filter tags by keyword (case-insensitive) -- Get tags by target ID (apps/datasets) -- Get tags by tag name -- Get target IDs by tag IDs -- Empty results handling - -### 2. Tag CRUD Operations (TestTagServiceCRUD) -Tests tag creation, update, and deletion: -- Create new tags -- Prevent duplicate tag names -- Update tag names -- Update with duplicate name validation -- Delete tags and cascade delete bindings -- Get tag binding counts -- NotFound error handling - -### 3. Tag Binding Operations (TestTagServiceBindings) -Tests tag-to-resource associations: -- Save tag bindings (apps/datasets) -- Prevent duplicate bindings (idempotent) -- Delete tag bindings -- Check target exists validation -- Batch binding operations - -## Testing Approach - -- **Mocking Strategy**: All external dependencies (database, current_user) are mocked - for fast, isolated unit tests -- **Factory Pattern**: TagServiceTestDataFactory provides consistent test data -- **Fixtures**: Mock objects are configured per test method -- **Assertions**: Each test verifies return values and side effects - (database operations, method calls) - -## Key Concepts - -**Tag Types:** -- knowledge: Tags for datasets/knowledge bases -- app: Tags for applications - -**Tag Bindings:** -- Many-to-many relationship between tags and resources -- Each binding links a tag to a specific app or dataset -- Bindings are tenant-scoped for multi-tenancy - -**Validation:** -- Tag names must be unique within tenant and type -- Target resources must exist before binding -- Cascade deletion of bindings when tag is deleted -""" - - -# ============================================================================ -# IMPORTS -# ============================================================================ - -from datetime import UTC, datetime -from unittest.mock import MagicMock, Mock, create_autospec, patch - -import pytest -from werkzeug.exceptions import NotFound - -from models.dataset import Dataset -from models.enums import TagType -from models.model import App, Tag, TagBinding -from services.tag_service import TagService - -# ============================================================================ -# TEST DATA FACTORY -# ============================================================================ - - -class TagServiceTestDataFactory: - """ - Factory for creating test data and mock objects. - - Provides reusable methods to create consistent mock objects for testing - tag-related operations. This factory ensures all test data follows the - same structure and reduces code duplication across tests. - - The factory pattern is used here to: - - Ensure consistent test data creation - - Reduce boilerplate code in individual tests - - Make tests more maintainable and readable - - Centralize mock object configuration - """ - - @staticmethod - def create_tag_mock( - tag_id: str = "tag-123", - name: str = "Test Tag", - tag_type: TagType = TagType.APP, - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """ - Create a mock Tag object. - - This method creates a mock Tag instance with all required attributes - set to sensible defaults. Additional attributes can be passed via - kwargs to customize the mock for specific test scenarios. - - Args: - tag_id: Unique identifier for the tag - name: Tag name (e.g., "Frontend", "Backend", "Data Science") - tag_type: Type of tag ('app' or 'knowledge') - tenant_id: Tenant identifier for multi-tenancy isolation - **kwargs: Additional attributes to set on the mock - (e.g., created_by, created_at, etc.) - - Returns: - Mock Tag object with specified attributes - - Example: - >>> tag = factory.create_tag_mock( - ... tag_id="tag-456", - ... name="Machine Learning", - ... tag_type="knowledge" - ... ) - """ - # Create a mock that matches the Tag model interface - tag = create_autospec(Tag, instance=True) - - # Set core attributes - tag.id = tag_id - tag.name = name - tag.type = tag_type - tag.tenant_id = tenant_id - - # Set default optional attributes - tag.created_by = kwargs.pop("created_by", "user-123") - tag.created_at = kwargs.pop("created_at", datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC)) - - # Apply any additional attributes from kwargs - for key, value in kwargs.items(): - setattr(tag, key, value) - - return tag - - @staticmethod - def create_tag_binding_mock( - binding_id: str = "binding-123", - tag_id: str = "tag-123", - target_id: str = "target-123", - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """ - Create a mock TagBinding object. - - TagBindings represent the many-to-many relationship between tags - and resources (datasets or apps). This method creates a mock - binding with the necessary attributes. - - Args: - binding_id: Unique identifier for the binding - tag_id: Associated tag identifier - target_id: Associated target (app/dataset) identifier - tenant_id: Tenant identifier for multi-tenancy isolation - **kwargs: Additional attributes to set on the mock - (e.g., created_by, etc.) - - Returns: - Mock TagBinding object with specified attributes - - Example: - >>> binding = factory.create_tag_binding_mock( - ... tag_id="tag-456", - ... target_id="dataset-789", - ... tenant_id="tenant-123" - ... ) - """ - # Create a mock that matches the TagBinding model interface - binding = create_autospec(TagBinding, instance=True) - - # Set core attributes - binding.id = binding_id - binding.tag_id = tag_id - binding.target_id = target_id - binding.tenant_id = tenant_id - - # Set default optional attributes - binding.created_by = kwargs.pop("created_by", "user-123") - - # Apply any additional attributes from kwargs - for key, value in kwargs.items(): - setattr(binding, key, value) - - return binding - - @staticmethod - def create_app_mock(app_id: str = "app-123", tenant_id: str = "tenant-123", **kwargs) -> Mock: - """ - Create a mock App object. - - This method creates a mock App instance for testing tag bindings - to applications. Apps are one of the two target types that tags - can be bound to (the other being datasets/knowledge bases). - - Args: - app_id: Unique identifier for the app - tenant_id: Tenant identifier for multi-tenancy isolation - **kwargs: Additional attributes to set on the mock - - Returns: - Mock App object with specified attributes - - Example: - >>> app = factory.create_app_mock( - ... app_id="app-456", - ... name="My Chat App" - ... ) - """ - # Create a mock that matches the App model interface - app = create_autospec(App, instance=True) - - # Set core attributes - app.id = app_id - app.tenant_id = tenant_id - app.name = kwargs.get("name", "Test App") - - # Apply any additional attributes from kwargs - for key, value in kwargs.items(): - setattr(app, key, value) - - return app - - @staticmethod - def create_dataset_mock(dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", **kwargs) -> Mock: - """ - Create a mock Dataset object. - - This method creates a mock Dataset instance for testing tag bindings - to knowledge bases. Datasets (knowledge bases) are one of the two - target types that tags can be bound to (the other being apps). - - Args: - dataset_id: Unique identifier for the dataset - tenant_id: Tenant identifier for multi-tenancy isolation - **kwargs: Additional attributes to set on the mock - - Returns: - Mock Dataset object with specified attributes - - Example: - >>> dataset = factory.create_dataset_mock( - ... dataset_id="dataset-456", - ... name="My Knowledge Base" - ... ) - """ - # Create a mock that matches the Dataset model interface - dataset = create_autospec(Dataset, instance=True) - - # Set core attributes - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.name = kwargs.pop("name", "Test Dataset") - - # Apply any additional attributes from kwargs - for key, value in kwargs.items(): - setattr(dataset, key, value) - - return dataset - - -# ============================================================================ -# PYTEST FIXTURES -# ============================================================================ - - -@pytest.fixture -def factory(): - """ - Provide the test data factory to all tests. - - This fixture makes the TagServiceTestDataFactory available to all test - methods, allowing them to create consistent mock objects easily. - - Returns: - TagServiceTestDataFactory class - """ - return TagServiceTestDataFactory - - -# ============================================================================ -# TAG RETRIEVAL TESTS -# ============================================================================ - - -class TestTagServiceRetrieval: - """ - Test tag retrieval operations. - - This test class covers all methods related to retrieving and querying - tags from the system. These operations are read-only and do not modify - the database state. - - Methods tested: - - get_tags: Retrieve tags with optional keyword filtering - - get_target_ids_by_tag_ids: Get target IDs (datasets/apps) by tag IDs - - get_tag_by_tag_name: Find tags by exact name match - - get_tags_by_target_id: Get all tags bound to a specific target - """ - - @patch("services.tag_service.db.session") - def test_get_tags_with_binding_counts(self, mock_db_session, factory): - """ - Test retrieving tags with their binding counts. - - This test verifies that the get_tags method correctly retrieves - a list of tags along with the count of how many resources - (datasets/apps) are bound to each tag. - - The method should: - - Query tags filtered by type and tenant - - Include binding counts via a LEFT OUTER JOIN - - Return results ordered by creation date (newest first) - - Expected behavior: - - Returns a list of tuples containing (id, type, name, binding_count) - - Each tag includes its binding count - - Results are ordered by creation date descending - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "app" - - # Mock query results: tuples of (tag_id, type, name, binding_count) - # This simulates the SQL query result with aggregated binding counts - mock_results = [ - ("tag-1", "app", "Frontend", 5), # Frontend tag with 5 bindings - ("tag-2", "app", "Backend", 3), # Backend tag with 3 bindings - ("tag-3", "app", "API", 0), # API tag with no bindings - ] - - # Configure mock database session and query chain - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.outerjoin.return_value = mock_query # LEFT OUTER JOIN with TagBinding - mock_query.where.return_value = mock_query # WHERE clause for filtering - mock_query.group_by.return_value = mock_query # GROUP BY for aggregation - mock_query.order_by.return_value = mock_query # ORDER BY for sorting - mock_query.all.return_value = mock_results # Final result - - # Act - # Execute the method under test - results = TagService.get_tags(tag_type=tag_type, current_tenant_id=tenant_id) - - # Assert - # Verify the results match expectations - assert len(results) == 3, "Should return 3 tags" - - # Verify each tag's data structure - assert results[0] == ("tag-1", "app", "Frontend", 5), "First tag should match" - assert results[1] == ("tag-2", "app", "Backend", 3), "Second tag should match" - assert results[2] == ("tag-3", "app", "API", 0), "Third tag should match" - - # Verify database query was called - mock_db_session.query.assert_called_once() - - @patch("services.tag_service.db.session") - def test_get_tags_with_keyword_filter(self, mock_db_session, factory): - """ - Test retrieving tags filtered by keyword (case-insensitive). - - This test verifies that the get_tags method correctly filters tags - by keyword when a keyword parameter is provided. The filtering - should be case-insensitive and support partial matches. - - The method should: - - Apply an additional WHERE clause when keyword is provided - - Use ILIKE for case-insensitive pattern matching - - Support partial matches (e.g., "data" matches "Database" and "Data Science") - - Expected behavior: - - Returns only tags whose names contain the keyword - - Matching is case-insensitive - - Partial matches are supported - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "knowledge" - keyword = "data" - - # Mock query results filtered by keyword - mock_results = [ - ("tag-1", "knowledge", "Database", 2), - ("tag-2", "knowledge", "Data Science", 4), - ] - - # Configure mock database session and query chain - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.group_by.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = mock_results - - # Act - # Execute the method with keyword filter - results = TagService.get_tags(tag_type=tag_type, current_tenant_id=tenant_id, keyword=keyword) - - # Assert - # Verify filtered results - assert len(results) == 2, "Should return 2 matching tags" - - # Verify keyword filter was applied - # The where() method should be called at least twice: - # 1. Initial WHERE clause for type and tenant - # 2. Additional WHERE clause for keyword filtering - assert mock_query.where.call_count >= 2, "Keyword filter should add WHERE clause" - - @patch("services.tag_service.db.session") - def test_get_target_ids_by_tag_ids(self, mock_db_session, factory): - """ - Test retrieving target IDs by tag IDs. - - This test verifies that the get_target_ids_by_tag_ids method correctly - retrieves all target IDs (dataset/app IDs) that are bound to the - specified tags. This is useful for filtering datasets or apps by tags. - - The method should: - - First validate and filter tags by type and tenant - - Then find all bindings for those tags - - Return the target IDs from those bindings - - Expected behavior: - - Returns a list of target IDs (strings) - - Only includes targets bound to valid tags - - Respects tenant and type filtering - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "app" - tag_ids = ["tag-1", "tag-2"] - - # Create mock tag objects - tags = [ - factory.create_tag_mock(tag_id="tag-1", tenant_id=tenant_id, tag_type=tag_type), - factory.create_tag_mock(tag_id="tag-2", tenant_id=tenant_id, tag_type=tag_type), - ] - - # Mock target IDs that are bound to these tags - target_ids = ["app-1", "app-2", "app-3"] - - # Mock tag query (first scalars call) - mock_scalars_tags = MagicMock() - mock_scalars_tags.all.return_value = tags - - # Mock binding query (second scalars call) - mock_scalars_bindings = MagicMock() - mock_scalars_bindings.all.return_value = target_ids - - # Configure side_effect to return different mocks for each scalars() call - mock_db_session.scalars.side_effect = [mock_scalars_tags, mock_scalars_bindings] - - # Act - # Execute the method under test - results = TagService.get_target_ids_by_tag_ids(tag_type=tag_type, current_tenant_id=tenant_id, tag_ids=tag_ids) - - # Assert - # Verify results match expected target IDs - assert results == target_ids, "Should return all target IDs bound to tags" - - # Verify both queries were executed - assert mock_db_session.scalars.call_count == 2, "Should execute tag query and binding query" - - @patch("services.tag_service.db.session") - def test_get_target_ids_with_empty_tag_ids(self, mock_db_session, factory): - """ - Test that empty tag_ids returns empty list. - - This test verifies the edge case handling when an empty list of - tag IDs is provided. The method should return early without - executing any database queries. - - Expected behavior: - - Returns empty list immediately - - Does not execute any database queries - - Handles empty input gracefully - """ - # Arrange - # Set up test parameters with empty tag IDs - tenant_id = "tenant-123" - tag_type = "app" - - # Act - # Execute the method with empty tag IDs list - results = TagService.get_target_ids_by_tag_ids(tag_type=tag_type, current_tenant_id=tenant_id, tag_ids=[]) - - # Assert - # Verify empty result and no database queries - assert results == [], "Should return empty list for empty input" - mock_db_session.scalars.assert_not_called(), "Should not query database for empty input" - - @patch("services.tag_service.db.session") - def test_get_tag_by_tag_name(self, mock_db_session, factory): - """ - Test retrieving tags by name. - - This test verifies that the get_tag_by_tag_name method correctly - finds tags by their exact name. This is used for duplicate name - checking and tag lookup operations. - - The method should: - - Perform exact name matching (case-sensitive) - - Filter by type and tenant - - Return a list of matching tags (usually 0 or 1) - - Expected behavior: - - Returns list of tags with matching name - - Respects type and tenant filtering - - Returns empty list if no matches found - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "app" - tag_name = "Production" - - # Create mock tag with matching name - tags = [factory.create_tag_mock(name=tag_name, tag_type=tag_type, tenant_id=tenant_id)] - - # Configure mock database session - mock_scalars = MagicMock() - mock_scalars.all.return_value = tags - mock_db_session.scalars.return_value = mock_scalars - - # Act - # Execute the method under test - results = TagService.get_tag_by_tag_name(tag_type=tag_type, current_tenant_id=tenant_id, tag_name=tag_name) - - # Assert - # Verify tag was found - assert len(results) == 1, "Should find exactly one tag" - assert results[0].name == tag_name, "Tag name should match" - - @patch("services.tag_service.db.session") - def test_get_tag_by_tag_name_returns_empty_for_missing_params(self, mock_db_session, factory): - """ - Test that missing tag_type or tag_name returns empty list. - - This test verifies the input validation for the get_tag_by_tag_name - method. When either tag_type or tag_name is empty or missing, - the method should return early without querying the database. - - Expected behavior: - - Returns empty list for empty tag_type - - Returns empty list for empty tag_name - - Does not execute database queries for invalid input - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - - # Act & Assert - # Test with empty tag_type - assert TagService.get_tag_by_tag_name("", tenant_id, "name") == [], "Should return empty for empty type" - - # Test with empty tag_name - assert TagService.get_tag_by_tag_name("app", tenant_id, "") == [], "Should return empty for empty name" - - # Verify no database queries were executed - mock_db_session.scalars.assert_not_called(), "Should not query database for invalid input" - - @patch("services.tag_service.db.session") - def test_get_tags_by_target_id(self, mock_db_session, factory): - """ - Test retrieving tags associated with a specific target. - - This test verifies that the get_tags_by_target_id method correctly - retrieves all tags that are bound to a specific target (dataset or app). - This is useful for displaying tags associated with a resource. - - The method should: - - Join Tag and TagBinding tables - - Filter by target_id, tenant, and type - - Return all tags bound to the target - - Expected behavior: - - Returns list of Tag objects bound to the target - - Respects tenant and type filtering - - Returns empty list if no tags are bound - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "app" - target_id = "app-123" - - # Create mock tags that are bound to the target - tags = [ - factory.create_tag_mock(tag_id="tag-1", name="Frontend"), - factory.create_tag_mock(tag_id="tag-2", name="Production"), - ] - - # Configure mock database session and query chain - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.join.return_value = mock_query # JOIN with TagBinding - mock_query.where.return_value = mock_query # WHERE clause for filtering - mock_query.all.return_value = tags # Final result - - # Act - # Execute the method under test - results = TagService.get_tags_by_target_id(tag_type=tag_type, current_tenant_id=tenant_id, target_id=target_id) - - # Assert - # Verify tags were retrieved - assert len(results) == 2, "Should return 2 tags bound to target" - - # Verify tag names - assert results[0].name == "Frontend", "First tag name should match" - assert results[1].name == "Production", "Second tag name should match" - - -# ============================================================================ -# TAG CRUD OPERATIONS TESTS -# ============================================================================ - - -class TestTagServiceCRUD: - """ - Test tag CRUD operations. - - This test class covers all Create, Read, Update, and Delete operations - for tags. These operations modify the database state and require proper - transaction handling and validation. - - Methods tested: - - save_tags: Create new tags - - update_tags: Update existing tag names - - delete_tag: Delete tags and cascade delete bindings - - get_tag_binding_count: Get count of bindings for a tag - """ - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session") - @patch("services.tag_service.uuid.uuid4", autospec=True) - def test_save_tags(self, mock_uuid, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): - """ - Test creating a new tag. - - This test verifies that the save_tags method correctly creates a new - tag in the database with all required attributes. The method should - validate uniqueness, generate a UUID, and persist the tag. - - The method should: - - Check for duplicate tag names (via get_tag_by_tag_name) - - Generate a unique UUID for the tag ID - - Set user and tenant information from current_user - - Persist the tag to the database - - Commit the transaction - - Expected behavior: - - Creates tag with correct attributes - - Assigns UUID to tag ID - - Sets created_by from current_user - - Sets tenant_id from current_user - - Commits to database - """ - # Arrange - # Configure mock current_user - mock_current_user.id = "user-123" - mock_current_user.current_tenant_id = "tenant-123" - - # Mock UUID generation - mock_uuid.return_value = "new-tag-id" - - # Mock no existing tag (duplicate check passes) - mock_get_tag_by_name.return_value = [] - - # Prepare tag creation arguments - args = {"name": "New Tag", "type": "app"} - - # Act - # Execute the method under test - result = TagService.save_tags(args) - - # Assert - # Verify tag was added to database session - mock_db_session.add.assert_called_once(), "Should add tag to session" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - # Verify tag attributes - added_tag = mock_db_session.add.call_args[0][0] - assert added_tag.name == "New Tag", "Tag name should match" - assert added_tag.type == TagType.APP, "Tag type should match" - assert added_tag.created_by == "user-123", "Created by should match current user" - assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - def test_save_tags_raises_error_for_duplicate_name(self, mock_get_tag_by_name, mock_current_user, factory): - """ - Test that creating a tag with duplicate name raises ValueError. - - This test verifies that the save_tags method correctly prevents - duplicate tag names within the same tenant and type. Tag names - must be unique per tenant and type combination. - - Expected behavior: - - Raises ValueError when duplicate name is detected - - Error message indicates "Tag name already exists" - - Does not create the tag - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Mock existing tag with same name (duplicate detected) - existing_tag = factory.create_tag_mock(name="Existing Tag") - mock_get_tag_by_name.return_value = [existing_tag] - - # Prepare tag creation arguments with duplicate name - args = {"name": "Existing Tag", "type": "app"} - - # Act & Assert - # Verify ValueError is raised for duplicate name - with pytest.raises(ValueError, match="Tag name already exists"): - TagService.save_tags(args) - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session") - def test_update_tags(self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): - """ - Test updating a tag name. - - This test verifies that the update_tags method correctly updates - an existing tag's name while preserving other attributes. The method - should validate uniqueness of the new name and ensure the tag exists. - - The method should: - - Check for duplicate tag names (excluding the current tag) - - Find the tag by ID - - Update the tag name - - Commit the transaction - - Expected behavior: - - Updates tag name successfully - - Preserves other tag attributes - - Commits to database - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Mock no duplicate name (update check passes) - mock_get_tag_by_name.return_value = [] - - # Create mock tag to be updated - tag = factory.create_tag_mock(tag_id="tag-123", name="Old Name") - - # Configure mock database session to return the tag - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = tag - - # Prepare update arguments - args = {"name": "New Name", "type": "app"} - - # Act - # Execute the method under test - result = TagService.update_tags(args, tag_id="tag-123") - - # Assert - # Verify tag name was updated - assert tag.name == "New Name", "Tag name should be updated" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session") - def test_update_tags_raises_error_for_duplicate_name( - self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory - ): - """ - Test that updating to a duplicate name raises ValueError. - - This test verifies that the update_tags method correctly prevents - updating a tag to a name that already exists for another tag - within the same tenant and type. - - Expected behavior: - - Raises ValueError when duplicate name is detected - - Error message indicates "Tag name already exists" - - Does not update the tag - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Mock existing tag with the duplicate name - existing_tag = factory.create_tag_mock(name="Duplicate Name") - mock_get_tag_by_name.return_value = [existing_tag] - - # Prepare update arguments with duplicate name - args = {"name": "Duplicate Name", "type": "app"} - - # Act & Assert - # Verify ValueError is raised for duplicate name - with pytest.raises(ValueError, match="Tag name already exists"): - TagService.update_tags(args, tag_id="tag-123") - - @patch("services.tag_service.db.session") - def test_update_tags_raises_not_found_for_missing_tag(self, mock_db_session, factory): - """ - Test that updating a non-existent tag raises NotFound. - - This test verifies that the update_tags method correctly handles - the case when attempting to update a tag that does not exist. - This prevents silent failures and provides clear error feedback. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "Tag not found" - - Does not attempt to update or commit - """ - # Arrange - # Configure mock database session to return None (tag not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Mock duplicate check and current_user - with patch("services.tag_service.TagService.get_tag_by_tag_name", return_value=[], autospec=True): - with patch("services.tag_service.current_user", autospec=True) as mock_user: - mock_user.current_tenant_id = "tenant-123" - args = {"name": "New Name", "type": "app"} - - # Act & Assert - # Verify NotFound is raised for non-existent tag - with pytest.raises(NotFound, match="Tag not found"): - TagService.update_tags(args, tag_id="nonexistent") - - @patch("services.tag_service.db.session") - def test_get_tag_binding_count(self, mock_db_session, factory): - """ - Test getting the count of bindings for a tag. - - This test verifies that the get_tag_binding_count method correctly - counts how many resources (datasets/apps) are bound to a specific tag. - This is useful for displaying tag usage statistics. - - The method should: - - Query TagBinding table filtered by tag_id - - Return the count of matching bindings - - Expected behavior: - - Returns integer count of bindings - - Returns 0 for tags with no bindings - """ - # Arrange - # Set up test parameters - tag_id = "tag-123" - expected_count = 5 - - # Configure mock database session - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.count.return_value = expected_count - - # Act - # Execute the method under test - result = TagService.get_tag_binding_count(tag_id) - - # Assert - # Verify count matches expectation - assert result == expected_count, "Binding count should match" - - @patch("services.tag_service.db.session") - def test_delete_tag(self, mock_db_session, factory): - """ - Test deleting a tag and its bindings. - - This test verifies that the delete_tag method correctly deletes - a tag along with all its associated bindings (cascade delete). - This ensures data integrity and prevents orphaned bindings. - - The method should: - - Find the tag by ID - - Delete the tag - - Find all bindings for the tag - - Delete all bindings (cascade delete) - - Commit the transaction - - Expected behavior: - - Deletes tag from database - - Deletes all associated bindings - - Commits transaction - """ - # Arrange - # Set up test parameters - tag_id = "tag-123" - - # Create mock tag to be deleted - tag = factory.create_tag_mock(tag_id=tag_id) - - # Create mock bindings that will be cascade deleted - bindings = [factory.create_tag_binding_mock(binding_id=f"binding-{i}", tag_id=tag_id) for i in range(3)] - - # Configure mock database session for tag query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = tag - - # Configure mock database session for bindings query - mock_scalars = MagicMock() - mock_scalars.all.return_value = bindings - mock_db_session.scalars.return_value = mock_scalars - - # Act - # Execute the method under test - TagService.delete_tag(tag_id) - - # Assert - # Verify tag and bindings were deleted - mock_db_session.delete.assert_called(), "Should call delete method" - - # Verify delete was called 4 times (1 tag + 3 bindings) - assert mock_db_session.delete.call_count == 4, "Should delete tag and all bindings" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - @patch("services.tag_service.db.session") - def test_delete_tag_raises_not_found(self, mock_db_session, factory): - """ - Test that deleting a non-existent tag raises NotFound. - - This test verifies that the delete_tag method correctly handles - the case when attempting to delete a tag that does not exist. - This prevents silent failures and provides clear error feedback. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "Tag not found" - - Does not attempt to delete or commit - """ - # Arrange - # Configure mock database session to return None (tag not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Act & Assert - # Verify NotFound is raised for non-existent tag - with pytest.raises(NotFound, match="Tag not found"): - TagService.delete_tag("nonexistent") - - -# ============================================================================ -# TAG BINDING OPERATIONS TESTS -# ============================================================================ - - -class TestTagServiceBindings: - """ - Test tag binding operations. - - This test class covers all operations related to binding tags to - resources (datasets and apps). Tag bindings create the many-to-many - relationship between tags and resources. - - Methods tested: - - save_tag_binding: Create bindings between tags and targets - - delete_tag_binding: Remove bindings between tags and targets - - check_target_exists: Validate target (dataset/app) existence - """ - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session") - def test_save_tag_binding(self, mock_db_session, mock_check_target, mock_current_user, factory): - """ - Test creating tag bindings. - - This test verifies that the save_tag_binding method correctly - creates bindings between tags and a target resource (dataset or app). - The method supports batch binding of multiple tags to a single target. - - The method should: - - Validate target exists (via check_target_exists) - - Check for existing bindings to avoid duplicates - - Create new bindings for tags that aren't already bound - - Commit the transaction - - Expected behavior: - - Validates target exists - - Creates bindings for each tag in tag_ids - - Skips tags that are already bound (idempotent) - - Commits transaction - """ - # Arrange - # Configure mock current_user - mock_current_user.id = "user-123" - mock_current_user.current_tenant_id = "tenant-123" - - # Configure mock database session (no existing bindings) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # No existing bindings - - # Prepare binding arguments (batch binding) - args = {"type": "app", "target_id": "app-123", "tag_ids": ["tag-1", "tag-2"]} - - # Act - # Execute the method under test - TagService.save_tag_binding(args) - - # Assert - # Verify target existence was checked - mock_check_target.assert_called_once_with("app", "app-123"), "Should validate target exists" - - # Verify bindings were created (2 bindings for 2 tags) - assert mock_db_session.add.call_count == 2, "Should create 2 bindings" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session") - def test_save_tag_binding_is_idempotent(self, mock_db_session, mock_check_target, mock_current_user, factory): - """ - Test that saving duplicate bindings is idempotent. - - This test verifies that the save_tag_binding method correctly handles - the case when attempting to create a binding that already exists. - The method should skip existing bindings and not create duplicates, - making the operation idempotent. - - Expected behavior: - - Checks for existing bindings - - Skips tags that are already bound - - Does not create duplicate bindings - - Still commits transaction - """ - # Arrange - # Configure mock current_user - mock_current_user.id = "user-123" - mock_current_user.current_tenant_id = "tenant-123" - - # Mock existing binding (duplicate detected) - existing_binding = factory.create_tag_binding_mock() - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = existing_binding # Binding already exists - - # Prepare binding arguments - args = {"type": "app", "target_id": "app-123", "tag_ids": ["tag-1"]} - - # Act - # Execute the method under test - TagService.save_tag_binding(args) - - # Assert - # Verify no new binding was added (idempotent) - mock_db_session.add.assert_not_called(), "Should not create duplicate binding" - - @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session") - def test_delete_tag_binding(self, mock_db_session, mock_check_target, factory): - """ - Test deleting a tag binding. - - This test verifies that the delete_tag_binding method correctly - removes a binding between a tag and a target resource. This - operation should be safe even if the binding doesn't exist. - - The method should: - - Validate target exists (via check_target_exists) - - Find the binding by tag_id and target_id - - Delete the binding if it exists - - Commit the transaction - - Expected behavior: - - Validates target exists - - Deletes the binding - - Commits transaction - """ - # Arrange - # Create mock binding to be deleted - binding = factory.create_tag_binding_mock() - - # Configure mock database session - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = binding - - # Prepare delete arguments - args = {"type": "app", "target_id": "app-123", "tag_id": "tag-1"} - - # Act - # Execute the method under test - TagService.delete_tag_binding(args) - - # Assert - # Verify target existence was checked - mock_check_target.assert_called_once_with("app", "app-123"), "Should validate target exists" - - # Verify binding was deleted - mock_db_session.delete.assert_called_once_with(binding), "Should delete the binding" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session") - def test_delete_tag_binding_does_nothing_if_not_exists(self, mock_db_session, mock_check_target, factory): - """ - Test that deleting a non-existent binding is a no-op. - - This test verifies that the delete_tag_binding method correctly - handles the case when attempting to delete a binding that doesn't - exist. The method should not raise an error and should not commit - if there's nothing to delete. - - Expected behavior: - - Validates target exists - - Does not raise error for non-existent binding - - Does not call delete or commit if binding doesn't exist - """ - # Arrange - # Configure mock database session (binding not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # Binding doesn't exist - - # Prepare delete arguments - args = {"type": "app", "target_id": "app-123", "tag_id": "tag-1"} - - # Act - # Execute the method under test - TagService.delete_tag_binding(args) - - # Assert - # Verify no delete operation was attempted - mock_db_session.delete.assert_not_called(), "Should not delete if binding doesn't exist" - - # Verify no commit was made (nothing changed) - mock_db_session.commit.assert_not_called(), "Should not commit if nothing to delete" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session") - def test_check_target_exists_for_dataset(self, mock_db_session, mock_current_user, factory): - """ - Test validating that a dataset target exists. - - This test verifies that the check_target_exists method correctly - validates the existence of a dataset (knowledge base) when the - target type is "knowledge". This validation ensures bindings - are only created for valid resources. - - The method should: - - Query Dataset table filtered by tenant and ID - - Raise NotFound if dataset doesn't exist - - Return normally if dataset exists - - Expected behavior: - - No exception raised when dataset exists - - Database query is executed - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Create mock dataset - dataset = factory.create_dataset_mock() - - # Configure mock database session - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = dataset # Dataset exists - - # Act - # Execute the method under test - TagService.check_target_exists("knowledge", "dataset-123") - - # Assert - # Verify no exception was raised and query was executed - mock_db_session.query.assert_called_once(), "Should query database for dataset" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session") - def test_check_target_exists_for_app(self, mock_db_session, mock_current_user, factory): - """ - Test validating that an app target exists. - - This test verifies that the check_target_exists method correctly - validates the existence of an application when the target type is - "app". This validation ensures bindings are only created for valid - resources. - - The method should: - - Query App table filtered by tenant and ID - - Raise NotFound if app doesn't exist - - Return normally if app exists - - Expected behavior: - - No exception raised when app exists - - Database query is executed - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Create mock app - app = factory.create_app_mock() - - # Configure mock database session - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = app # App exists - - # Act - # Execute the method under test - TagService.check_target_exists("app", "app-123") - - # Assert - # Verify no exception was raised and query was executed - mock_db_session.query.assert_called_once(), "Should query database for app" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session") - def test_check_target_exists_raises_not_found_for_missing_dataset( - self, mock_db_session, mock_current_user, factory - ): - """ - Test that missing dataset raises NotFound. - - This test verifies that the check_target_exists method correctly - raises a NotFound exception when attempting to validate a dataset - that doesn't exist. This prevents creating bindings for invalid - resources. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "Dataset not found" - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Configure mock database session (dataset not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # Dataset doesn't exist - - # Act & Assert - # Verify NotFound is raised for non-existent dataset - with pytest.raises(NotFound, match="Dataset not found"): - TagService.check_target_exists("knowledge", "nonexistent") - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session") - def test_check_target_exists_raises_not_found_for_missing_app(self, mock_db_session, mock_current_user, factory): - """ - Test that missing app raises NotFound. - - This test verifies that the check_target_exists method correctly - raises a NotFound exception when attempting to validate an app - that doesn't exist. This prevents creating bindings for invalid - resources. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "App not found" - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Configure mock database session (app not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # App doesn't exist - - # Act & Assert - # Verify NotFound is raised for non-existent app - with pytest.raises(NotFound, match="App not found"): - TagService.check_target_exists("app", "nonexistent") - - def test_check_target_exists_raises_not_found_for_invalid_type(self, factory): - """ - Test that invalid binding type raises NotFound. - - This test verifies that the check_target_exists method correctly - raises a NotFound exception when an invalid target type is provided. - Only "knowledge" (for datasets) and "app" are valid target types. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "Invalid binding type" - """ - # Act & Assert - # Verify NotFound is raised for invalid target type - with pytest.raises(NotFound, match="Invalid binding type"): - TagService.check_target_exists("invalid_type", "target-123") From 7cc81e9a43f2711f1d014b7db346cd88c7102339 Mon Sep 17 00:00:00 2001 From: YBoy Date: Sat, 28 Mar 2026 09:50:26 +0200 Subject: [PATCH 10/14] test: migrate workspace service tests to testcontainers (#34218) --- .../services/test_workspace_service.py | 284 ++++++++- .../services/test_workspace_service.py | 576 ------------------ 2 files changed, 283 insertions(+), 577 deletions(-) delete mode 100644 api/tests/unit_tests/services/test_workspace_service.py diff --git a/api/tests/test_containers_integration_tests/services/test_workspace_service.py b/api/tests/test_containers_integration_tests/services/test_workspace_service.py index 92dec24c7d8..4e89d906f16 100644 --- a/api/tests/test_containers_integration_tests/services/test_workspace_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workspace_service.py @@ -1,4 +1,6 @@ -from unittest.mock import patch +from __future__ import annotations + +from unittest.mock import MagicMock, patch import pytest from faker import Faker @@ -534,3 +536,283 @@ class TestWorkspaceService: # Verify database state db_session_with_containers.refresh(tenant) assert tenant.id is not None + + def test_get_tenant_info_should_raise_assertion_when_join_missing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """TenantAccountJoin must exist; missing join should raise AssertionError.""" + fake = Faker() + account = Account(email=fake.email(), name=fake.name(), interface_language="en-US", status="active") + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant = Tenant(name=fake.company(), status="normal", plan="basic") + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + # No TenantAccountJoin created + with patch("services.workspace_service.current_user", account): + with pytest.raises(AssertionError, match="TenantAccountJoin not found"): + WorkspaceService.get_tenant_info(tenant) + + def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """replace_webapp_logo should be None when custom_config_dict does not have the key.""" + import json + + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + tenant.custom_config = json.dumps({}) + db_session_with_containers.commit() + + mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True + mock_external_service_dependencies["tenant_service"].has_roles.return_value = True + + with patch("services.workspace_service.current_user", account): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["custom_config"]["replace_webapp_logo"] is None + + def test_get_tenant_info_should_use_files_url_for_logo_url( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """The logo URL should use dify_config.FILES_URL as the base.""" + import json + + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + tenant.custom_config = json.dumps({"replace_webapp_logo": True}) + db_session_with_containers.commit() + + custom_base = "https://cdn.mycompany.io" + mock_external_service_dependencies["dify_config"].FILES_URL = custom_base + mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True + mock_external_service_dependencies["tenant_service"].has_roles.return_value = True + + with patch("services.workspace_service.current_user", account): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base) + + def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "SELF_HOSTED" + mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = False + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + with patch("services.workspace_service.current_user", account): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert "next_credit_reset_date" not in result + assert "trial_credits" not in result + assert "trial_credits_used" not in result + + def test_get_tenant_info_cloud_credit_reset_date( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """next_credit_reset_date should be present in CLOUD edition.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", return_value=None), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["next_credit_reset_date"] == "2025-02-01" + + def test_get_tenant_info_cloud_paid_pool_not_full( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """trial_credits come from paid pool when plan is not sandbox and pool is not full.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + paid_pool = MagicMock(quota_limit=1000, quota_used=200) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", return_value=paid_pool), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == 1000 + assert result["trial_credits_used"] == 200 + + def test_get_tenant_info_cloud_paid_pool_unlimited( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """quota_limit == -1 means unlimited; service should use paid pool.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + paid_pool = MagicMock(quota_limit=-1, quota_used=999) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, None]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == -1 + assert result["trial_credits_used"] == 999 + + def test_get_tenant_info_cloud_fall_back_to_trial_when_paid_full( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """When paid pool is exhausted, switch to trial pool.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + paid_pool = MagicMock(quota_limit=500, quota_used=500) + trial_pool = MagicMock(quota_limit=100, quota_used=10) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, trial_pool]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == 100 + assert result["trial_credits_used"] == 10 + + def test_get_tenant_info_cloud_fall_back_to_trial_when_paid_none( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """When paid_pool is None, fall back to trial pool.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + trial_pool = MagicMock(quota_limit=50, quota_used=5) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[None, trial_pool]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == 50 + assert result["trial_credits_used"] == 5 + + def test_get_tenant_info_cloud_sandbox_uses_trial_pool( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """When plan is SANDBOX, skip paid pool and use trial pool.""" + from enums.cloud_plan import CloudPlan + + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = CloudPlan.SANDBOX + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + paid_pool = MagicMock(quota_limit=1000, quota_used=0) + trial_pool = MagicMock(quota_limit=200, quota_used=20) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, trial_pool]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == 200 + assert result["trial_credits_used"] == 20 + + def test_get_tenant_info_cloud_both_pools_none( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """When both paid and trial pools are absent, trial_credits should not be set.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[None, None]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert "trial_credits" not in result + assert "trial_credits_used" not in result diff --git a/api/tests/unit_tests/services/test_workspace_service.py b/api/tests/unit_tests/services/test_workspace_service.py deleted file mode 100644 index 9bfd7eb2c5b..00000000000 --- a/api/tests/unit_tests/services/test_workspace_service.py +++ /dev/null @@ -1,576 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture - -from models.account import Tenant - -# --------------------------------------------------------------------------- -# Constants used throughout the tests -# --------------------------------------------------------------------------- - -TENANT_ID = "tenant-abc" -ACCOUNT_ID = "account-xyz" -FILES_BASE_URL = "https://files.example.com" - -DB_PATH = "services.workspace_service.db" -FEATURE_SERVICE_PATH = "services.workspace_service.FeatureService.get_features" -TENANT_SERVICE_PATH = "services.workspace_service.TenantService.has_roles" -DIFY_CONFIG_PATH = "services.workspace_service.dify_config" -CURRENT_USER_PATH = "services.workspace_service.current_user" -CREDIT_POOL_SERVICE_PATH = "services.credit_pool_service.CreditPoolService.get_pool" - - -# --------------------------------------------------------------------------- -# Helpers / factories -# --------------------------------------------------------------------------- - - -def _make_tenant( - tenant_id: str = TENANT_ID, - name: str = "My Workspace", - plan: str = "sandbox", - status: str = "active", - custom_config: dict | None = None, -) -> Tenant: - """Create a minimal Tenant-like namespace.""" - return cast( - Tenant, - SimpleNamespace( - id=tenant_id, - name=name, - plan=plan, - status=status, - created_at="2024-01-01T00:00:00Z", - custom_config_dict=custom_config or {}, - ), - ) - - -def _make_feature( - can_replace_logo: bool = False, - next_credit_reset_date: str | None = None, - billing_plan: str = "sandbox", -) -> MagicMock: - """Create a feature namespace matching what FeatureService.get_features returns.""" - feature = MagicMock() - feature.can_replace_logo = can_replace_logo - feature.next_credit_reset_date = next_credit_reset_date - feature.billing.subscription.plan = billing_plan - return feature - - -def _make_pool(quota_limit: int, quota_used: int) -> MagicMock: - pool = MagicMock() - pool.quota_limit = quota_limit - pool.quota_used = quota_used - return pool - - -def _make_tenant_account_join(role: str = "normal") -> SimpleNamespace: - return SimpleNamespace(role=role) - - -def _tenant_info(result: object) -> dict[str, Any] | None: - return cast(dict[str, Any] | None, result) - - -# --------------------------------------------------------------------------- -# Shared fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def mock_current_user() -> SimpleNamespace: - """Return a lightweight current_user stand-in.""" - return SimpleNamespace(id=ACCOUNT_ID) - - -@pytest.fixture -def basic_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict: - """ - Patch the common external boundaries used by WorkspaceService.get_tenant_info. - - Returns a dict of named mocks so individual tests can customise them. - """ - mocker.patch(CURRENT_USER_PATH, mock_current_user) - - mock_db_session = mocker.patch(f"{DB_PATH}.session") - mock_query_chain = MagicMock() - mock_db_session.query.return_value = mock_query_chain - mock_query_chain.where.return_value = mock_query_chain - mock_query_chain.first.return_value = _make_tenant_account_join(role="owner") - - mock_feature = mocker.patch(FEATURE_SERVICE_PATH, return_value=_make_feature()) - mock_has_roles = mocker.patch(TENANT_SERVICE_PATH, return_value=False) - mock_config = mocker.patch(DIFY_CONFIG_PATH) - mock_config.EDITION = "SELF_HOSTED" - mock_config.FILES_URL = FILES_BASE_URL - - return { - "db_session": mock_db_session, - "query_chain": mock_query_chain, - "get_features": mock_feature, - "has_roles": mock_has_roles, - "config": mock_config, - } - - -# --------------------------------------------------------------------------- -# 1. None Tenant Handling -# --------------------------------------------------------------------------- - - -def test_get_tenant_info_should_return_none_when_tenant_is_none() -> None: - """get_tenant_info should short-circuit and return None for a falsy tenant.""" - from services.workspace_service import WorkspaceService - - # Arrange - tenant = None - - # Act - result = WorkspaceService.get_tenant_info(cast(Tenant, tenant)) - - # Assert - assert result is None - - -def test_get_tenant_info_should_return_none_when_tenant_is_falsy() -> None: - """get_tenant_info treats any falsy value as absent (e.g. empty string, 0).""" - from services.workspace_service import WorkspaceService - - # Arrange / Act / Assert - assert WorkspaceService.get_tenant_info("") is None # type: ignore[arg-type] - - -# --------------------------------------------------------------------------- -# 2. Basic Tenant Info โ€” happy path -# --------------------------------------------------------------------------- - - -def test_get_tenant_info_should_return_base_fields( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """get_tenant_info should always return the six base scalar fields.""" - from services.workspace_service import WorkspaceService - - # Arrange - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["id"] == TENANT_ID - assert result["name"] == "My Workspace" - assert result["plan"] == "sandbox" - assert result["status"] == "active" - assert result["created_at"] == "2024-01-01T00:00:00Z" - assert result["trial_end_reason"] is None - - -def test_get_tenant_info_should_populate_role_from_tenant_account_join( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """The 'role' field should be taken from TenantAccountJoin, not the default.""" - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["query_chain"].first.return_value = _make_tenant_account_join(role="admin") - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["role"] == "admin" - - -def test_get_tenant_info_should_raise_assertion_when_tenant_account_join_missing( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """ - The service asserts that TenantAccountJoin exists. - Missing join should raise AssertionError. - """ - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["query_chain"].first.return_value = None - tenant = _make_tenant() - - # Act + Assert - with pytest.raises(AssertionError, match="TenantAccountJoin not found"): - WorkspaceService.get_tenant_info(tenant) - - -# --------------------------------------------------------------------------- -# 3. Logo Customisation -# --------------------------------------------------------------------------- - - -def test_get_tenant_info_should_include_custom_config_when_logo_allowed_and_admin( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """custom_config block should appear for OWNER/ADMIN when can_replace_logo is True.""" - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) - basic_mocks["has_roles"].return_value = True - tenant = _make_tenant( - custom_config={ - "replace_webapp_logo": True, - "remove_webapp_brand": True, - } - ) - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert "custom_config" in result - assert result["custom_config"]["remove_webapp_brand"] is True - expected_logo_url = f"{FILES_BASE_URL}/files/workspaces/{TENANT_ID}/webapp-logo" - assert result["custom_config"]["replace_webapp_logo"] == expected_logo_url - - -def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """replace_webapp_logo should be None when custom_config_dict does not have the key.""" - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) - basic_mocks["has_roles"].return_value = True - tenant = _make_tenant(custom_config={}) # no replace_webapp_logo key - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["custom_config"]["replace_webapp_logo"] is None - - -def test_get_tenant_info_should_not_include_custom_config_when_logo_not_allowed( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """custom_config should be absent when can_replace_logo is False.""" - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=False) - basic_mocks["has_roles"].return_value = True - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert "custom_config" not in result - - -def test_get_tenant_info_should_not_include_custom_config_when_user_not_admin( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """custom_config block is gated on OWNER or ADMIN role.""" - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) - basic_mocks["has_roles"].return_value = False # regular member - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert "custom_config" not in result - - -def test_get_tenant_info_should_use_files_url_for_logo_url( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """The logo URL should use dify_config.FILES_URL as the base.""" - from services.workspace_service import WorkspaceService - - # Arrange - custom_base = "https://cdn.mycompany.io" - basic_mocks["config"].FILES_URL = custom_base - basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) - basic_mocks["has_roles"].return_value = True - tenant = _make_tenant(custom_config={"replace_webapp_logo": True}) - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base) - - -# --------------------------------------------------------------------------- -# 4. Cloud-Edition Credit Features -# --------------------------------------------------------------------------- - -CLOUD_BILLING_PLAN_NON_SANDBOX = "professional" # any plan that is not SANDBOX - - -@pytest.fixture -def cloud_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict: - """Patches for CLOUD edition tests, billing plan = professional by default.""" - mocker.patch(CURRENT_USER_PATH, mock_current_user) - - mock_db_session = mocker.patch(f"{DB_PATH}.session") - mock_query_chain = MagicMock() - mock_db_session.query.return_value = mock_query_chain - mock_query_chain.where.return_value = mock_query_chain - mock_query_chain.first.return_value = _make_tenant_account_join(role="owner") - - mock_feature = mocker.patch( - FEATURE_SERVICE_PATH, - return_value=_make_feature( - can_replace_logo=False, - next_credit_reset_date="2025-02-01", - billing_plan=CLOUD_BILLING_PLAN_NON_SANDBOX, - ), - ) - mocker.patch(TENANT_SERVICE_PATH, return_value=False) - mock_config = mocker.patch(DIFY_CONFIG_PATH) - mock_config.EDITION = "CLOUD" - mock_config.FILES_URL = FILES_BASE_URL - - return { - "db_session": mock_db_session, - "query_chain": mock_query_chain, - "get_features": mock_feature, - "config": mock_config, - } - - -def test_get_tenant_info_should_add_next_credit_reset_date_in_cloud_edition( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """next_credit_reset_date should be present in CLOUD edition.""" - from services.workspace_service import WorkspaceService - - # Arrange - mocker.patch( - CREDIT_POOL_SERVICE_PATH, - side_effect=[None, None], # both paid and trial pools absent - ) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["next_credit_reset_date"] == "2025-02-01" - - -def test_get_tenant_info_should_use_paid_pool_when_plan_is_not_sandbox_and_pool_not_full( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """trial_credits/trial_credits_used come from the paid pool when conditions are met.""" - from services.workspace_service import WorkspaceService - - # Arrange - paid_pool = _make_pool(quota_limit=1000, quota_used=200) - mocker.patch(CREDIT_POOL_SERVICE_PATH, return_value=paid_pool) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["trial_credits"] == 1000 - assert result["trial_credits_used"] == 200 - - -def test_get_tenant_info_should_use_paid_pool_when_quota_limit_is_infinite( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """quota_limit == -1 means unlimited; service should still use the paid pool.""" - from services.workspace_service import WorkspaceService - - # Arrange - paid_pool = _make_pool(quota_limit=-1, quota_used=999) - mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, None]) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["trial_credits"] == -1 - assert result["trial_credits_used"] == 999 - - -def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_full( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """When paid pool is exhausted (used >= limit), switch to trial pool.""" - from services.workspace_service import WorkspaceService - - # Arrange - paid_pool = _make_pool(quota_limit=500, quota_used=500) # exactly full - trial_pool = _make_pool(quota_limit=100, quota_used=10) - mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool]) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["trial_credits"] == 100 - assert result["trial_credits_used"] == 10 - - -def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_none( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """When paid_pool is None, fall back to trial pool.""" - from services.workspace_service import WorkspaceService - - # Arrange - trial_pool = _make_pool(quota_limit=50, quota_used=5) - mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, trial_pool]) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["trial_credits"] == 50 - assert result["trial_credits_used"] == 5 - - -def test_get_tenant_info_should_fall_back_to_trial_pool_for_sandbox_plan( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """ - When the subscription plan IS SANDBOX, the paid pool branch is skipped - entirely and we fall back to the trial pool. - """ - from enums.cloud_plan import CloudPlan - from services.workspace_service import WorkspaceService - - # Arrange โ€” override billing plan to SANDBOX - cloud_mocks["get_features"].return_value = _make_feature( - next_credit_reset_date="2025-02-01", - billing_plan=CloudPlan.SANDBOX, - ) - paid_pool = _make_pool(quota_limit=1000, quota_used=0) - trial_pool = _make_pool(quota_limit=200, quota_used=20) - mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool]) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["trial_credits"] == 200 - assert result["trial_credits_used"] == 20 - - -def test_get_tenant_info_should_omit_trial_credits_when_both_pools_are_none( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """When both paid and trial pools are absent, trial_credits should not be set.""" - from services.workspace_service import WorkspaceService - - # Arrange - mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, None]) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert "trial_credits" not in result - assert "trial_credits_used" not in result - - -# --------------------------------------------------------------------------- -# 5. Self-hosted / Non-Cloud Edition -# --------------------------------------------------------------------------- - - -def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode.""" - from services.workspace_service import WorkspaceService - - # Arrange (basic_mocks already sets EDITION = "SELF_HOSTED") - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert "next_credit_reset_date" not in result - assert "trial_credits" not in result - assert "trial_credits_used" not in result - - -# --------------------------------------------------------------------------- -# 6. DB query integrity -# --------------------------------------------------------------------------- - - -def test_get_tenant_info_should_query_tenant_account_join_with_correct_ids( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """ - The DB query for TenantAccountJoin must be scoped to the correct - tenant_id and current_user.id. - """ - from services.workspace_service import WorkspaceService - - # Arrange - tenant = _make_tenant(tenant_id="my-special-tenant") - mock_current_user = mocker.patch(CURRENT_USER_PATH) - mock_current_user.id = "special-user-id" - - # Act - WorkspaceService.get_tenant_info(tenant) - - # Assert โ€” db.session.query was invoked (at least once) - basic_mocks["db_session"].query.assert_called() From 364d7ebc406bd46cad701f67dab1de403451f8df Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Sat, 28 Mar 2026 11:14:43 +0100 Subject: [PATCH 11/14] refactor: core/tools, agent, callback_handler, encrypter, llm_generator, plugin, inner_api (#34205) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato --- api/controllers/inner_api/app/dsl.py | 5 +-- api/core/agent/base_agent_runner.py | 13 +++++--- .../index_tool_callback_handler.py | 20 ++++++------ api/core/helper/encrypter.py | 2 +- api/core/llm_generator/llm_generator.py | 5 +-- api/core/plugin/backwards_invocation/app.py | 2 +- api/core/tools/tool_label_manager.py | 4 +-- api/core/tools/tool_manager.py | 30 ++++++++--------- .../dataset_multi_retriever_tool.py | 2 +- .../dataset_retriever_tool.py | 2 +- .../controllers/inner_api/app/test_dsl.py | 14 ++++---- .../core/agent/test_base_agent_runner.py | 2 +- .../test_index_tool_callback_handler.py | 12 ++----- .../unit_tests/core/helper/test_encrypter.py | 32 +++++++++---------- .../core/llm_generator/test_llm_generator.py | 32 +++++++++---------- .../plugin/test_backwards_invocation_app.py | 12 ++----- .../core/tools/test_tool_label_manager.py | 4 +-- .../core/tools/test_tool_manager.py | 20 ++++-------- .../core/tools/utils/test_misc_utils_extra.py | 4 +-- 19 files changed, 99 insertions(+), 118 deletions(-) diff --git a/api/controllers/inner_api/app/dsl.py b/api/controllers/inner_api/app/dsl.py index 56730cf37a7..3b673d6e1d3 100644 --- a/api/controllers/inner_api/app/dsl.py +++ b/api/controllers/inner_api/app/dsl.py @@ -8,6 +8,7 @@ Go admin-api caller. from flask import request from flask_restx import Resource from pydantic import BaseModel, Field +from sqlalchemy import select from sqlalchemy.orm import Session from controllers.common.schema import register_schema_model @@ -87,7 +88,7 @@ class EnterpriseAppDSLExport(Resource): """Export an app's DSL as YAML.""" include_secret = request.args.get("include_secret", "false").lower() == "true" - app_model = db.session.query(App).filter_by(id=app_id).first() + app_model = db.session.get(App, app_id) if not app_model: return {"message": "app not found"}, 404 @@ -104,7 +105,7 @@ def _get_active_account(email: str) -> Account | None: Workspace membership is already validated by the Go admin-api caller. """ - account = db.session.query(Account).filter_by(email=email).first() + account = db.session.scalar(select(Account).where(Account.email == email).limit(1)) if account is None or account.status != AccountStatus.ACTIVE: return None return account diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index ff8f40407fa..06c746990d2 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -18,7 +18,7 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.model_runtime.entities.model_entities import ModelFeature from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from sqlalchemy import select +from sqlalchemy import func, select from core.agent.entities import AgentEntity, AgentToolEntity from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -104,11 +104,14 @@ class BaseAgentRunner(AppRunner): ) # get how many agent thoughts have been created self.agent_thought_count = ( - db.session.query(MessageAgentThought) - .where( - MessageAgentThought.message_id == self.message.id, + db.session.scalar( + select(func.count()) + .select_from(MessageAgentThought) + .where( + MessageAgentThought.message_id == self.message.id, + ) ) - .count() + or 0 ) db.session.close() diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 8de5cb16900..6a071192447 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,7 +1,7 @@ import logging from collections.abc import Sequence -from sqlalchemy import select +from sqlalchemy import select, update from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom @@ -70,23 +70,21 @@ class DatasetIndexToolCallbackHandler: ) child_chunk = db.session.scalar(child_chunk_stmt) if child_chunk: - _ = ( - db.session.query(DocumentSegment) + db.session.execute( + update(DocumentSegment) .where(DocumentSegment.id == child_chunk.segment_id) - .update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False - ) + .values(hit_count=DocumentSegment.hit_count + 1) ) else: - query = db.session.query(DocumentSegment).where( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + conditions = [DocumentSegment.index_node_id == document.metadata["doc_id"]] if "dataset_id" in document.metadata: - query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + conditions.append(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + db.session.execute( + update(DocumentSegment).where(*conditions).values(hit_count=DocumentSegment.hit_count + 1) + ) db.session.commit() diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 17345dc203b..20125ec6b30 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -19,7 +19,7 @@ def encrypt_token(tenant_id: str, token: str): from extensions.ext_database import db from models.account import Tenant - if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()): + if not (tenant := db.session.get(Tenant, tenant_id)): raise ValueError(f"Tenant with id {tenant_id} not found") assert tenant.encrypt_public_key is not None encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 3d94f1a5969..d39630ad951 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -10,6 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from sqlalchemy import select from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload @@ -410,8 +411,8 @@ class LLMGenerator: model_config: ModelConfig, ideal_output: str | None, ): - last_run: Message | None = ( - db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() + last_run: Message | None = db.session.scalar( + select(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).limit(1) ) if not last_run: return LLMGenerator.__instruction_modify_common( diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 60d08b26c95..be11d2223ca 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -227,7 +227,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): get app """ try: - app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first() + app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1)) except Exception: raise ValueError("app not found") diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 250dd91bfd2..58190d10894 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -1,4 +1,4 @@ -from sqlalchemy import select +from sqlalchemy import delete, select from core.tools.__base.tool_provider import ToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -31,7 +31,7 @@ class ToolLabelManager: raise ValueError("Unsupported tool type") # delete old labels - db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete() + db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id)) # insert new labels for label in labels: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 584bae39b9d..a58d3103137 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -255,11 +255,11 @@ class ToolManager: if builtin_provider is None: raise ToolProviderNotFoundError(f"no default provider for {provider_id}") else: - builtin_provider = ( - db.session.query(BuiltinToolProvider) + builtin_provider = db.session.scalar( + select(BuiltinToolProvider) .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - .first() + .limit(1) ) if builtin_provider is None: @@ -818,13 +818,13 @@ class ToolManager: :return: the provider controller, the credentials """ - provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) + provider: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) .where( ApiToolProvider.id == provider_id, ApiToolProvider.tenant_id == tenant_id, ) - .first() + .limit(1) ) if provider is None: @@ -872,13 +872,13 @@ class ToolManager: get api provider """ provider_name = provider - provider_obj: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) + provider_obj: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider, ) - .first() + .limit(1) ) if provider_obj is None: @@ -964,10 +964,10 @@ class ToolManager: @classmethod def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: - workflow_provider: WorkflowToolProvider | None = ( - db.session.query(WorkflowToolProvider) + workflow_provider: WorkflowToolProvider | None = db.session.scalar( + select(WorkflowToolProvider) .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) - .first() + .limit(1) ) if workflow_provider is None: @@ -981,10 +981,10 @@ class ToolManager: @classmethod def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: - api_provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) + api_provider: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) - .first() + .limit(1) ) if api_provider is None: diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 6a77fda7ef3..e63435db988 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -110,7 +110,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): context_list: list[RetrievalSourceMetadata] = [] resource_number = 1 for segment in sorted_segments: - dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() + dataset = db.session.get(Dataset, segment.dataset_id) document_stmt = select(Document).where( Document.id == segment.document_id, Document.enabled == True, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index f3d390ed59d..cbd8bdb36cf 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -205,7 +205,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): if self.return_resource: for record in records: segment = record.segment - dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() + dataset = db.session.get(Dataset, segment.dataset_id) dataset_document_stmt = select(DatasetDocument).where( DatasetDocument.id == segment.document_id, DatasetDocument.enabled == True, diff --git a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py index 5862239142e..4a5f91cc5d5 100644 --- a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py +++ b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py @@ -64,18 +64,18 @@ class TestGetActiveAccount: def test_returns_active_account(self, mock_db): mock_account = MagicMock() mock_account.status = "active" - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account + mock_db.session.scalar.return_value = mock_account result = _get_active_account("user@example.com") assert result is mock_account - mock_db.session.query.return_value.filter_by.assert_called_once_with(email="user@example.com") + mock_db.session.scalar.assert_called_once() @patch("controllers.inner_api.app.dsl.db") def test_returns_none_for_inactive_account(self, mock_db): mock_account = MagicMock() mock_account.status = "banned" - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account + mock_db.session.scalar.return_value = mock_account result = _get_active_account("banned@example.com") @@ -83,7 +83,7 @@ class TestGetActiveAccount: @patch("controllers.inner_api.app.dsl.db") def test_returns_none_for_nonexistent_email(self, mock_db): - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_db.session.scalar.return_value = None result = _get_active_account("missing@example.com") @@ -205,7 +205,7 @@ class TestEnterpriseAppDSLExport: @patch("controllers.inner_api.app.dsl.db") def test_export_success_returns_200(self, mock_db, mock_dsl_cls, api_instance, app: Flask): mock_app = MagicMock() - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app + mock_db.session.get.return_value = mock_app mock_dsl_cls.export_dsl.return_value = "version: 0.6.0\nkind: app\n" unwrapped = inspect.unwrap(api_instance.get) @@ -221,7 +221,7 @@ class TestEnterpriseAppDSLExport: @patch("controllers.inner_api.app.dsl.db") def test_export_with_secret(self, mock_db, mock_dsl_cls, api_instance, app: Flask): mock_app = MagicMock() - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app + mock_db.session.get.return_value = mock_app mock_dsl_cls.export_dsl.return_value = "yaml-data" unwrapped = inspect.unwrap(api_instance.get) @@ -234,7 +234,7 @@ class TestEnterpriseAppDSLExport: @patch("controllers.inner_api.app.dsl.db") def test_export_app_not_found_returns_404(self, mock_db, api_instance, app: Flask): - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_db.session.get.return_value = None unwrapped = inspect.unwrap(api_instance.get) with app.test_request_context("?include_secret=false"): diff --git a/api/tests/unit_tests/core/agent/test_base_agent_runner.py b/api/tests/unit_tests/core/agent/test_base_agent_runner.py index 683cc0e36f7..db4b293b163 100644 --- a/api/tests/unit_tests/core/agent/test_base_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_base_agent_runner.py @@ -621,7 +621,7 @@ class TestConvertDatasetRetrieverTool: class TestBaseAgentRunnerInit: def test_init_sets_stream_tool_call_and_files(self, mocker): session = mocker.MagicMock() - session.query.return_value.where.return_value.count.return_value = 2 + session.scalar.return_value = 2 mocker.patch.object(module.db, "session", session) mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[]) diff --git a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py index b37c4c57a1f..8e5670e9be3 100644 --- a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py +++ b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py @@ -114,13 +114,9 @@ class TestOnToolEnd: document = mocker.Mock() document.metadata = {"document_id": "doc-1", "doc_id": "node-1"} - mock_query = mocker.Mock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - handler.on_tool_end([document]) - mock_query.update.assert_called_once() + mock_db.session.execute.assert_called_once() mock_db.session.commit.assert_called_once() def test_on_tool_end_non_parent_child_index(self, handler, mocker): @@ -138,13 +134,9 @@ class TestOnToolEnd: "dataset_id": "dataset-1", } - mock_query = mocker.Mock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - handler.on_tool_end([document]) - mock_query.update.assert_called_once() + mock_db.session.execute.assert_called_once() mock_db.session.commit.assert_called_once() def test_on_tool_end_empty_documents(self, handler): diff --git a/api/tests/unit_tests/core/helper/test_encrypter.py b/api/tests/unit_tests/core/helper/test_encrypter.py index 58900097428..f3ef7fccd0f 100644 --- a/api/tests/unit_tests/core/helper/test_encrypter.py +++ b/api/tests/unit_tests/core/helper/test_encrypter.py @@ -38,13 +38,13 @@ class TestObfuscatedToken: class TestEncryptToken: - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_successful_encryption(self, mock_encrypt, mock_query): """Test successful token encryption""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_data" result = encrypt_token("tenant-123", "test_token") @@ -52,10 +52,10 @@ class TestEncryptToken: assert result == base64.b64encode(b"encrypted_data").decode() mock_encrypt.assert_called_with("test_token", "mock_public_key") - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") def test_tenant_not_found(self, mock_query): """Test error when tenant doesn't exist""" - mock_query.return_value.where.return_value.first.return_value = None + mock_query.return_value = None with pytest.raises(ValueError) as exc_info: encrypt_token("invalid-tenant", "test_token") @@ -119,7 +119,7 @@ class TestGetDecryptDecoding: class TestEncryptDecryptIntegration: - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") @patch("libs.rsa.decrypt") def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query): @@ -127,7 +127,7 @@ class TestEncryptDecryptIntegration: # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant # Setup mock encryption/decryption original_token = "test_token_123" @@ -146,14 +146,14 @@ class TestEncryptDecryptIntegration: class TestSecurity: """Critical security tests for encryption system""" - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_cross_tenant_isolation(self, mock_encrypt, mock_query): """Ensure tokens encrypted for one tenant cannot be used by another""" # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "tenant1_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_for_tenant1" # Encrypt token for tenant1 @@ -181,12 +181,12 @@ class TestSecurity: with pytest.raises(Exception, match="Decryption error"): decrypt_token("tenant-123", tampered) - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_encryption_randomness(self, mock_encrypt, mock_query): """Ensure same plaintext produces different ciphertext""" mock_tenant = MagicMock(encrypt_public_key="key") - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant # Different outputs for same input mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"] @@ -205,13 +205,13 @@ class TestEdgeCases: # Test empty string (which is a valid str type) assert obfuscated_token("") == "" - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query): """Test encryption of empty token""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_empty" result = encrypt_token("tenant-123", "") @@ -219,13 +219,13 @@ class TestEdgeCases: assert result == base64.b64encode(b"encrypted_empty").decode() mock_encrypt.assert_called_with("", "mock_public_key") - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query): """Test tokens containing special/unicode characters""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_special" # Test various special characters @@ -242,13 +242,13 @@ class TestEdgeCases: assert result == base64.b64encode(b"encrypted_special").decode() mock_encrypt.assert_called_with(token, "mock_public_key") - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query): """Test behavior when token exceeds RSA encryption limits""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant # RSA 2048-bit can only encrypt ~245 bytes # The actual limit depends on padding scheme diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index 2c0a4411254..62e714deb61 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -314,8 +314,8 @@ class TestLLMGenerator: assert "An unexpected error occurred" in result["error"] def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None # Mock __instruction_modify_common call via invoke_llm mock_response = MagicMock() @@ -328,12 +328,12 @@ class TestLLMGenerator: assert result == {"modified": "prompt"} def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: last_run = MagicMock() last_run.query = "q" last_run.answer = "a" last_run.error = "e" - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run + mock_scalar.return_value = last_run mock_response = MagicMock() mock_response.message.get_text_content.return_value = '{"modified": "prompt"}' @@ -483,8 +483,8 @@ class TestLLMGenerator: def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity): # Testing placeholders replacement via instruction_modify_legacy for convenience - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = '{"ok": true}' @@ -504,8 +504,8 @@ class TestLLMGenerator: assert "current_val" in user_msg_dict["instruction"] def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = "No braces here" mock_model_instance.invoke_llm.return_value = mock_response @@ -516,8 +516,8 @@ class TestLLMGenerator: assert "Could not find a valid JSON object" in result["error"] def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = "[1, 2, 3]" mock_model_instance.invoke_llm.return_value = mock_response @@ -556,8 +556,8 @@ class TestLLMGenerator: ) def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed") result = LLMGenerator.instruction_modify_legacy( @@ -566,8 +566,8 @@ class TestLLMGenerator: assert "Failed to generate code" in result["error"] def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_model_instance.invoke_llm.side_effect = Exception("Random error") result = LLMGenerator.instruction_modify_legacy( @@ -576,8 +576,8 @@ class TestLLMGenerator: assert "An unexpected error occurred" in result["error"] def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = "No JSON here" diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py index c2778f082b8..3feb4159ade 100644 --- a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py @@ -332,27 +332,21 @@ class TestPluginAppBackwardsInvocation: PluginAppBackwardsInvocation._get_user("uid") def test_get_app_returns_app(self, mocker): - query_chain = MagicMock() - query_chain.where.return_value = query_chain app_obj = MagicMock(id="app") - query_chain.first.return_value = app_obj - db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain))) + db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=app_obj))) mocker.patch("core.plugin.backwards_invocation.app.db", db) assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj def test_get_app_raises_when_missing(self, mocker): - query_chain = MagicMock() - query_chain.where.return_value = query_chain - query_chain.first.return_value = None - db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain))) + db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=None))) mocker.patch("core.plugin.backwards_invocation.app.db", db) with pytest.raises(ValueError, match="app not found"): PluginAppBackwardsInvocation._get_app("app", "tenant") def test_get_app_raises_when_query_fails(self, mocker): - db = SimpleNamespace(session=MagicMock(query=MagicMock(side_effect=RuntimeError("db down")))) + db = SimpleNamespace(session=MagicMock(scalar=MagicMock(side_effect=RuntimeError("db down")))) mocker.patch("core.plugin.backwards_invocation.app.db", db) with pytest.raises(ValueError, match="app not found"): diff --git a/api/tests/unit_tests/core/tools/test_tool_label_manager.py b/api/tests/unit_tests/core/tools/test_tool_label_manager.py index 857f4aa1780..8c0e7e9419e 100644 --- a/api/tests/unit_tests/core/tools/test_tool_label_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_label_manager.py @@ -38,11 +38,9 @@ def test_tool_label_manager_filter_tool_labels(): def test_tool_label_manager_update_tool_labels_db(): controller = _api_controller("api-1") with patch("core.tools.tool_label_manager.db") as mock_db: - delete_query = mock_db.session.query.return_value.where.return_value - delete_query.delete.return_value = None ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"]) - delete_query.delete.assert_called_once() + mock_db.session.execute.assert_called_once() # only one valid unique label should be inserted. assert mock_db.session.add.call_count == 1 mock_db.session.commit.assert_called_once() diff --git a/api/tests/unit_tests/core/tools/test_tool_manager.py b/api/tests/unit_tests/core/tools/test_tool_manager.py index 844bc01e29f..31b68f0b3f3 100644 --- a/api/tests/unit_tests/core/tools/test_tool_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_manager.py @@ -220,9 +220,7 @@ def test_get_tool_runtime_builtin_with_credentials_decrypts_and_forks(): with patch.object(ToolManager, "get_builtin_provider", return_value=controller): with patch("core.helper.credential_utils.check_credential_policy_compliance"): with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( - builtin_provider - ) + mock_db.session.scalar.return_value = builtin_provider encrypter = Mock() encrypter.decrypt.return_value = {"api_key": "secret"} cache = Mock() @@ -274,7 +272,7 @@ def test_get_tool_runtime_builtin_refreshes_expired_oauth_credentials( ) refreshed = SimpleNamespace(credentials={"token": "new"}, expires_at=123456) - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = builtin_provider + mock_db.session.scalar.return_value = builtin_provider encrypter = Mock() encrypter.decrypt.return_value = {"token": "old"} encrypter.encrypt.return_value = {"token": "encrypted"} @@ -698,12 +696,10 @@ def test_get_api_provider_controller_returns_controller_and_credentials(): privacy_policy="privacy", custom_disclaimer="disclaimer", ) - db_query = Mock() - db_query.where.return_value.first.return_value = provider controller = Mock() with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value = db_query + mock_db.session.scalar.return_value = provider with patch( "core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller ) as mock_from_db: @@ -730,12 +726,10 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels(): privacy_policy="privacy", custom_disclaimer="disclaimer", ) - db_query = Mock() - db_query.where.return_value.first.return_value = provider controller = Mock() with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value = db_query + mock_db.session.scalar.return_value = provider with patch("core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller): encrypter = Mock() encrypter.decrypt.return_value = {"api_key_value": "secret"} @@ -750,7 +744,7 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels(): def test_get_api_provider_controller_not_found_raises(): with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with pytest.raises(ToolProviderNotFoundError, match="api provider missing not found"): ToolManager.get_api_provider_controller("tenant-1", "missing") @@ -809,14 +803,14 @@ def test_generate_tool_icon_urls_for_workflow_and_api(): workflow_provider = SimpleNamespace(icon='{"background": "#222", "content": "W"}') api_provider = SimpleNamespace(icon='{"background": "#333", "content": "A"}') with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.side_effect = [workflow_provider, api_provider] + mock_db.session.scalar.side_effect = [workflow_provider, api_provider] assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "wf-1") == {"background": "#222", "content": "W"} assert ToolManager.generate_api_tool_icon_url("tenant-1", "api-1") == {"background": "#333", "content": "A"} def test_generate_tool_icon_urls_missing_workflow_and_api_use_default(): with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "missing")["background"] == "#252525" assert ToolManager.generate_api_tool_icon_url("tenant-1", "missing")["background"] == "#252525" diff --git a/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py b/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py index 4ce73272bf0..a93624123e2 100644 --- a/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py +++ b/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py @@ -263,7 +263,7 @@ def test_single_dataset_retriever_non_economy_run_sorts_context_and_resources(): ) db_session = Mock() db_session.scalar.side_effect = [dataset, lookup_doc_low, lookup_doc_high] - db_session.query.return_value.filter_by.return_value.first.return_value = dataset + db_session.get.return_value = dataset tool = SingleDatasetRetrieverTool( tenant_id="tenant-1", @@ -444,7 +444,7 @@ def test_multi_dataset_retriever_run_orders_segments_and_returns_resources(): ) db_session = Mock() db_session.scalars.return_value.all.return_value = [segment_for_node_2, segment_for_node_1] - db_session.query.return_value.filter_by.return_value.first.side_effect = [ + db_session.get.side_effect = [ SimpleNamespace(id="dataset-2", name="Dataset Two"), SimpleNamespace(id="dataset-1", name="Dataset One"), ] From 6bf8982559bbd5268037508cac2bb324fe877b9b Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 28 Mar 2026 20:28:25 +0800 Subject: [PATCH 12/14] chore(ci): reduce web test shard fan-out (#34215) --- .github/workflows/web-tests.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index d40cd4bfebc..8110a163559 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -22,8 +22,8 @@ jobs: strategy: fail-fast: false matrix: - shardIndex: [1, 2, 3, 4, 5, 6] - shardTotal: [6] + shardIndex: [1, 2, 3, 4] + shardTotal: [4] defaults: run: shell: bash @@ -66,7 +66,6 @@ jobs: - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: - fetch-depth: 0 persist-credentials: false - name: Setup web environment From f06cc339cc8155e1e42be5e49d66fa8d8449f3f0 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 28 Mar 2026 22:04:22 +0800 Subject: [PATCH 13/14] chore(ci): remove duplicate pyrefly work from style lane (#34213) --- .github/workflows/style.yml | 2 +- Makefile | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 23ae36f7b10..7b269ccf4e5 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -49,7 +49,7 @@ jobs: - name: Run Type Checks if: steps.changed-files.outputs.any_changed == 'true' - run: make type-check + run: make type-check-core - name: Dotenv check if: steps.changed-files.outputs.any_changed == 'true' diff --git a/Makefile b/Makefile index 55871c86a72..c377b7c671f 100644 --- a/Makefile +++ b/Makefile @@ -74,6 +74,12 @@ type-check: @uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . @echo "โœ… Type checks complete" +type-check-core: + @echo "๐Ÿ“ Running core type checks (basedpyright + mypy)..." + @./dev/basedpyright-check $(PATH_TO_CHECK) + @uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . + @echo "โœ… Core type checks complete" + test: @echo "๐Ÿงช Running backend unit tests..." @if [ -n "$(TARGET_TESTS)" ]; then \ @@ -133,6 +139,7 @@ help: @echo " make check - Check code with ruff" @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)" @echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)" + @echo " make type-check-core - Run core type checks (basedpyright, mypy)" @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)" @echo "" @echo "Docker Build Targets:" From a1171877a463895f0c2040d0a07775550fc97bb3 Mon Sep 17 00:00:00 2001 From: Jasonfish Date: Sat, 28 Mar 2026 23:37:51 +0800 Subject: [PATCH 14/14] fix: Fix docker-compose.yaml's ENV variables (#31101) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: -LAN- --- api/.env.example | 3 ++- docker/.env.example | 3 ++- docker/docker-compose-template.yaml | 1 + docker/docker-compose.yaml | 2 +- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/api/.env.example b/api/.env.example index 9672a99d555..c6541731e64 100644 --- a/api/.env.example +++ b/api/.env.example @@ -127,7 +127,8 @@ ALIYUN_OSS_AUTH_VERSION=v1 ALIYUN_OSS_REGION=your-region # Don't start with '/'. OSS doesn't support leading slash in object names. ALIYUN_OSS_PATH=your-path -ALIYUN_CLOUDBOX_ID=your-cloudbox-id +# Optional CloudBox ID for Aliyun OSS, DO NOT enable it if you are not using CloudBox. +#ALIYUN_CLOUDBOX_ID=your-cloudbox-id # Google Storage configuration GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name diff --git a/docker/.env.example b/docker/.env.example index 8cf77cf56b0..9fbf9a9e72a 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -488,7 +488,8 @@ ALIYUN_OSS_REGION=ap-southeast-1 ALIYUN_OSS_AUTH_VERSION=v4 # Don't start with '/'. OSS doesn't support leading slash in object names. ALIYUN_OSS_PATH=your-path -ALIYUN_CLOUDBOX_ID=your-cloudbox-id +# Optional CloudBox ID for Aliyun OSS, DO NOT enable it if you are not using CloudBox. +#ALIYUN_CLOUDBOX_ID=your-cloudbox-id # Tencent COS Configuration # diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 98c2613a07a..e55cf942c32 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -275,6 +275,7 @@ services: # Use the shared environment variables. <<: *shared-api-worker-env DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin} + DB_SSL_MODE: ${DB_SSL_MODE:-disable} SERVER_PORT: ${PLUGIN_DAEMON_PORT:-5002} SERVER_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi} MAX_PLUGIN_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 2a75de1a899..737a62020ca 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -146,7 +146,6 @@ x-shared-env: &shared-api-worker-env ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1} ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path} - ALIYUN_CLOUDBOX_ID: ${ALIYUN_CLOUDBOX_ID:-your-cloudbox-id} TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-your-bucket-name} TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-your-secret-key} TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id} @@ -985,6 +984,7 @@ services: # Use the shared environment variables. <<: *shared-api-worker-env DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin} + DB_SSL_MODE: ${DB_SSL_MODE:-disable} SERVER_PORT: ${PLUGIN_DAEMON_PORT:-5002} SERVER_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi} MAX_PLUGIN_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}