chore: remove stale mypy suppressions and align dataset service tests (#34130)

This commit is contained in:
99
2026-03-26 20:34:44 +08:00
committed by GitHub
parent 69c2b422de
commit fcfc96ca05
11 changed files with 195 additions and 128 deletions

View File

@@ -5,7 +5,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -22,7 +22,12 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.advanced_chat.generate_task_pipeline import (
AdvancedChatAppGenerateTaskPipeline,
ConversationSnapshot,
MessageSnapshot,
WorkflowSnapshot,
)
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.draft_variable_saver import DraftVariableSaverFactory
from core.app.apps.exc import GenerateTaskStoppedError
@@ -44,7 +49,6 @@ from graphon.runtime import GraphRuntimeState
from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.base import Base
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.workflow_draft_variable_service import (
@@ -524,19 +528,20 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
worker_thread.start()
# release database connection, because the following new thread operations may take a long time
with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session, workflow)
message = _refresh_model(session, message)
# Capture the scalar fields needed by the response pipeline before
# releasing the request-scoped SQLAlchemy session.
workflow_snapshot = WorkflowSnapshot.from_workflow(workflow)
conversation_snapshot = ConversationSnapshot.from_conversation(conversation)
message_snapshot = MessageSnapshot.from_message(message)
db.session.close()
# return response or stream generator
response = self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
workflow=workflow_snapshot,
queue_manager=queue_manager,
conversation=conversation,
message=message,
conversation=conversation_snapshot,
message=message_snapshot,
user=user,
stream=stream,
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
@@ -643,10 +648,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
*,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
workflow: WorkflowSnapshot,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
conversation: ConversationSnapshot,
message: MessageSnapshot,
user: Union[Account, EndUser],
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
@@ -683,13 +688,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
else:
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
raise e
_T = TypeVar("_T", bound=Base)
def _refresh_model(session, model: _T) -> _T:
with Session(bind=db.engine, expire_on_commit=False) as session:
detach_model = session.get(type(model), model.id)
assert detach_model is not None
return detach_model

View File

@@ -4,6 +4,8 @@ import re
import time
from collections.abc import Callable, Generator, Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from threading import Thread
from typing import Any, Union
@@ -79,11 +81,59 @@ from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus
from models.execution_extra_content import HumanInputContent
from models.model import AppMode
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@dataclass(frozen=True, slots=True)
class WorkflowSnapshot:
id: str
tenant_id: str
features_dict: Mapping[str, Any]
@classmethod
def from_workflow(cls, workflow: Workflow) -> "WorkflowSnapshot":
return cls(
id=workflow.id,
tenant_id=workflow.tenant_id,
features_dict=dict(workflow.features_dict),
)
@dataclass(frozen=True, slots=True)
class ConversationSnapshot:
id: str
mode: AppMode
@classmethod
def from_conversation(cls, conversation: Conversation) -> "ConversationSnapshot":
return cls(
id=conversation.id,
mode=conversation.mode,
)
@dataclass(frozen=True, slots=True)
class MessageSnapshot:
id: str
query: str
created_at: datetime
status: MessageStatus
answer: str
@classmethod
def from_message(cls, message: Message) -> "MessageSnapshot":
return cls(
id=message.id,
query=message.query,
created_at=message.created_at,
status=message.status,
answer=message.answer,
)
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
@@ -92,10 +142,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
workflow: WorkflowSnapshot,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
conversation: ConversationSnapshot,
message: MessageSnapshot,
user: Union[Account, EndUser],
stream: bool,
dialogue_count: int,
@@ -156,7 +206,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._message_saved_on_pause = False
self._seed_graph_runtime_state_from_queue_manager()
def _seed_task_state_from_message(self, message: Message) -> None:
def _seed_task_state_from_message(self, message: MessageSnapshot) -> None:
if message.status == MessageStatus.PAUSED and message.answer:
self._task_state.answer = message.answer