mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 06:06:13 +08:00
chore: remove stale mypy suppressions and align dataset service tests (#34130)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal, cast
|
||||
from typing import Literal, TypedDict, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
@@ -173,6 +173,23 @@ console_ns.schema_model(
|
||||
)
|
||||
|
||||
|
||||
class HumanInputPauseTypeResponse(TypedDict):
|
||||
type: Literal["human_input"]
|
||||
form_id: str
|
||||
backstage_input_url: str | None
|
||||
|
||||
|
||||
class PausedNodeResponse(TypedDict):
|
||||
node_id: str
|
||||
node_title: str
|
||||
pause_type: HumanInputPauseTypeResponse
|
||||
|
||||
|
||||
class WorkflowPauseDetailsResponse(TypedDict):
|
||||
paused_at: str | None
|
||||
paused_nodes: list[PausedNodeResponse]
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||
class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
@console_ns.doc("get_advanced_chat_workflow_runs")
|
||||
@@ -490,10 +507,11 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
# Check if workflow is suspended
|
||||
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
if not is_paused:
|
||||
return {
|
||||
empty_response: WorkflowPauseDetailsResponse = {
|
||||
"paused_at": None,
|
||||
"paused_nodes": [],
|
||||
}, 200
|
||||
}
|
||||
return empty_response, 200
|
||||
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
|
||||
pause_reasons = pause_entity.get_pause_reasons() if pause_entity else []
|
||||
@@ -503,8 +521,8 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
|
||||
# Build response
|
||||
paused_at = pause_entity.paused_at if pause_entity else None
|
||||
paused_nodes = []
|
||||
response = {
|
||||
paused_nodes: list[PausedNodeResponse] = []
|
||||
response: WorkflowPauseDetailsResponse = {
|
||||
"paused_at": paused_at.isoformat() + "Z" if paused_at else None,
|
||||
"paused_nodes": paused_nodes,
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
@@ -166,6 +167,7 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
|
||||
else:
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app.mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app.mode == AppMode.WORKFLOW:
|
||||
@@ -202,7 +204,7 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun):
|
||||
def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun) -> App:
|
||||
query = select(App).where(
|
||||
App.id == workflow_run.app_id,
|
||||
App.tenant_id == workflow_run.tenant_id,
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -22,7 +22,12 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import (
|
||||
AdvancedChatAppGenerateTaskPipeline,
|
||||
ConversationSnapshot,
|
||||
MessageSnapshot,
|
||||
WorkflowSnapshot,
|
||||
)
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.draft_variable_saver import DraftVariableSaverFactory
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
@@ -44,7 +49,6 @@ from graphon.runtime import GraphRuntimeState
|
||||
from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.base import Base
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.workflow_draft_variable_service import (
|
||||
@@ -524,19 +528,20 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
workflow = _refresh_model(session, workflow)
|
||||
message = _refresh_model(session, message)
|
||||
# Capture the scalar fields needed by the response pipeline before
|
||||
# releasing the request-scoped SQLAlchemy session.
|
||||
workflow_snapshot = WorkflowSnapshot.from_workflow(workflow)
|
||||
conversation_snapshot = ConversationSnapshot.from_conversation(conversation)
|
||||
message_snapshot = MessageSnapshot.from_message(message)
|
||||
db.session.close()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_advanced_chat_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
workflow=workflow_snapshot,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
conversation=conversation_snapshot,
|
||||
message=message_snapshot,
|
||||
user=user,
|
||||
stream=stream,
|
||||
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
|
||||
@@ -643,10 +648,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
workflow: WorkflowSnapshot,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
conversation: ConversationSnapshot,
|
||||
message: MessageSnapshot,
|
||||
user: Union[Account, EndUser],
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
@@ -683,13 +688,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
else:
|
||||
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
|
||||
raise e
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=Base)
|
||||
|
||||
|
||||
def _refresh_model(session, model: _T) -> _T:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
detach_model = session.get(type(model), model.id)
|
||||
assert detach_model is not None
|
||||
return detach_model
|
||||
|
||||
@@ -4,6 +4,8 @@ import re
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
from typing import Any, Union
|
||||
|
||||
@@ -79,11 +81,59 @@ from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class WorkflowSnapshot:
|
||||
id: str
|
||||
tenant_id: str
|
||||
features_dict: Mapping[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_workflow(cls, workflow: Workflow) -> "WorkflowSnapshot":
|
||||
return cls(
|
||||
id=workflow.id,
|
||||
tenant_id=workflow.tenant_id,
|
||||
features_dict=dict(workflow.features_dict),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ConversationSnapshot:
|
||||
id: str
|
||||
mode: AppMode
|
||||
|
||||
@classmethod
|
||||
def from_conversation(cls, conversation: Conversation) -> "ConversationSnapshot":
|
||||
return cls(
|
||||
id=conversation.id,
|
||||
mode=conversation.mode,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class MessageSnapshot:
|
||||
id: str
|
||||
query: str
|
||||
created_at: datetime
|
||||
status: MessageStatus
|
||||
answer: str
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, message: Message) -> "MessageSnapshot":
|
||||
return cls(
|
||||
id=message.id,
|
||||
query=message.query,
|
||||
created_at=message.created_at,
|
||||
status=message.status,
|
||||
answer=message.answer,
|
||||
)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
@@ -92,10 +142,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
workflow: WorkflowSnapshot,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
conversation: ConversationSnapshot,
|
||||
message: MessageSnapshot,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
dialogue_count: int,
|
||||
@@ -156,7 +206,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self._message_saved_on_pause = False
|
||||
self._seed_graph_runtime_state_from_queue_manager()
|
||||
|
||||
def _seed_task_state_from_message(self, message: Message) -> None:
|
||||
def _seed_task_state_from_message(self, message: MessageSnapshot) -> None:
|
||||
if message.status == MessageStatus.PAUSED and message.answer:
|
||||
self._task_state.answer = message.answer
|
||||
|
||||
|
||||
@@ -52,6 +52,12 @@ class ReadyQueueProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class NodeExecutionProtocol(Protocol):
|
||||
"""Structural interface for persisted per-node execution state."""
|
||||
|
||||
execution_id: str | None
|
||||
|
||||
|
||||
class GraphExecutionProtocol(Protocol):
|
||||
"""Structural interface for graph execution aggregate.
|
||||
|
||||
@@ -67,6 +73,11 @@ class GraphExecutionProtocol(Protocol):
|
||||
exceptions_count: int
|
||||
pause_reasons: list[PauseReason]
|
||||
|
||||
@property
|
||||
def node_executions(self) -> Mapping[str, NodeExecutionProtocol]:
|
||||
"""Return the persisted node execution state keyed by node id."""
|
||||
...
|
||||
|
||||
def start(self) -> None:
|
||||
"""Transition execution into the running state."""
|
||||
...
|
||||
|
||||
@@ -231,26 +231,6 @@ vdb = [
|
||||
"holo-search-sdk>=0.4.1",
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
# targeted ignores for current type-check errors
|
||||
# TODO(QuantumGhost): suppress type errors in HITL related code.
|
||||
# fix the type error later
|
||||
module = [
|
||||
"configs.middleware.cache.redis_pubsub_config",
|
||||
"extensions.ext_redis",
|
||||
"tasks.workflow_execution_tasks",
|
||||
"graphon.nodes.base.node",
|
||||
"services.human_input_delivery_test_service",
|
||||
"core.app.apps.advanced_chat.app_generator",
|
||||
"controllers.console.human_input_form",
|
||||
"controllers.console.app.workflow_run",
|
||||
"repositories.sqlalchemy_api_workflow_node_execution_repository",
|
||||
"extensions.logstore.repositories.logstore_api_workflow_run_repository",
|
||||
]
|
||||
ignore_errors = true
|
||||
|
||||
[tool.pyrefly]
|
||||
project-includes = ["."]
|
||||
project-excludes = [".venv", "migrations/"]
|
||||
|
||||
@@ -220,7 +220,7 @@ class EmailDeliveryTestHandler:
|
||||
stmt = stmt.where(Account.id.in_(unique_ids))
|
||||
|
||||
with self._session_factory() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
rows = session.execute(stmt).tuples().all()
|
||||
return dict(rows)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -9,10 +9,17 @@ from pydantic import BaseModel, ValidationError
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator, _refresh_model
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import (
|
||||
ConversationSnapshot,
|
||||
MessageSnapshot,
|
||||
WorkflowSnapshot,
|
||||
)
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import MessageStatus
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
@@ -363,8 +370,15 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
workflow_run_id="run-id",
|
||||
)
|
||||
|
||||
workflow = SimpleNamespace(id="wf-1", tenant_id="tenant", features={"feature": True}, features_dict={})
|
||||
conversation = SimpleNamespace(id="conv-1", mode=AppMode.ADVANCED_CHAT, override_model_configs=None)
|
||||
message = SimpleNamespace(id="msg-1")
|
||||
message = SimpleNamespace(
|
||||
id="msg-1",
|
||||
query="hello",
|
||||
created_at=naive_utc_now(),
|
||||
status=MessageStatus.NORMAL,
|
||||
answer="",
|
||||
)
|
||||
db_session = SimpleNamespace(commit=MagicMock(), refresh=MagicMock(), close=MagicMock())
|
||||
captured: dict[str, object] = {}
|
||||
thread_data: dict[str, object] = {}
|
||||
@@ -394,19 +408,6 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
thread_data["started"] = True
|
||||
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread)
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model)
|
||||
|
||||
class _Session:
|
||||
def __init__(self, *args, **kwargs):
|
||||
_ = args, kwargs
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session)
|
||||
)
|
||||
@@ -424,7 +425,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
pause_state_config = SimpleNamespace(session_factory="session-factory", state_owner_user_id="owner")
|
||||
|
||||
response = generator._generate(
|
||||
workflow=SimpleNamespace(features={"feature": True}),
|
||||
workflow=workflow,
|
||||
user=SimpleNamespace(id="user"),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
application_generate_entity=application_generate_entity,
|
||||
@@ -444,6 +445,9 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
db_session.refresh.assert_called_once_with(conversation)
|
||||
db_session.close.assert_called_once()
|
||||
assert captured["draft_var_saver_factory"] == "draft-factory"
|
||||
assert isinstance(captured["workflow"], WorkflowSnapshot)
|
||||
assert isinstance(captured["conversation"], ConversationSnapshot)
|
||||
assert isinstance(captured["message"], MessageSnapshot)
|
||||
|
||||
def test_generate_internal_flow_with_existing_records_skips_init(self, monkeypatch):
|
||||
generator = AdvancedChatAppGenerator()
|
||||
@@ -464,8 +468,15 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
workflow_run_id="run-id",
|
||||
)
|
||||
|
||||
workflow = SimpleNamespace(id="wf-2", tenant_id="tenant", features={}, features_dict={})
|
||||
conversation = SimpleNamespace(id="conv-2", mode=AppMode.ADVANCED_CHAT, override_model_configs=None)
|
||||
message = SimpleNamespace(id="msg-2")
|
||||
message = SimpleNamespace(
|
||||
id="msg-2",
|
||||
query="hello",
|
||||
created_at=naive_utc_now(),
|
||||
status=MessageStatus.NORMAL,
|
||||
answer="",
|
||||
)
|
||||
db_session = SimpleNamespace(close=MagicMock(), commit=MagicMock(), refresh=MagicMock())
|
||||
init_records = MagicMock()
|
||||
thread_data: dict[str, object] = {}
|
||||
@@ -491,19 +502,6 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
thread_data["started"] = True
|
||||
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread)
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model)
|
||||
|
||||
class _Session:
|
||||
def __init__(self, *args, **kwargs):
|
||||
_ = args, kwargs
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session)
|
||||
)
|
||||
@@ -519,7 +517,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
)
|
||||
|
||||
response = generator._generate(
|
||||
workflow=SimpleNamespace(features={}),
|
||||
workflow=workflow,
|
||||
user=SimpleNamespace(id="user"),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
application_generate_entity=application_generate_entity,
|
||||
@@ -940,10 +938,16 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
generator._handle_advanced_chat_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=SimpleNamespace(),
|
||||
workflow=WorkflowSnapshot(id="wf", tenant_id="tenant", features_dict={}),
|
||||
queue_manager=SimpleNamespace(),
|
||||
conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT),
|
||||
message=SimpleNamespace(id="msg"),
|
||||
conversation=ConversationSnapshot(id="conv", mode=AppMode.ADVANCED_CHAT),
|
||||
message=MessageSnapshot(
|
||||
id="msg",
|
||||
query="hello",
|
||||
created_at=naive_utc_now(),
|
||||
status=MessageStatus.NORMAL,
|
||||
answer="",
|
||||
),
|
||||
user=SimpleNamespace(),
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
stream=False,
|
||||
@@ -981,10 +985,16 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
with pytest.raises(ValueError, match="other error"):
|
||||
generator._handle_advanced_chat_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=SimpleNamespace(),
|
||||
workflow=WorkflowSnapshot(id="wf", tenant_id="tenant", features_dict={}),
|
||||
queue_manager=SimpleNamespace(),
|
||||
conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT),
|
||||
message=SimpleNamespace(id="msg"),
|
||||
conversation=ConversationSnapshot(id="conv", mode=AppMode.ADVANCED_CHAT),
|
||||
message=MessageSnapshot(
|
||||
id="msg",
|
||||
query="hello",
|
||||
created_at=naive_utc_now(),
|
||||
status=MessageStatus.NORMAL,
|
||||
answer="",
|
||||
),
|
||||
user=SimpleNamespace(),
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
stream=False,
|
||||
@@ -992,31 +1002,6 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
|
||||
logger_exception.assert_called_once()
|
||||
|
||||
def test_refresh_model_returns_detached_model(self, monkeypatch):
|
||||
source_model = SimpleNamespace(id="source-id")
|
||||
detached_model = SimpleNamespace(id="source-id", detached=True)
|
||||
|
||||
class _Session:
|
||||
def __init__(self, *args, **kwargs):
|
||||
_ = args, kwargs
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def get(self, model_type, model_id):
|
||||
_ = model_type
|
||||
return detached_model if model_id == "source-id" else None
|
||||
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object()))
|
||||
|
||||
refreshed = _refresh_model(session=SimpleNamespace(), model=source_model)
|
||||
|
||||
assert refreshed is detached_model
|
||||
|
||||
def test_generate_worker_handles_invoke_auth_error(self, monkeypatch):
|
||||
generator = AdvancedChatAppGenerator()
|
||||
generator._dialogue_count = 1
|
||||
|
||||
@@ -21,7 +21,7 @@ from graphon.entities.pause_reason import HumanInputRequired
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from models.enums import MessageStatus
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.model import EndUser
|
||||
from models.model import AppMode, EndUser
|
||||
|
||||
|
||||
def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline:
|
||||
@@ -159,8 +159,8 @@ def test_resume_appends_chunks_to_paused_answer() -> None:
|
||||
task_id="task-1",
|
||||
)
|
||||
queue_manager = SimpleNamespace(graph_runtime_state=None)
|
||||
conversation = SimpleNamespace(id="conversation-1", mode="advanced-chat")
|
||||
message = SimpleNamespace(
|
||||
conversation = pipeline_module.ConversationSnapshot(id="conversation-1", mode=AppMode.ADVANCED_CHAT)
|
||||
message = pipeline_module.MessageSnapshot(
|
||||
id="message-1",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
query="hello",
|
||||
@@ -170,7 +170,7 @@ def test_resume_appends_chunks_to_paused_answer() -> None:
|
||||
user = EndUser()
|
||||
user.id = "user-1"
|
||||
user.session_id = "session-1"
|
||||
workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={})
|
||||
workflow = pipeline_module.WorkflowSnapshot(id="workflow-1", tenant_id="tenant-1", features_dict={})
|
||||
|
||||
pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
@@ -184,14 +184,33 @@ def test_resume_appends_chunks_to_paused_answer() -> None:
|
||||
draft_var_saver_factory=SimpleNamespace(),
|
||||
)
|
||||
|
||||
pipeline._get_message = mock.Mock(return_value=message)
|
||||
stored_message = SimpleNamespace(
|
||||
id="message-1",
|
||||
answer="before",
|
||||
status=MessageStatus.PAUSED,
|
||||
updated_at=None,
|
||||
provider_response_latency=0,
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
message_metadata=None,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
from_account_id=None,
|
||||
from_end_user_id="user-1",
|
||||
)
|
||||
pipeline._get_message = mock.Mock(return_value=stored_message)
|
||||
pipeline._recorded_files = []
|
||||
|
||||
list(pipeline._handle_text_chunk_event(QueueTextChunkEvent(text="after")))
|
||||
pipeline._save_message(session=mock.Mock())
|
||||
|
||||
assert message.answer == "beforeafter"
|
||||
assert message.status == MessageStatus.NORMAL
|
||||
assert stored_message.answer == "beforeafter"
|
||||
assert stored_message.status == MessageStatus.NORMAL
|
||||
|
||||
|
||||
def test_workflow_succeeded_emits_message_end_before_workflow_finished() -> None:
|
||||
|
||||
@@ -6,7 +6,12 @@ from types import SimpleNamespace
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import (
|
||||
AdvancedChatAppGenerateTaskPipeline,
|
||||
ConversationSnapshot,
|
||||
MessageSnapshot,
|
||||
WorkflowSnapshot,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
@@ -73,15 +78,15 @@ def _make_pipeline():
|
||||
workflow_run_id="run-id",
|
||||
)
|
||||
|
||||
message = SimpleNamespace(
|
||||
message = MessageSnapshot(
|
||||
id="message-id",
|
||||
query="hello",
|
||||
created_at=naive_utc_now(),
|
||||
status=MessageStatus.NORMAL,
|
||||
answer="",
|
||||
)
|
||||
conversation = SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT)
|
||||
workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={})
|
||||
conversation = ConversationSnapshot(id="conv-id", mode=AppMode.ADVANCED_CHAT)
|
||||
workflow = WorkflowSnapshot(id="workflow-id", tenant_id="tenant", features_dict={})
|
||||
user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session")
|
||||
|
||||
pipeline = AdvancedChatAppGenerateTaskPipeline(
|
||||
|
||||
@@ -302,8 +302,10 @@ class TestEmailDeliveryTestHandler:
|
||||
|
||||
# user_ids is None (all)
|
||||
mock_execute = MagicMock()
|
||||
mock_tuples = MagicMock()
|
||||
mock_session.execute.return_value = mock_execute
|
||||
mock_execute.all.return_value = [("u1", "u1@example.com")]
|
||||
mock_execute.tuples.return_value = mock_tuples
|
||||
mock_tuples.all.return_value = [("u1", "u1@example.com")]
|
||||
|
||||
result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=None)
|
||||
assert result == {"u1": "u1@example.com"}
|
||||
|
||||
Reference in New Issue
Block a user