chore: remove stale mypy suppressions and align dataset service tests (#34130)

This commit is contained in:
99
2026-03-26 20:34:44 +08:00
committed by GitHub
parent 69c2b422de
commit fcfc96ca05
11 changed files with 195 additions and 128 deletions

View File

@@ -1,5 +1,5 @@
from datetime import UTC, datetime, timedelta
from typing import Literal, cast
from typing import Literal, TypedDict, cast
from flask import request
from flask_restx import Resource, fields, marshal_with
@@ -173,6 +173,23 @@ console_ns.schema_model(
)
class HumanInputPauseTypeResponse(TypedDict):
type: Literal["human_input"]
form_id: str
backstage_input_url: str | None
class PausedNodeResponse(TypedDict):
node_id: str
node_title: str
pause_type: HumanInputPauseTypeResponse
class WorkflowPauseDetailsResponse(TypedDict):
paused_at: str | None
paused_nodes: list[PausedNodeResponse]
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
class AdvancedChatAppWorkflowRunListApi(Resource):
@console_ns.doc("get_advanced_chat_workflow_runs")
@@ -490,10 +507,11 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
# Check if workflow is suspended
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
if not is_paused:
return {
empty_response: WorkflowPauseDetailsResponse = {
"paused_at": None,
"paused_nodes": [],
}, 200
}
return empty_response, 200
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
pause_reasons = pause_entity.get_pause_reasons() if pause_entity else []
@@ -503,8 +521,8 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
# Build response
paused_at = pause_entity.paused_at if pause_entity else None
paused_nodes = []
response = {
paused_nodes: list[PausedNodeResponse] = []
response: WorkflowPauseDetailsResponse = {
"paused_at": paused_at.isoformat() + "Z" if paused_at else None,
"paused_nodes": paused_nodes,
}

View File

@@ -15,6 +15,7 @@ from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
@@ -166,6 +167,7 @@ class ConsoleWorkflowEventsApi(Resource):
else:
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app.mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app.mode == AppMode.WORKFLOW:
@@ -202,7 +204,7 @@ class ConsoleWorkflowEventsApi(Resource):
)
def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun):
def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun) -> App:
query = select(App).where(
App.id == workflow_run.app_id,
App.tenant_id == workflow_run.tenant_id,

View File

@@ -5,7 +5,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -22,7 +22,12 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.advanced_chat.generate_task_pipeline import (
AdvancedChatAppGenerateTaskPipeline,
ConversationSnapshot,
MessageSnapshot,
WorkflowSnapshot,
)
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.draft_variable_saver import DraftVariableSaverFactory
from core.app.apps.exc import GenerateTaskStoppedError
@@ -44,7 +49,6 @@ from graphon.runtime import GraphRuntimeState
from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.base import Base
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.workflow_draft_variable_service import (
@@ -524,19 +528,20 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
worker_thread.start()
# release database connection, because the following new thread operations may take a long time
with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session, workflow)
message = _refresh_model(session, message)
# Capture the scalar fields needed by the response pipeline before
# releasing the request-scoped SQLAlchemy session.
workflow_snapshot = WorkflowSnapshot.from_workflow(workflow)
conversation_snapshot = ConversationSnapshot.from_conversation(conversation)
message_snapshot = MessageSnapshot.from_message(message)
db.session.close()
# return response or stream generator
response = self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
workflow=workflow_snapshot,
queue_manager=queue_manager,
conversation=conversation,
message=message,
conversation=conversation_snapshot,
message=message_snapshot,
user=user,
stream=stream,
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
@@ -643,10 +648,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
*,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
workflow: WorkflowSnapshot,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
conversation: ConversationSnapshot,
message: MessageSnapshot,
user: Union[Account, EndUser],
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
@@ -683,13 +688,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
else:
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
raise e
_T = TypeVar("_T", bound=Base)
def _refresh_model(session, model: _T) -> _T:
with Session(bind=db.engine, expire_on_commit=False) as session:
detach_model = session.get(type(model), model.id)
assert detach_model is not None
return detach_model

View File

@@ -4,6 +4,8 @@ import re
import time
from collections.abc import Callable, Generator, Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from threading import Thread
from typing import Any, Union
@@ -79,11 +81,59 @@ from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus
from models.execution_extra_content import HumanInputContent
from models.model import AppMode
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@dataclass(frozen=True, slots=True)
class WorkflowSnapshot:
id: str
tenant_id: str
features_dict: Mapping[str, Any]
@classmethod
def from_workflow(cls, workflow: Workflow) -> "WorkflowSnapshot":
return cls(
id=workflow.id,
tenant_id=workflow.tenant_id,
features_dict=dict(workflow.features_dict),
)
@dataclass(frozen=True, slots=True)
class ConversationSnapshot:
id: str
mode: AppMode
@classmethod
def from_conversation(cls, conversation: Conversation) -> "ConversationSnapshot":
return cls(
id=conversation.id,
mode=conversation.mode,
)
@dataclass(frozen=True, slots=True)
class MessageSnapshot:
id: str
query: str
created_at: datetime
status: MessageStatus
answer: str
@classmethod
def from_message(cls, message: Message) -> "MessageSnapshot":
return cls(
id=message.id,
query=message.query,
created_at=message.created_at,
status=message.status,
answer=message.answer,
)
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
@@ -92,10 +142,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
workflow: WorkflowSnapshot,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
conversation: ConversationSnapshot,
message: MessageSnapshot,
user: Union[Account, EndUser],
stream: bool,
dialogue_count: int,
@@ -156,7 +206,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._message_saved_on_pause = False
self._seed_graph_runtime_state_from_queue_manager()
def _seed_task_state_from_message(self, message: Message) -> None:
def _seed_task_state_from_message(self, message: MessageSnapshot) -> None:
if message.status == MessageStatus.PAUSED and message.answer:
self._task_state.answer = message.answer

View File

@@ -52,6 +52,12 @@ class ReadyQueueProtocol(Protocol):
...
class NodeExecutionProtocol(Protocol):
"""Structural interface for persisted per-node execution state."""
execution_id: str | None
class GraphExecutionProtocol(Protocol):
"""Structural interface for graph execution aggregate.
@@ -67,6 +73,11 @@ class GraphExecutionProtocol(Protocol):
exceptions_count: int
pause_reasons: list[PauseReason]
@property
def node_executions(self) -> Mapping[str, NodeExecutionProtocol]:
"""Return the persisted node execution state keyed by node id."""
...
def start(self) -> None:
"""Transition execution into the running state."""
...

View File

@@ -231,26 +231,6 @@ vdb = [
"holo-search-sdk>=0.4.1",
]
[tool.mypy]
[[tool.mypy.overrides]]
# targeted ignores for current type-check errors
# TODO(QuantumGhost): suppress type errors in HITL related code.
# fix the type error later
module = [
"configs.middleware.cache.redis_pubsub_config",
"extensions.ext_redis",
"tasks.workflow_execution_tasks",
"graphon.nodes.base.node",
"services.human_input_delivery_test_service",
"core.app.apps.advanced_chat.app_generator",
"controllers.console.human_input_form",
"controllers.console.app.workflow_run",
"repositories.sqlalchemy_api_workflow_node_execution_repository",
"extensions.logstore.repositories.logstore_api_workflow_run_repository",
]
ignore_errors = true
[tool.pyrefly]
project-includes = ["."]
project-excludes = [".venv", "migrations/"]

View File

@@ -220,7 +220,7 @@ class EmailDeliveryTestHandler:
stmt = stmt.where(Account.id.in_(unique_ids))
with self._session_factory() as session:
rows = session.execute(stmt).all()
rows = session.execute(stmt).tuples().all()
return dict(rows)
@staticmethod

View File

@@ -9,10 +9,17 @@ from pydantic import BaseModel, ValidationError
from constants import UUID_NIL
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator, _refresh_model
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.advanced_chat.generate_task_pipeline import (
ConversationSnapshot,
MessageSnapshot,
WorkflowSnapshot,
)
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.ops.ops_trace_manager import TraceQueueManager
from libs.datetime_utils import naive_utc_now
from models.enums import MessageStatus
from models.model import AppMode
@@ -363,8 +370,15 @@ class TestAdvancedChatAppGeneratorInternals:
workflow_run_id="run-id",
)
workflow = SimpleNamespace(id="wf-1", tenant_id="tenant", features={"feature": True}, features_dict={})
conversation = SimpleNamespace(id="conv-1", mode=AppMode.ADVANCED_CHAT, override_model_configs=None)
message = SimpleNamespace(id="msg-1")
message = SimpleNamespace(
id="msg-1",
query="hello",
created_at=naive_utc_now(),
status=MessageStatus.NORMAL,
answer="",
)
db_session = SimpleNamespace(commit=MagicMock(), refresh=MagicMock(), close=MagicMock())
captured: dict[str, object] = {}
thread_data: dict[str, object] = {}
@@ -394,19 +408,6 @@ class TestAdvancedChatAppGeneratorInternals:
thread_data["started"] = True
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread)
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model)
class _Session:
def __init__(self, *args, **kwargs):
_ = args, kwargs
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
monkeypatch.setattr(
"core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session)
)
@@ -424,7 +425,7 @@ class TestAdvancedChatAppGeneratorInternals:
pause_state_config = SimpleNamespace(session_factory="session-factory", state_owner_user_id="owner")
response = generator._generate(
workflow=SimpleNamespace(features={"feature": True}),
workflow=workflow,
user=SimpleNamespace(id="user"),
invoke_from=InvokeFrom.WEB_APP,
application_generate_entity=application_generate_entity,
@@ -444,6 +445,9 @@ class TestAdvancedChatAppGeneratorInternals:
db_session.refresh.assert_called_once_with(conversation)
db_session.close.assert_called_once()
assert captured["draft_var_saver_factory"] == "draft-factory"
assert isinstance(captured["workflow"], WorkflowSnapshot)
assert isinstance(captured["conversation"], ConversationSnapshot)
assert isinstance(captured["message"], MessageSnapshot)
def test_generate_internal_flow_with_existing_records_skips_init(self, monkeypatch):
generator = AdvancedChatAppGenerator()
@@ -464,8 +468,15 @@ class TestAdvancedChatAppGeneratorInternals:
workflow_run_id="run-id",
)
workflow = SimpleNamespace(id="wf-2", tenant_id="tenant", features={}, features_dict={})
conversation = SimpleNamespace(id="conv-2", mode=AppMode.ADVANCED_CHAT, override_model_configs=None)
message = SimpleNamespace(id="msg-2")
message = SimpleNamespace(
id="msg-2",
query="hello",
created_at=naive_utc_now(),
status=MessageStatus.NORMAL,
answer="",
)
db_session = SimpleNamespace(close=MagicMock(), commit=MagicMock(), refresh=MagicMock())
init_records = MagicMock()
thread_data: dict[str, object] = {}
@@ -491,19 +502,6 @@ class TestAdvancedChatAppGeneratorInternals:
thread_data["started"] = True
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread)
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model)
class _Session:
def __init__(self, *args, **kwargs):
_ = args, kwargs
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
monkeypatch.setattr(
"core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session)
)
@@ -519,7 +517,7 @@ class TestAdvancedChatAppGeneratorInternals:
)
response = generator._generate(
workflow=SimpleNamespace(features={}),
workflow=workflow,
user=SimpleNamespace(id="user"),
invoke_from=InvokeFrom.WEB_APP,
application_generate_entity=application_generate_entity,
@@ -940,10 +938,16 @@ class TestAdvancedChatAppGeneratorInternals:
with pytest.raises(GenerateTaskStoppedError):
generator._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,
workflow=SimpleNamespace(),
workflow=WorkflowSnapshot(id="wf", tenant_id="tenant", features_dict={}),
queue_manager=SimpleNamespace(),
conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT),
message=SimpleNamespace(id="msg"),
conversation=ConversationSnapshot(id="conv", mode=AppMode.ADVANCED_CHAT),
message=MessageSnapshot(
id="msg",
query="hello",
created_at=naive_utc_now(),
status=MessageStatus.NORMAL,
answer="",
),
user=SimpleNamespace(),
draft_var_saver_factory=lambda **kwargs: None,
stream=False,
@@ -981,10 +985,16 @@ class TestAdvancedChatAppGeneratorInternals:
with pytest.raises(ValueError, match="other error"):
generator._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,
workflow=SimpleNamespace(),
workflow=WorkflowSnapshot(id="wf", tenant_id="tenant", features_dict={}),
queue_manager=SimpleNamespace(),
conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT),
message=SimpleNamespace(id="msg"),
conversation=ConversationSnapshot(id="conv", mode=AppMode.ADVANCED_CHAT),
message=MessageSnapshot(
id="msg",
query="hello",
created_at=naive_utc_now(),
status=MessageStatus.NORMAL,
answer="",
),
user=SimpleNamespace(),
draft_var_saver_factory=lambda **kwargs: None,
stream=False,
@@ -992,31 +1002,6 @@ class TestAdvancedChatAppGeneratorInternals:
logger_exception.assert_called_once()
def test_refresh_model_returns_detached_model(self, monkeypatch):
source_model = SimpleNamespace(id="source-id")
detached_model = SimpleNamespace(id="source-id", detached=True)
class _Session:
def __init__(self, *args, **kwargs):
_ = args, kwargs
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def get(self, model_type, model_id):
_ = model_type
return detached_model if model_id == "source-id" else None
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object()))
refreshed = _refresh_model(session=SimpleNamespace(), model=source_model)
assert refreshed is detached_model
def test_generate_worker_handles_invoke_auth_error(self, monkeypatch):
generator = AdvancedChatAppGenerator()
generator._dialogue_count = 1

View File

@@ -21,7 +21,7 @@ from graphon.entities.pause_reason import HumanInputRequired
from graphon.enums import WorkflowExecutionStatus
from models.enums import MessageStatus
from models.execution_extra_content import HumanInputContent
from models.model import EndUser
from models.model import AppMode, EndUser
def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline:
@@ -159,8 +159,8 @@ def test_resume_appends_chunks_to_paused_answer() -> None:
task_id="task-1",
)
queue_manager = SimpleNamespace(graph_runtime_state=None)
conversation = SimpleNamespace(id="conversation-1", mode="advanced-chat")
message = SimpleNamespace(
conversation = pipeline_module.ConversationSnapshot(id="conversation-1", mode=AppMode.ADVANCED_CHAT)
message = pipeline_module.MessageSnapshot(
id="message-1",
created_at=datetime(2024, 1, 1),
query="hello",
@@ -170,7 +170,7 @@ def test_resume_appends_chunks_to_paused_answer() -> None:
user = EndUser()
user.id = "user-1"
user.session_id = "session-1"
workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={})
workflow = pipeline_module.WorkflowSnapshot(id="workflow-1", tenant_id="tenant-1", features_dict={})
pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
@@ -184,14 +184,33 @@ def test_resume_appends_chunks_to_paused_answer() -> None:
draft_var_saver_factory=SimpleNamespace(),
)
pipeline._get_message = mock.Mock(return_value=message)
stored_message = SimpleNamespace(
id="message-1",
answer="before",
status=MessageStatus.PAUSED,
updated_at=None,
provider_response_latency=0,
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
total_price=0,
currency="USD",
message_metadata=None,
invoke_from=InvokeFrom.WEB_APP,
from_account_id=None,
from_end_user_id="user-1",
)
pipeline._get_message = mock.Mock(return_value=stored_message)
pipeline._recorded_files = []
list(pipeline._handle_text_chunk_event(QueueTextChunkEvent(text="after")))
pipeline._save_message(session=mock.Mock())
assert message.answer == "beforeafter"
assert message.status == MessageStatus.NORMAL
assert stored_message.answer == "beforeafter"
assert stored_message.status == MessageStatus.NORMAL
def test_workflow_succeeded_emits_message_end_before_workflow_finished() -> None:

View File

@@ -6,7 +6,12 @@ from types import SimpleNamespace
import pytest
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.advanced_chat.generate_task_pipeline import (
AdvancedChatAppGenerateTaskPipeline,
ConversationSnapshot,
MessageSnapshot,
WorkflowSnapshot,
)
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent,
@@ -73,15 +78,15 @@ def _make_pipeline():
workflow_run_id="run-id",
)
message = SimpleNamespace(
message = MessageSnapshot(
id="message-id",
query="hello",
created_at=naive_utc_now(),
status=MessageStatus.NORMAL,
answer="",
)
conversation = SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT)
workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={})
conversation = ConversationSnapshot(id="conv-id", mode=AppMode.ADVANCED_CHAT)
workflow = WorkflowSnapshot(id="workflow-id", tenant_id="tenant", features_dict={})
user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session")
pipeline = AdvancedChatAppGenerateTaskPipeline(

View File

@@ -302,8 +302,10 @@ class TestEmailDeliveryTestHandler:
# user_ids is None (all)
mock_execute = MagicMock()
mock_tuples = MagicMock()
mock_session.execute.return_value = mock_execute
mock_execute.all.return_value = [("u1", "u1@example.com")]
mock_execute.tuples.return_value = mock_tuples
mock_tuples.all.return_value = [("u1", "u1@example.com")]
result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=None)
assert result == {"u1": "u1@example.com"}