From fcfc96ca052789a63e3df6c3abf36f539c928cb5 Mon Sep 17 00:00:00 2001 From: 99 Date: Thu, 26 Mar 2026 20:34:44 +0800 Subject: [PATCH] chore: remove stale mypy suppressions and align dataset service tests (#34130) --- api/controllers/console/app/workflow_run.py | 28 ++++- api/controllers/console/human_input_form.py | 4 +- .../app/apps/advanced_chat/app_generator.py | 41 +++---- .../advanced_chat/generate_task_pipeline.py | 58 +++++++++- api/graphon/runtime/graph_runtime_state.py | 11 ++ api/pyproject.toml | 20 ---- .../human_input_delivery_test_service.py | 2 +- .../apps/advanced_chat/test_app_generator.py | 109 ++++++++---------- .../test_generate_task_pipeline.py | 33 ++++-- .../test_generate_task_pipeline_core.py | 13 ++- .../test_human_input_delivery_test_service.py | 4 +- 11 files changed, 195 insertions(+), 128 deletions(-) diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 29fa96c4e63..d1df7227293 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,5 +1,5 @@ from datetime import UTC, datetime, timedelta -from typing import Literal, cast +from typing import Literal, TypedDict, cast from flask import request from flask_restx import Resource, fields, marshal_with @@ -173,6 +173,23 @@ console_ns.schema_model( ) +class HumanInputPauseTypeResponse(TypedDict): + type: Literal["human_input"] + form_id: str + backstage_input_url: str | None + + +class PausedNodeResponse(TypedDict): + node_id: str + node_title: str + pause_type: HumanInputPauseTypeResponse + + +class WorkflowPauseDetailsResponse(TypedDict): + paused_at: str | None + paused_nodes: list[PausedNodeResponse] + + @console_ns.route("/apps//advanced-chat/workflow-runs") class AdvancedChatAppWorkflowRunListApi(Resource): @console_ns.doc("get_advanced_chat_workflow_runs") @@ -490,10 +507,11 @@ class ConsoleWorkflowPauseDetailsApi(Resource): # Check if workflow is suspended is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED if not is_paused: - return { + empty_response: WorkflowPauseDetailsResponse = { "paused_at": None, "paused_nodes": [], - }, 200 + } + return empty_response, 200 pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) pause_reasons = pause_entity.get_pause_reasons() if pause_entity else [] @@ -503,8 +521,8 @@ class ConsoleWorkflowPauseDetailsApi(Resource): # Build response paused_at = pause_entity.paused_at if pause_entity else None - paused_nodes = [] - response = { + paused_nodes: list[PausedNodeResponse] = [] + response: WorkflowPauseDetailsResponse = { "paused_at": paused_at.isoformat() + "Z" if paused_at else None, "paused_nodes": paused_nodes, } diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 7207f7fd1d5..e37e78c966f 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -15,6 +15,7 @@ from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.apps.message_generator import MessageGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator @@ -166,6 +167,7 @@ class ConsoleWorkflowEventsApi(Resource): else: msg_generator = MessageGenerator() + generator: BaseAppGenerator if app.mode == AppMode.ADVANCED_CHAT: generator = AdvancedChatAppGenerator() elif app.mode == AppMode.WORKFLOW: @@ -202,7 +204,7 @@ class ConsoleWorkflowEventsApi(Resource): ) -def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun): +def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun) -> App: query = select(App).where( App.id == workflow_run.app_id, App.tenant_id == workflow_run.tenant_id, diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 853cbb426ce..d69a80e4a93 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -5,7 +5,7 @@ import logging import threading import uuid from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -22,7 +22,12 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter -from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.apps.advanced_chat.generate_task_pipeline import ( + AdvancedChatAppGenerateTaskPipeline, + ConversationSnapshot, + MessageSnapshot, + WorkflowSnapshot, +) from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError @@ -44,7 +49,6 @@ from graphon.runtime import GraphRuntimeState from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom -from models.base import Base from models.enums import WorkflowRunTriggeredFrom from services.conversation_service import ConversationService from services.workflow_draft_variable_service import ( @@ -524,19 +528,20 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): worker_thread.start() - # release database connection, because the following new thread operations may take a long time - with Session(bind=db.engine, expire_on_commit=False) as session: - workflow = _refresh_model(session, workflow) - message = _refresh_model(session, message) + # Capture the scalar fields needed by the response pipeline before + # releasing the request-scoped SQLAlchemy session. + workflow_snapshot = WorkflowSnapshot.from_workflow(workflow) + conversation_snapshot = ConversationSnapshot.from_conversation(conversation) + message_snapshot = MessageSnapshot.from_message(message) db.session.close() # return response or stream generator response = self._handle_advanced_chat_response( application_generate_entity=application_generate_entity, - workflow=workflow, + workflow=workflow_snapshot, queue_manager=queue_manager, - conversation=conversation, - message=message, + conversation=conversation_snapshot, + message=message_snapshot, user=user, stream=stream, draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user), @@ -643,10 +648,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): self, *, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, + workflow: WorkflowSnapshot, queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, + conversation: ConversationSnapshot, + message: MessageSnapshot, user: Union[Account, EndUser], draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, @@ -683,13 +688,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): else: logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id) raise e - - -_T = TypeVar("_T", bound=Base) - - -def _refresh_model(session, model: _T) -> _T: - with Session(bind=db.engine, expire_on_commit=False) as session: - detach_model = session.get(type(model), model.id) - assert detach_model is not None - return detach_model diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 51febed32ab..3577ae139bf 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,6 +4,8 @@ import re import time from collections.abc import Callable, Generator, Mapping from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime from threading import Thread from typing import Any, Union @@ -79,11 +81,59 @@ from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus from models.execution_extra_content import HumanInputContent +from models.model import AppMode from models.workflow import Workflow logger = logging.getLogger(__name__) +@dataclass(frozen=True, slots=True) +class WorkflowSnapshot: + id: str + tenant_id: str + features_dict: Mapping[str, Any] + + @classmethod + def from_workflow(cls, workflow: Workflow) -> "WorkflowSnapshot": + return cls( + id=workflow.id, + tenant_id=workflow.tenant_id, + features_dict=dict(workflow.features_dict), + ) + + +@dataclass(frozen=True, slots=True) +class ConversationSnapshot: + id: str + mode: AppMode + + @classmethod + def from_conversation(cls, conversation: Conversation) -> "ConversationSnapshot": + return cls( + id=conversation.id, + mode=conversation.mode, + ) + + +@dataclass(frozen=True, slots=True) +class MessageSnapshot: + id: str + query: str + created_at: datetime + status: MessageStatus + answer: str + + @classmethod + def from_message(cls, message: Message) -> "MessageSnapshot": + return cls( + id=message.id, + query=message.query, + created_at=message.created_at, + status=message.status, + answer=message.answer, + ) + + class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. @@ -92,10 +142,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def __init__( self, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, + workflow: WorkflowSnapshot, queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, + conversation: ConversationSnapshot, + message: MessageSnapshot, user: Union[Account, EndUser], stream: bool, dialogue_count: int, @@ -156,7 +206,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._message_saved_on_pause = False self._seed_graph_runtime_state_from_queue_manager() - def _seed_task_state_from_message(self, message: Message) -> None: + def _seed_task_state_from_message(self, message: MessageSnapshot) -> None: if message.status == MessageStatus.PAUSED and message.answer: self._task_state.answer = message.answer diff --git a/api/graphon/runtime/graph_runtime_state.py b/api/graphon/runtime/graph_runtime_state.py index 6e4ed202b5a..8453830f284 100644 --- a/api/graphon/runtime/graph_runtime_state.py +++ b/api/graphon/runtime/graph_runtime_state.py @@ -52,6 +52,12 @@ class ReadyQueueProtocol(Protocol): ... +class NodeExecutionProtocol(Protocol): + """Structural interface for persisted per-node execution state.""" + + execution_id: str | None + + class GraphExecutionProtocol(Protocol): """Structural interface for graph execution aggregate. @@ -67,6 +73,11 @@ class GraphExecutionProtocol(Protocol): exceptions_count: int pause_reasons: list[PauseReason] + @property + def node_executions(self) -> Mapping[str, NodeExecutionProtocol]: + """Return the persisted node execution state keyed by node id.""" + ... + def start(self) -> None: """Transition execution into the running state.""" ... diff --git a/api/pyproject.toml b/api/pyproject.toml index 0398376ee26..b1f1f4bb2e8 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -231,26 +231,6 @@ vdb = [ "holo-search-sdk>=0.4.1", ] -[tool.mypy] - -[[tool.mypy.overrides]] -# targeted ignores for current type-check errors -# TODO(QuantumGhost): suppress type errors in HITL related code. -# fix the type error later -module = [ - "configs.middleware.cache.redis_pubsub_config", - "extensions.ext_redis", - "tasks.workflow_execution_tasks", - "graphon.nodes.base.node", - "services.human_input_delivery_test_service", - "core.app.apps.advanced_chat.app_generator", - "controllers.console.human_input_form", - "controllers.console.app.workflow_run", - "repositories.sqlalchemy_api_workflow_node_execution_repository", - "extensions.logstore.repositories.logstore_api_workflow_run_repository", -] -ignore_errors = true - [tool.pyrefly] project-includes = ["."] project-excludes = [".venv", "migrations/"] diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 861d952c932..68ef67dec1c 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -220,7 +220,7 @@ class EmailDeliveryTestHandler: stmt = stmt.where(Account.id.in_(unique_ids)) with self._session_factory() as session: - rows = session.execute(stmt).all() + rows = session.execute(stmt).tuples().all() return dict(rows) @staticmethod diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py index 8b0ff7b6c19..af5d203f126 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py @@ -9,10 +9,17 @@ from pydantic import BaseModel, ValidationError from constants import UUID_NIL from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig -from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator, _refresh_model +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.advanced_chat.generate_task_pipeline import ( + ConversationSnapshot, + MessageSnapshot, + WorkflowSnapshot, +) from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager +from libs.datetime_utils import naive_utc_now +from models.enums import MessageStatus from models.model import AppMode @@ -363,8 +370,15 @@ class TestAdvancedChatAppGeneratorInternals: workflow_run_id="run-id", ) + workflow = SimpleNamespace(id="wf-1", tenant_id="tenant", features={"feature": True}, features_dict={}) conversation = SimpleNamespace(id="conv-1", mode=AppMode.ADVANCED_CHAT, override_model_configs=None) - message = SimpleNamespace(id="msg-1") + message = SimpleNamespace( + id="msg-1", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ) db_session = SimpleNamespace(commit=MagicMock(), refresh=MagicMock(), close=MagicMock()) captured: dict[str, object] = {} thread_data: dict[str, object] = {} @@ -394,19 +408,6 @@ class TestAdvancedChatAppGeneratorInternals: thread_data["started"] = True monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread) - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model) - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) monkeypatch.setattr( "core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session) ) @@ -424,7 +425,7 @@ class TestAdvancedChatAppGeneratorInternals: pause_state_config = SimpleNamespace(session_factory="session-factory", state_owner_user_id="owner") response = generator._generate( - workflow=SimpleNamespace(features={"feature": True}), + workflow=workflow, user=SimpleNamespace(id="user"), invoke_from=InvokeFrom.WEB_APP, application_generate_entity=application_generate_entity, @@ -444,6 +445,9 @@ class TestAdvancedChatAppGeneratorInternals: db_session.refresh.assert_called_once_with(conversation) db_session.close.assert_called_once() assert captured["draft_var_saver_factory"] == "draft-factory" + assert isinstance(captured["workflow"], WorkflowSnapshot) + assert isinstance(captured["conversation"], ConversationSnapshot) + assert isinstance(captured["message"], MessageSnapshot) def test_generate_internal_flow_with_existing_records_skips_init(self, monkeypatch): generator = AdvancedChatAppGenerator() @@ -464,8 +468,15 @@ class TestAdvancedChatAppGeneratorInternals: workflow_run_id="run-id", ) + workflow = SimpleNamespace(id="wf-2", tenant_id="tenant", features={}, features_dict={}) conversation = SimpleNamespace(id="conv-2", mode=AppMode.ADVANCED_CHAT, override_model_configs=None) - message = SimpleNamespace(id="msg-2") + message = SimpleNamespace( + id="msg-2", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ) db_session = SimpleNamespace(close=MagicMock(), commit=MagicMock(), refresh=MagicMock()) init_records = MagicMock() thread_data: dict[str, object] = {} @@ -491,19 +502,6 @@ class TestAdvancedChatAppGeneratorInternals: thread_data["started"] = True monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread) - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model) - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) monkeypatch.setattr( "core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session) ) @@ -519,7 +517,7 @@ class TestAdvancedChatAppGeneratorInternals: ) response = generator._generate( - workflow=SimpleNamespace(features={}), + workflow=workflow, user=SimpleNamespace(id="user"), invoke_from=InvokeFrom.WEB_APP, application_generate_entity=application_generate_entity, @@ -940,10 +938,16 @@ class TestAdvancedChatAppGeneratorInternals: with pytest.raises(GenerateTaskStoppedError): generator._handle_advanced_chat_response( application_generate_entity=application_generate_entity, - workflow=SimpleNamespace(), + workflow=WorkflowSnapshot(id="wf", tenant_id="tenant", features_dict={}), queue_manager=SimpleNamespace(), - conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT), - message=SimpleNamespace(id="msg"), + conversation=ConversationSnapshot(id="conv", mode=AppMode.ADVANCED_CHAT), + message=MessageSnapshot( + id="msg", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ), user=SimpleNamespace(), draft_var_saver_factory=lambda **kwargs: None, stream=False, @@ -981,10 +985,16 @@ class TestAdvancedChatAppGeneratorInternals: with pytest.raises(ValueError, match="other error"): generator._handle_advanced_chat_response( application_generate_entity=application_generate_entity, - workflow=SimpleNamespace(), + workflow=WorkflowSnapshot(id="wf", tenant_id="tenant", features_dict={}), queue_manager=SimpleNamespace(), - conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT), - message=SimpleNamespace(id="msg"), + conversation=ConversationSnapshot(id="conv", mode=AppMode.ADVANCED_CHAT), + message=MessageSnapshot( + id="msg", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ), user=SimpleNamespace(), draft_var_saver_factory=lambda **kwargs: None, stream=False, @@ -992,31 +1002,6 @@ class TestAdvancedChatAppGeneratorInternals: logger_exception.assert_called_once() - def test_refresh_model_returns_detached_model(self, monkeypatch): - source_model = SimpleNamespace(id="source-id") - detached_model = SimpleNamespace(id="source-id", detached=True) - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def get(self, model_type, model_id): - _ = model_type - return detached_model if model_id == "source-id" else None - - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object())) - - refreshed = _refresh_model(session=SimpleNamespace(), model=source_model) - - assert refreshed is detached_model - def test_generate_worker_handles_invoke_auth_error(self, monkeypatch): generator = AdvancedChatAppGenerator() generator._dialogue_count = 1 diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py index 56919d7f651..99a386cd45e 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py @@ -21,7 +21,7 @@ from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import WorkflowExecutionStatus from models.enums import MessageStatus from models.execution_extra_content import HumanInputContent -from models.model import EndUser +from models.model import AppMode, EndUser def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline: @@ -159,8 +159,8 @@ def test_resume_appends_chunks_to_paused_answer() -> None: task_id="task-1", ) queue_manager = SimpleNamespace(graph_runtime_state=None) - conversation = SimpleNamespace(id="conversation-1", mode="advanced-chat") - message = SimpleNamespace( + conversation = pipeline_module.ConversationSnapshot(id="conversation-1", mode=AppMode.ADVANCED_CHAT) + message = pipeline_module.MessageSnapshot( id="message-1", created_at=datetime(2024, 1, 1), query="hello", @@ -170,7 +170,7 @@ def test_resume_appends_chunks_to_paused_answer() -> None: user = EndUser() user.id = "user-1" user.session_id = "session-1" - workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={}) + workflow = pipeline_module.WorkflowSnapshot(id="workflow-1", tenant_id="tenant-1", features_dict={}) pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline( application_generate_entity=application_generate_entity, @@ -184,14 +184,33 @@ def test_resume_appends_chunks_to_paused_answer() -> None: draft_var_saver_factory=SimpleNamespace(), ) - pipeline._get_message = mock.Mock(return_value=message) + stored_message = SimpleNamespace( + id="message-1", + answer="before", + status=MessageStatus.PAUSED, + updated_at=None, + provider_response_latency=0, + message_tokens=0, + message_unit_price=0, + message_price_unit=0, + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0, + total_price=0, + currency="USD", + message_metadata=None, + invoke_from=InvokeFrom.WEB_APP, + from_account_id=None, + from_end_user_id="user-1", + ) + pipeline._get_message = mock.Mock(return_value=stored_message) pipeline._recorded_files = [] list(pipeline._handle_text_chunk_event(QueueTextChunkEvent(text="after"))) pipeline._save_message(session=mock.Mock()) - assert message.answer == "beforeafter" - assert message.status == MessageStatus.NORMAL + assert stored_message.answer == "beforeafter" + assert stored_message.status == MessageStatus.NORMAL def test_workflow_succeeded_emits_message_end_before_workflow_finished() -> None: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 3baefd64d64..29fd63c063a 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -6,7 +6,12 @@ from types import SimpleNamespace import pytest from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig -from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.apps.advanced_chat.generate_task_pipeline import ( + AdvancedChatAppGenerateTaskPipeline, + ConversationSnapshot, + MessageSnapshot, + WorkflowSnapshot, +) from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.queue_entities import ( QueueAdvancedChatMessageEndEvent, @@ -73,15 +78,15 @@ def _make_pipeline(): workflow_run_id="run-id", ) - message = SimpleNamespace( + message = MessageSnapshot( id="message-id", query="hello", created_at=naive_utc_now(), status=MessageStatus.NORMAL, answer="", ) - conversation = SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT) - workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}) + conversation = ConversationSnapshot(id="conv-id", mode=AppMode.ADVANCED_CHAT) + workflow = WorkflowSnapshot(id="workflow-id", tenant_id="tenant", features_dict={}) user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session") pipeline = AdvancedChatAppGenerateTaskPipeline( diff --git a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py index 2c0f5618601..ce40756f12d 100644 --- a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py @@ -302,8 +302,10 @@ class TestEmailDeliveryTestHandler: # user_ids is None (all) mock_execute = MagicMock() + mock_tuples = MagicMock() mock_session.execute.return_value = mock_execute - mock_execute.all.return_value = [("u1", "u1@example.com")] + mock_execute.tuples.return_value = mock_tuples + mock_tuples.all.return_value = [("u1", "u1@example.com")] result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=None) assert result == {"u1": "u1@example.com"}