feat: chatflow support multimodal (#31293)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei
2026-01-27 00:24:48 +08:00
committed by GitHub
parent 5eaf0c733a
commit e48419937b
14 changed files with 1051 additions and 133 deletions

View File

@@ -236,4 +236,7 @@ class AgentChatAppRunner(AppRunner):
queue_manager=queue_manager, queue_manager=queue_manager,
stream=application_generate_entity.stream, stream=application_generate_entity.stream,
agent=True, agent=True,
message_id=message.id,
user_id=application_generate_entity.user_id,
tenant_id=app_config.tenant_id,
) )

View File

@@ -1,6 +1,8 @@
import base64
import logging import logging
import time import time
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from mimetypes import guess_extension
from typing import TYPE_CHECKING, Any, Union from typing import TYPE_CHECKING, Any, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
@@ -11,10 +13,16 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom, InvokeFrom,
ModelConfigWithCredentialsEntity, ModelConfigWithCredentialsEntity,
) )
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent from core.app.entities.queue_entities import (
QueueAgentMessageEvent,
QueueLLMChunkEvent,
QueueMessageEndEvent,
QueueMessageFileEvent,
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.file.enums import FileTransferMethod, FileType
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@@ -22,6 +30,7 @@ from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
TextPromptMessageContent,
) )
from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.invoke import InvokeBadRequestError
@@ -29,7 +38,10 @@ from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from models.model import App, AppMode, Message, MessageAnnotation from core.tools.tool_file_manager import ToolFileManager
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
if TYPE_CHECKING: if TYPE_CHECKING:
from core.file.models import File from core.file.models import File
@@ -203,6 +215,9 @@ class AppRunner:
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
stream: bool, stream: bool,
agent: bool = False, agent: bool = False,
message_id: str | None = None,
user_id: str | None = None,
tenant_id: str | None = None,
): ):
""" """
Handle invoke result Handle invoke result
@@ -210,21 +225,41 @@ class AppRunner:
:param queue_manager: application queue manager :param queue_manager: application queue manager
:param stream: stream :param stream: stream
:param agent: agent :param agent: agent
:param message_id: message id for multimodal output
:param user_id: user id for multimodal output
:param tenant_id: tenant id for multimodal output
:return: :return:
""" """
if not stream and isinstance(invoke_result, LLMResult): if not stream and isinstance(invoke_result, LLMResult):
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) self._handle_invoke_result_direct(
invoke_result=invoke_result,
queue_manager=queue_manager,
)
elif stream and isinstance(invoke_result, Generator): elif stream and isinstance(invoke_result, Generator):
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) self._handle_invoke_result_stream(
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
)
else: else:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool): def _handle_invoke_result_direct(
self,
invoke_result: LLMResult,
queue_manager: AppQueueManager,
):
""" """
Handle invoke result direct Handle invoke result direct
:param invoke_result: invoke result :param invoke_result: invoke result
:param queue_manager: application queue manager :param queue_manager: application queue manager
:param agent: agent :param agent: agent
:param message_id: message id for multimodal output
:param user_id: user id for multimodal output
:param tenant_id: tenant id for multimodal output
:return: :return:
""" """
queue_manager.publish( queue_manager.publish(
@@ -235,13 +270,22 @@ class AppRunner:
) )
def _handle_invoke_result_stream( def _handle_invoke_result_stream(
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool self,
invoke_result: Generator[LLMResultChunk, None, None],
queue_manager: AppQueueManager,
agent: bool,
message_id: str | None = None,
user_id: str | None = None,
tenant_id: str | None = None,
): ):
""" """
Handle invoke result Handle invoke result
:param invoke_result: invoke result :param invoke_result: invoke result
:param queue_manager: application queue manager :param queue_manager: application queue manager
:param agent: agent :param agent: agent
:param message_id: message id for multimodal output
:param user_id: user id for multimodal output
:param tenant_id: tenant id for multimodal output
:return: :return:
""" """
model: str = "" model: str = ""
@@ -259,12 +303,26 @@ class AppRunner:
text += message.content text += message.content
elif isinstance(message.content, list): elif isinstance(message.content, list):
for content in message.content: for content in message.content:
if not isinstance(content, str): if isinstance(content, str):
# TODO(QuantumGhost): Add multimodal output support for easy ui. text += content
_logger.warning("received multimodal output, type=%s", type(content)) elif isinstance(content, TextPromptMessageContent):
text += content.data text += content.data
elif isinstance(content, ImagePromptMessageContent):
if message_id and user_id and tenant_id:
try:
self._handle_multimodal_image_content(
content=content,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
queue_manager=queue_manager,
)
except Exception:
_logger.exception("Failed to handle multimodal image output")
else:
_logger.warning("Received multimodal output but missing required parameters")
else: else:
text += content # failback to str text += content.data if hasattr(content, "data") else str(content)
if not model: if not model:
model = result.model model = result.model
@@ -289,6 +347,101 @@ class AppRunner:
PublishFrom.APPLICATION_MANAGER, PublishFrom.APPLICATION_MANAGER,
) )
def _handle_multimodal_image_content(
self,
content: ImagePromptMessageContent,
message_id: str,
user_id: str,
tenant_id: str,
queue_manager: AppQueueManager,
):
"""
Handle multimodal image content from LLM response.
Save the image and create a MessageFile record.
:param content: ImagePromptMessageContent instance
:param message_id: message id
:param user_id: user id
:param tenant_id: tenant id
:param queue_manager: queue manager
:return:
"""
_logger.info("Handling multimodal image content for message %s", message_id)
image_url = content.url
base64_data = content.base64_data
_logger.info("Image URL: %s, Base64 data present: %s", image_url, base64_data)
if not image_url and not base64_data:
_logger.warning("Image content has neither URL nor base64 data")
return
tool_file_manager = ToolFileManager()
# Save the image file
try:
if image_url:
# Download image from URL
_logger.info("Downloading image from URL: %s", image_url)
tool_file = tool_file_manager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
file_url=image_url,
conversation_id=None,
)
_logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
elif base64_data:
if base64_data.startswith("data:"):
base64_data = base64_data.split(",", 1)[1]
image_binary = base64.b64decode(base64_data)
mimetype = content.mime_type or "image/png"
extension = guess_extension(mimetype) or ".png"
tool_file = tool_file_manager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_binary=image_binary,
mimetype=mimetype,
filename=f"generated_image{extension}",
)
_logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
else:
return
except Exception:
_logger.exception("Failed to save image file")
return
# Create MessageFile record
message_file = MessageFile(
message_id=message_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
belongs_to="assistant",
url=f"/files/tools/{tool_file.id}",
upload_file_id=tool_file.id,
created_by_role=(
CreatorUserRole.ACCOUNT
if queue_manager.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}
else CreatorUserRole.END_USER
),
created_by=user_id,
)
db.session.add(message_file)
db.session.commit()
db.session.refresh(message_file)
# Publish QueueMessageFileEvent
queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file.id),
PublishFrom.APPLICATION_MANAGER,
)
_logger.info("QueueMessageFileEvent published for message_file_id: %s", message_file.id)
def moderation_for_inputs( def moderation_for_inputs(
self, self,
*, *,

View File

@@ -226,5 +226,10 @@ class ChatAppRunner(AppRunner):
# handle invoke result # handle invoke result
self._handle_invoke_result( self._handle_invoke_result(
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream,
message_id=message.id,
user_id=application_generate_entity.user_id,
tenant_id=app_config.tenant_id,
) )

View File

@@ -184,5 +184,10 @@ class CompletionAppRunner(AppRunner):
# handle invoke result # handle invoke result
self._handle_invoke_result( self._handle_invoke_result(
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream,
message_id=message.id,
user_id=application_generate_entity.user_id,
tenant_id=app_config.tenant_id,
) )

View File

@@ -39,6 +39,7 @@ from core.app.entities.task_entities import (
MessageAudioEndStreamResponse, MessageAudioEndStreamResponse,
MessageAudioStreamResponse, MessageAudioStreamResponse,
MessageEndStreamResponse, MessageEndStreamResponse,
StreamEvent,
StreamResponse, StreamResponse,
) )
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
@@ -70,6 +71,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
_task_state: EasyUITaskState _task_state: EasyUITaskState
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity] _application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
_precomputed_event_type: StreamEvent | None = None
def __init__( def __init__(
self, self,
@@ -342,11 +344,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
self._task_state.llm_result.message.content = current_content self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent): if isinstance(event, QueueLLMChunkEvent):
event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id) # Determine the event type once, on first LLM chunk, and reuse for subsequent chunks
if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None:
self._precomputed_event_type = self._message_cycle_manager.get_message_event_type(
message_id=self._message_id
)
yield self._message_cycle_manager.message_to_stream_response( yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text), answer=cast(str, delta_text),
message_id=self._message_id, message_id=self._message_id,
event_type=event_type, event_type=self._precomputed_event_type,
) )
else: else:
yield self._agent_message_to_stream_response( yield self._agent_message_to_stream_response(

View File

@@ -5,7 +5,7 @@ from threading import Thread
from typing import Union from typing import Union
from flask import Flask, current_app from flask import Flask, current_app
from sqlalchemy import exists, select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
@@ -30,6 +30,7 @@ from core.app.entities.task_entities import (
StreamEvent, StreamEvent,
WorkflowTaskState, WorkflowTaskState,
) )
from core.db.session_factory import session_factory
from core.llm_generator.llm_generator import LLMGenerator from core.llm_generator.llm_generator import LLMGenerator
from core.tools.signature import sign_tool_file from core.tools.signature import sign_tool_file
from extensions.ext_database import db from extensions.ext_database import db
@@ -57,13 +58,15 @@ class MessageCycleManager:
self._message_has_file: set[str] = set() self._message_has_file: set[str] = set()
def get_message_event_type(self, message_id: str) -> StreamEvent: def get_message_event_type(self, message_id: str) -> StreamEvent:
# Fast path: cached determination from prior QueueMessageFileEvent
if message_id in self._message_has_file: if message_id in self._message_has_file:
return StreamEvent.MESSAGE_FILE return StreamEvent.MESSAGE_FILE
with Session(db.engine, expire_on_commit=False) as session: # Use SQLAlchemy 2.x style session.scalar(select(...))
has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar() with session_factory.create_session() as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id))
if has_file: if message_file:
self._message_has_file.add(message_id) self._message_has_file.add(message_id)
return StreamEvent.MESSAGE_FILE return StreamEvent.MESSAGE_FILE
@@ -199,6 +202,8 @@ class MessageCycleManager:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id)) message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
if message_file and message_file.url is not None: if message_file and message_file.url is not None:
self._message_has_file.add(message_file.message_id)
# get tool file id # get tool file id
tool_file_id = message_file.url.split("/")[-1] tool_file_id = message_file.url.split("/")[-1]
# trim extension # trim extension

View File

@@ -0,0 +1,454 @@
"""Test multimodal image output handling in BaseAppRunner."""
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueMessageFileEvent
from core.file.enums import FileTransferMethod, FileType
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from models.enums import CreatorUserRole
class TestBaseAppRunnerMultimodal:
"""Test that BaseAppRunner correctly handles multimodal image content."""
@pytest.fixture
def mock_user_id(self):
"""Mock user ID."""
return str(uuid4())
@pytest.fixture
def mock_tenant_id(self):
"""Mock tenant ID."""
return str(uuid4())
@pytest.fixture
def mock_message_id(self):
"""Mock message ID."""
return str(uuid4())
@pytest.fixture
def mock_queue_manager(self):
"""Create a mock queue manager."""
manager = MagicMock()
manager.invoke_from = InvokeFrom.SERVICE_API
return manager
@pytest.fixture
def mock_tool_file(self):
"""Create a mock tool file."""
tool_file = MagicMock()
tool_file.id = str(uuid4())
return tool_file
@pytest.fixture
def mock_message_file(self):
"""Create a mock message file."""
message_file = MagicMock()
message_file.id = str(uuid4())
return message_file
def test_handle_multimodal_image_content_with_url(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
mock_tool_file,
mock_message_file,
):
"""Test handling image from URL."""
# Arrange
image_url = "http://example.com/image.png"
content = ImagePromptMessageContent(
url=image_url,
format="png",
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_url.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert
# Verify tool file was created from URL
mock_mgr.create_file_by_url.assert_called_once_with(
user_id=mock_user_id,
tenant_id=mock_tenant_id,
file_url=image_url,
conversation_id=None,
)
# Verify message file was created with correct parameters
mock_msg_file_class.assert_called_once()
call_kwargs = mock_msg_file_class.call_args[1]
assert call_kwargs["message_id"] == mock_message_id
assert call_kwargs["type"] == FileType.IMAGE
assert call_kwargs["transfer_method"] == FileTransferMethod.TOOL_FILE
assert call_kwargs["belongs_to"] == "assistant"
assert call_kwargs["created_by"] == mock_user_id
# Verify database operations
mock_session.add.assert_called_once_with(mock_message_file)
mock_session.commit.assert_called_once()
mock_session.refresh.assert_called_once_with(mock_message_file)
# Verify event was published
mock_queue_manager.publish.assert_called_once()
publish_call = mock_queue_manager.publish.call_args
assert isinstance(publish_call[0][0], QueueMessageFileEvent)
assert publish_call[0][0].message_file_id == mock_message_file.id
# publish_from might be passed as positional or keyword argument
assert (
publish_call[0][1] == PublishFrom.APPLICATION_MANAGER
or publish_call.kwargs.get("publish_from") == PublishFrom.APPLICATION_MANAGER
)
def test_handle_multimodal_image_content_with_base64(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
mock_tool_file,
mock_message_file,
):
"""Test handling image from base64 data."""
# Arrange
import base64
# Create a small test image (1x1 PNG)
test_image_data = base64.b64encode(
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde"
).decode()
content = ImagePromptMessageContent(
base64_data=test_image_data,
format="png",
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_raw.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert
# Verify tool file was created from base64
mock_mgr.create_file_by_raw.assert_called_once()
call_kwargs = mock_mgr.create_file_by_raw.call_args[1]
assert call_kwargs["user_id"] == mock_user_id
assert call_kwargs["tenant_id"] == mock_tenant_id
assert call_kwargs["conversation_id"] is None
assert "file_binary" in call_kwargs
assert call_kwargs["mimetype"] == "image/png"
assert call_kwargs["filename"].startswith("generated_image")
assert call_kwargs["filename"].endswith(".png")
# Verify message file was created
mock_msg_file_class.assert_called_once()
# Verify database operations
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
mock_session.refresh.assert_called_once()
# Verify event was published
mock_queue_manager.publish.assert_called_once()
def test_handle_multimodal_image_content_with_base64_data_uri(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
mock_tool_file,
mock_message_file,
):
"""Test handling image from base64 data with URI prefix."""
# Arrange
# Data URI format: data:image/png;base64,<base64_data>
test_image_data = (
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
)
content = ImagePromptMessageContent(
base64_data=f"data:image/png;base64,{test_image_data}",
format="png",
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_raw.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert - verify that base64 data was extracted correctly (without prefix)
mock_mgr.create_file_by_raw.assert_called_once()
call_kwargs = mock_mgr.create_file_by_raw.call_args[1]
# The base64 data should be decoded, so we check the binary was passed
assert "file_binary" in call_kwargs
def test_handle_multimodal_image_content_without_url_or_base64(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
):
"""Test handling image content without URL or base64 data."""
# Arrange
content = ImagePromptMessageContent(
url="",
base64_data="",
format="png",
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert - should not create any files or publish events
mock_mgr_class.assert_not_called()
mock_msg_file_class.assert_not_called()
mock_session.add.assert_not_called()
mock_queue_manager.publish.assert_not_called()
def test_handle_multimodal_image_content_with_error(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
):
"""Test handling image content when an error occurs."""
# Arrange
image_url = "http://example.com/image.png"
content = ImagePromptMessageContent(
url=image_url,
format="png",
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock to raise exception
mock_mgr = MagicMock()
mock_mgr.create_file_by_url.side_effect = Exception("Network error")
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
# Should not raise exception, just log it
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert - should not create message file or publish event on error
mock_msg_file_class.assert_not_called()
mock_session.add.assert_not_called()
mock_queue_manager.publish.assert_not_called()
def test_handle_multimodal_image_content_debugger_mode(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
mock_tool_file,
mock_message_file,
):
"""Test that debugger mode sets correct created_by_role."""
# Arrange
image_url = "http://example.com/image.png"
content = ImagePromptMessageContent(
url=image_url,
format="png",
mime_type="image/png",
)
mock_queue_manager.invoke_from = InvokeFrom.DEBUGGER
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_url.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert - verify created_by_role is ACCOUNT for debugger mode
call_kwargs = mock_msg_file_class.call_args[1]
assert call_kwargs["created_by_role"] == CreatorUserRole.ACCOUNT
def test_handle_multimodal_image_content_service_api_mode(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
mock_tool_file,
mock_message_file,
):
"""Test that service API mode sets correct created_by_role."""
# Arrange
image_url = "http://example.com/image.png"
content = ImagePromptMessageContent(
url=image_url,
format="png",
mime_type="image/png",
)
mock_queue_manager.invoke_from = InvokeFrom.SERVICE_API
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_url.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert - verify created_by_role is END_USER for service API
call_kwargs = mock_msg_file_class.call_args[1]
assert call_kwargs["created_by_role"] == CreatorUserRole.END_USER

View File

@@ -1,7 +1,6 @@
"""Unit tests for the message cycle manager optimization.""" """Unit tests for the message cycle manager optimization."""
from types import SimpleNamespace from unittest.mock import Mock, patch
from unittest.mock import ANY, Mock, patch
import pytest import pytest
from flask import current_app from flask import current_app
@@ -28,17 +27,14 @@ class TestMessageCycleManagerOptimization:
def test_get_message_event_type_with_message_file(self, message_cycle_manager): def test_get_message_event_type_with_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE_FILE when message has files.""" """Test get_message_event_type returns MESSAGE_FILE when message has files."""
with ( with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
# Setup mock session and message file # Setup mock session and message file
mock_session = Mock() mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
mock_message_file = Mock() mock_message_file = Mock()
# Current implementation uses session.query(...).scalar() # Current implementation uses session.scalar(select(...))
mock_session.query.return_value.scalar.return_value = mock_message_file mock_session.scalar.return_value = mock_message_file
# Execute # Execute
with current_app.app_context(): with current_app.app_context():
@@ -46,19 +42,16 @@ class TestMessageCycleManagerOptimization:
# Assert # Assert
assert result == StreamEvent.MESSAGE_FILE assert result == StreamEvent.MESSAGE_FILE
mock_session.query.return_value.scalar.assert_called_once() mock_session.scalar.assert_called_once()
def test_get_message_event_type_without_message_file(self, message_cycle_manager): def test_get_message_event_type_without_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE when message has no files.""" """Test get_message_event_type returns MESSAGE when message has no files."""
with ( with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
# Setup mock session and no message file # Setup mock session and no message file
mock_session = Mock() mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
# Current implementation uses session.query(...).scalar() # Current implementation uses session.scalar(select(...))
mock_session.query.return_value.scalar.return_value = None mock_session.scalar.return_value = None
# Execute # Execute
with current_app.app_context(): with current_app.app_context():
@@ -66,21 +59,18 @@ class TestMessageCycleManagerOptimization:
# Assert # Assert
assert result == StreamEvent.MESSAGE assert result == StreamEvent.MESSAGE
mock_session.query.return_value.scalar.assert_called_once() mock_session.scalar.assert_called_once()
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager): def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it.""" """MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
with ( with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
# Setup mock session and message file # Setup mock session and message file
mock_session = Mock() mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
mock_message_file = Mock() mock_message_file = Mock()
# Current implementation uses session.query(...).scalar() # Current implementation uses session.scalar(select(...))
mock_session.query.return_value.scalar.return_value = mock_message_file mock_session.scalar.return_value = mock_message_file
# Execute: compute event type once, then pass to message_to_stream_response # Execute: compute event type once, then pass to message_to_stream_response
with current_app.app_context(): with current_app.app_context():
@@ -94,11 +84,11 @@ class TestMessageCycleManagerOptimization:
assert result.answer == "Hello world" assert result.answer == "Hello world"
assert result.id == "test-message-id" assert result.id == "test-message-id"
assert result.event == StreamEvent.MESSAGE_FILE assert result.event == StreamEvent.MESSAGE_FILE
mock_session.query.return_value.scalar.assert_called_once() mock_session.scalar.assert_called_once()
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager): def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
"""Test that message_to_stream_response skips database query when event_type is provided.""" """Test that message_to_stream_response skips database query when event_type is provided."""
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
# Execute with event_type provided # Execute with event_type provided
result = message_cycle_manager.message_to_stream_response( result = message_cycle_manager.message_to_stream_response(
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
@@ -109,8 +99,8 @@ class TestMessageCycleManagerOptimization:
assert result.answer == "Hello world" assert result.answer == "Hello world"
assert result.id == "test-message-id" assert result.id == "test-message-id"
assert result.event == StreamEvent.MESSAGE assert result.event == StreamEvent.MESSAGE
# Should not query database when event_type is provided # Should not open a session when event_type is provided
mock_session_class.assert_not_called() mock_session_factory.create_session.assert_not_called()
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager): def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
"""Test message_to_stream_response with from_variable_selector parameter.""" """Test message_to_stream_response with from_variable_selector parameter."""
@@ -130,24 +120,21 @@ class TestMessageCycleManagerOptimization:
def test_optimization_usage_example(self, message_cycle_manager): def test_optimization_usage_example(self, message_cycle_manager):
"""Test the optimization pattern that should be used by callers.""" """Test the optimization pattern that should be used by callers."""
# Step 1: Get event type once (this queries database) # Step 1: Get event type once (this queries database)
with ( with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
mock_session = Mock() mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
# Current implementation uses session.query(...).scalar() # Current implementation uses session.scalar(select(...))
mock_session.query.return_value.scalar.return_value = None # No files mock_session.scalar.return_value = None # No files
with current_app.app_context(): with current_app.app_context():
event_type = message_cycle_manager.get_message_event_type("test-message-id") event_type = message_cycle_manager.get_message_event_type("test-message-id")
# Should query database once # Should open session once
mock_session_class.assert_called_once_with(ANY, expire_on_commit=False) mock_session_factory.create_session.assert_called_once()
assert event_type == StreamEvent.MESSAGE assert event_type == StreamEvent.MESSAGE
# Step 2: Use event_type for multiple calls (no additional queries) # Step 2: Use event_type for multiple calls (no additional queries)
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
mock_session_class.return_value.__enter__.return_value = Mock() mock_session_factory.create_session.return_value.__enter__.return_value = Mock()
chunk1_response = message_cycle_manager.message_to_stream_response( chunk1_response = message_cycle_manager.message_to_stream_response(
answer="Chunk 1", message_id="test-message-id", event_type=event_type answer="Chunk 1", message_id="test-message-id", event_type=event_type
@@ -157,8 +144,8 @@ class TestMessageCycleManagerOptimization:
answer="Chunk 2", message_id="test-message-id", event_type=event_type answer="Chunk 2", message_id="test-message-id", event_type=event_type
) )
# Should not query database again # Should not open session again when event_type provided
mock_session_class.assert_not_called() mock_session_factory.create_session.assert_not_called()
assert chunk1_response.event == StreamEvent.MESSAGE assert chunk1_response.event == StreamEvent.MESSAGE
assert chunk2_response.event == StreamEvent.MESSAGE assert chunk2_response.event == StreamEvent.MESSAGE

View File

@@ -0,0 +1,178 @@
/**
* Tests for multimodal image file handling in chat hooks.
* Tests the file object conversion logic without full hook integration.
*/
describe('Multimodal File Handling', () => {
describe('File type to MIME type mapping', () => {
it('should map image to image/png', () => {
const fileType: string = 'image'
const expectedMime = 'image/png'
const mimeType = fileType === 'image' ? 'image/png' : 'application/octet-stream'
expect(mimeType).toBe(expectedMime)
})
it('should map video to video/mp4', () => {
const fileType: string = 'video'
const expectedMime = 'video/mp4'
const mimeType = fileType === 'video' ? 'video/mp4' : 'application/octet-stream'
expect(mimeType).toBe(expectedMime)
})
it('should map audio to audio/mpeg', () => {
const fileType: string = 'audio'
const expectedMime = 'audio/mpeg'
const mimeType = fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'
expect(mimeType).toBe(expectedMime)
})
it('should map unknown to application/octet-stream', () => {
const fileType: string = 'unknown'
const expectedMime = 'application/octet-stream'
const mimeType = ['image', 'video', 'audio'].includes(fileType) ? 'image/png' : 'application/octet-stream'
expect(mimeType).toBe(expectedMime)
})
})
describe('TransferMethod selection', () => {
it('should select remote_url for images', () => {
const fileType: string = 'image'
const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file'
expect(transferMethod).toBe('remote_url')
})
it('should select local_file for non-images', () => {
const fileType: string = 'video'
const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file'
expect(transferMethod).toBe('local_file')
})
})
describe('File extension mapping', () => {
it('should use .png extension for images', () => {
const fileType: string = 'image'
const expectedExtension = '.png'
const extension = fileType === 'image' ? 'png' : 'bin'
expect(extension).toBe(expectedExtension.replace('.', ''))
})
it('should use .mp4 extension for videos', () => {
const fileType: string = 'video'
const expectedExtension = '.mp4'
const extension = fileType === 'video' ? 'mp4' : 'bin'
expect(extension).toBe(expectedExtension.replace('.', ''))
})
it('should use .mp3 extension for audio', () => {
const fileType: string = 'audio'
const expectedExtension = '.mp3'
const extension = fileType === 'audio' ? 'mp3' : 'bin'
expect(extension).toBe(expectedExtension.replace('.', ''))
})
})
describe('File name generation', () => {
it('should generate correct file name for images', () => {
const fileType: string = 'image'
const expectedName = 'generated_image.png'
const fileName = `generated_${fileType}.${fileType === 'image' ? 'png' : 'bin'}`
expect(fileName).toBe(expectedName)
})
it('should generate correct file name for videos', () => {
const fileType: string = 'video'
const expectedName = 'generated_video.mp4'
const fileName = `generated_${fileType}.${fileType === 'video' ? 'mp4' : 'bin'}`
expect(fileName).toBe(expectedName)
})
it('should generate correct file name for audio', () => {
const fileType: string = 'audio'
const expectedName = 'generated_audio.mp3'
const fileName = `generated_${fileType}.${fileType === 'audio' ? 'mp3' : 'bin'}`
expect(fileName).toBe(expectedName)
})
})
describe('SupportFileType mapping', () => {
it('should map image type to image supportFileType', () => {
const fileType: string = 'image'
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
expect(supportFileType).toBe('image')
})
it('should map video type to video supportFileType', () => {
const fileType: string = 'video'
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
expect(supportFileType).toBe('video')
})
it('should map audio type to audio supportFileType', () => {
const fileType: string = 'audio'
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
expect(supportFileType).toBe('audio')
})
it('should map unknown type to document supportFileType', () => {
const fileType: string = 'unknown'
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
expect(supportFileType).toBe('document')
})
})
describe('File conversion logic', () => {
it('should detect existing transferMethod', () => {
const fileWithTransferMethod = {
id: 'file-123',
transferMethod: 'remote_url' as const,
type: 'image/png',
name: 'test.png',
size: 1024,
supportFileType: 'image',
progress: 100,
}
const hasTransferMethod = 'transferMethod' in fileWithTransferMethod
expect(hasTransferMethod).toBe(true)
})
it('should detect missing transferMethod', () => {
const fileWithoutTransferMethod = {
id: 'file-456',
type: 'image',
url: 'http://example.com/image.png',
belongs_to: 'assistant',
}
const hasTransferMethod = 'transferMethod' in fileWithoutTransferMethod
expect(hasTransferMethod).toBe(false)
})
it('should create file with size 0 for generated files', () => {
const expectedSize = 0
expect(expectedSize).toBe(0)
})
})
describe('Agent vs Non-Agent mode logic', () => {
it('should check for agent_thoughts to determine mode', () => {
const agentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {
agent_thoughts: [{}],
}
const isAgentMode = agentResponse.agent_thoughts && agentResponse.agent_thoughts.length > 0
expect(isAgentMode).toBe(true)
})
it('should detect non-agent mode when agent_thoughts is empty', () => {
const nonAgentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {
agent_thoughts: [],
}
const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0
expect(isAgentMode).toBe(false)
})
it('should detect non-agent mode when agent_thoughts is undefined', () => {
const nonAgentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {}
const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0
expect(isAgentMode).toBeFalsy()
})
})
})

View File

@@ -419,9 +419,40 @@ export const useChat = (
} }
}, },
onFile(file) { onFile(file) {
// Convert simple file type to MIME type for non-agent mode
// Backend sends: { id, type: "image", belongs_to, url }
// Frontend expects: { id, type: "image/png", transferMethod, url, uploadedId, supportFileType, name, size }
// Determine file type for MIME conversion
const fileType = (file as { type?: string }).type || 'image'
// If file already has transferMethod, use it as base and ensure all required fields exist
// Otherwise, create a new complete file object
const baseFile = ('transferMethod' in file) ? (file as Partial<FileEntity>) : null
const convertedFile: FileEntity = {
id: baseFile?.id || (file as { id: string }).id,
type: baseFile?.type || (fileType === 'image' ? 'image/png' : fileType === 'video' ? 'video/mp4' : fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'),
transferMethod: (baseFile?.transferMethod as FileEntity['transferMethod']) || (fileType === 'image' ? 'remote_url' : 'local_file'),
uploadedId: baseFile?.uploadedId || (file as { id: string }).id,
supportFileType: baseFile?.supportFileType || (fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'),
progress: baseFile?.progress ?? 100,
name: baseFile?.name || `generated_${fileType}.${fileType === 'image' ? 'png' : fileType === 'video' ? 'mp4' : fileType === 'audio' ? 'mp3' : 'bin'}`,
url: baseFile?.url || (file as { url?: string }).url,
size: baseFile?.size ?? 0, // Generated files don't have a known size
}
// For agent mode, add files to the last thought
const lastThought = responseItem.agent_thoughts?.[responseItem.agent_thoughts?.length - 1] const lastThought = responseItem.agent_thoughts?.[responseItem.agent_thoughts?.length - 1]
if (lastThought) if (lastThought) {
responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(lastThought as any).message_files, file] const thought = lastThought as { message_files?: FileEntity[] }
responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(thought.message_files ?? []), convertedFile]
}
// For non-agent mode, add files directly to responseItem.message_files
else {
const currentFiles = (responseItem.message_files as FileEntity[] | undefined) ?? []
responseItem.message_files = [...currentFiles, convertedFile]
}
updateCurrentQAOnTree({ updateCurrentQAOnTree({
placeholderQuestionId, placeholderQuestionId,

View File

@@ -2039,8 +2039,13 @@ describe('Integration: Hit Testing Flow', () => {
renderWithProviders(<HitTestingPage datasetId="dataset-1" />) renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
// Wait for textbox with timeout for CI
const textarea = await waitFor(
() => screen.getByRole('textbox'),
{ timeout: 3000 },
)
// Type query // Type query
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'Test query' } }) fireEvent.change(textarea, { target: { value: 'Test query' } })
// Find submit button by class // Find submit button by class
@@ -2054,8 +2059,13 @@ describe('Integration: Hit Testing Flow', () => {
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />) const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
// Wait for textbox with timeout for CI
const textarea = await waitFor(
() => screen.getByRole('textbox'),
{ timeout: 3000 },
)
// Type query // Type query
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'Test query' } }) fireEvent.change(textarea, { target: { value: 'Test query' } })
// Component should still be functional - check for the main container // Component should still be functional - check for the main container
@@ -2089,10 +2099,15 @@ describe('Integration: Hit Testing Flow', () => {
isLoading: false, isLoading: false,
} as unknown as ReturnType<typeof useDatasetTestingRecords>) } as unknown as ReturnType<typeof useDatasetTestingRecords>)
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />) const { container: _container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
// Wait for textbox to be rendered with timeout for CI environment
const textarea = await waitFor(
() => screen.getByRole('textbox'),
{ timeout: 3000 },
)
// Type query // Type query
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'Test query' } }) fireEvent.change(textarea, { target: { value: 'Test query' } })
// Submit // Submit
@@ -2101,8 +2116,13 @@ describe('Integration: Hit Testing Flow', () => {
if (submitButton) if (submitButton)
fireEvent.click(submitButton) fireEvent.click(submitButton)
// Verify the component is still rendered after submission // Wait for the mutation to complete
expect(container.firstChild).toBeInTheDocument() await waitFor(
() => {
expect(mockHitTestingMutateAsync).toHaveBeenCalled()
},
{ timeout: 3000 },
)
}) })
it('should render ResultItem components for non-external results', async () => { it('should render ResultItem components for non-external results', async () => {
@@ -2127,10 +2147,15 @@ describe('Integration: Hit Testing Flow', () => {
isLoading: false, isLoading: false,
} as unknown as ReturnType<typeof useDatasetTestingRecords>) } as unknown as ReturnType<typeof useDatasetTestingRecords>)
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />) const { container: _container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
// Wait for component to be fully rendered with longer timeout
const textarea = await waitFor(
() => screen.getByRole('textbox'),
{ timeout: 3000 },
)
// Submit a query // Submit a query
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'Test query' } }) fireEvent.change(textarea, { target: { value: 'Test query' } })
const buttons = screen.getAllByRole('button') const buttons = screen.getAllByRole('button')
@@ -2138,8 +2163,13 @@ describe('Integration: Hit Testing Flow', () => {
if (submitButton) if (submitButton)
fireEvent.click(submitButton) fireEvent.click(submitButton)
// Verify component is rendered after submission // Wait for mutation to complete with longer timeout
expect(container.firstChild).toBeInTheDocument() await waitFor(
() => {
expect(mockHitTestingMutateAsync).toHaveBeenCalled()
},
{ timeout: 3000 },
)
}) })
it('should render external results when dataset is external', async () => { it('should render external results when dataset is external', async () => {
@@ -2165,8 +2195,14 @@ describe('Integration: Hit Testing Flow', () => {
// Component should render // Component should render
expect(container.firstChild).toBeInTheDocument() expect(container.firstChild).toBeInTheDocument()
// Wait for textbox with timeout for CI
const textarea = await waitFor(
() => screen.getByRole('textbox'),
{ timeout: 3000 },
)
// Type in textarea to verify component is functional // Type in textarea to verify component is functional
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'Test query' } }) fireEvent.change(textarea, { target: { value: 'Test query' } })
const buttons = screen.getAllByRole('button') const buttons = screen.getAllByRole('button')
@@ -2174,9 +2210,13 @@ describe('Integration: Hit Testing Flow', () => {
if (submitButton) if (submitButton)
fireEvent.click(submitButton) fireEvent.click(submitButton)
await waitFor(() => { // Verify component is still functional after submission
expect(screen.getByRole('textbox')).toBeInTheDocument() await waitFor(
}) () => {
expect(screen.getByRole('textbox')).toBeInTheDocument()
},
{ timeout: 3000 },
)
}) })
}) })
@@ -2260,8 +2300,13 @@ describe('renderHitResults Coverage', () => {
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />) const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
// Wait for textbox with timeout for CI
const textarea = await waitFor(
() => screen.getByRole('textbox'),
{ timeout: 3000 },
)
// Enter query // Enter query
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'test query' } }) fireEvent.change(textarea, { target: { value: 'test query' } })
// Submit // Submit
@@ -2386,8 +2431,13 @@ describe('HitTestingPage Internal Functions Coverage', () => {
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />) const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
// Wait for textbox with timeout for CI
const textarea = await waitFor(
() => screen.getByRole('textbox'),
{ timeout: 3000 },
)
// Enter query and submit // Enter query and submit
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'test query' } }) fireEvent.change(textarea, { target: { value: 'test query' } })
const buttons = screen.getAllByRole('button') const buttons = screen.getAllByRole('button')
@@ -2400,7 +2450,7 @@ describe('HitTestingPage Internal Functions Coverage', () => {
// Wait for state updates // Wait for state updates
await waitFor(() => { await waitFor(() => {
expect(container.firstChild).toBeInTheDocument() expect(container.firstChild).toBeInTheDocument()
}, { timeout: 2000 }) }, { timeout: 3000 })
// Verify mutation was called // Verify mutation was called
expect(mockHitTestingMutateAsync).toHaveBeenCalled() expect(mockHitTestingMutateAsync).toHaveBeenCalled()
@@ -2445,8 +2495,13 @@ describe('HitTestingPage Internal Functions Coverage', () => {
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />) const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
// Wait for textbox with timeout for CI
const textarea = await waitFor(
() => screen.getByRole('textbox'),
{ timeout: 3000 },
)
// Submit a query // Submit a query
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'test' } }) fireEvent.change(textarea, { target: { value: 'test' } })
const buttons = screen.getAllByRole('button') const buttons = screen.getAllByRole('button')
@@ -2458,7 +2513,7 @@ describe('HitTestingPage Internal Functions Coverage', () => {
// Verify the component renders // Verify the component renders
await waitFor(() => { await waitFor(() => {
expect(container.firstChild).toBeInTheDocument() expect(container.firstChild).toBeInTheDocument()
}) }, { timeout: 3000 })
}) })
}) })

View File

@@ -162,6 +162,44 @@ vi.mock('@/utils/var', () => ({
getMarketplaceUrl: (path: string, _params?: Record<string, string | undefined>) => `https://marketplace.dify.ai${path}`, getMarketplaceUrl: (path: string, _params?: Record<string, string | undefined>) => `https://marketplace.dify.ai${path}`,
})) }))
// Mock marketplace client used by marketplace utils
vi.mock('@/service/client', () => ({
marketplaceClient: {
collections: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({
data: {
collections: [
{
name: 'collection-1',
label: { 'en-US': 'Collection 1' },
description: { 'en-US': 'Desc' },
rule: '',
created_at: '2024-01-01',
updated_at: '2024-01-01',
searchable: true,
search_params: { query: '', sort_by: 'install_count', sort_order: 'DESC' },
},
],
},
})),
collectionPlugins: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({
data: {
plugins: [
{ type: 'plugin', org: 'test', name: 'plugin1', tags: [] },
],
},
})),
// Some utils paths may call searchAdvanced; provide a minimal stub
searchAdvanced: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({
data: {
plugins: [
{ type: 'plugin', org: 'test', name: 'plugin1', tags: [] },
],
total: 1,
},
})),
},
}))
// Mock context/query-client // Mock context/query-client
vi.mock('@/context/query-client', () => ({ vi.mock('@/context/query-client', () => ({
TanstackQueryInitializer: ({ children }: { children: React.ReactNode }) => <div data-testid="query-initializer">{children}</div>, TanstackQueryInitializer: ({ children }: { children: React.ReactNode }) => <div data-testid="query-initializer">{children}</div>,
@@ -1474,7 +1512,24 @@ describe('flatMap Coverage', () => {
// ================================ // ================================
// Async Utils Tests // Async Utils Tests
// ================================ // ================================
// Narrow mock surface and avoid any in tests
// Types are local to this spec to keep scope minimal
type FnMock = ReturnType<typeof vi.fn>
type MarketplaceClientMock = {
collectionPlugins: FnMock
collections: FnMock
}
describe('Async Utils', () => { describe('Async Utils', () => {
let marketplaceClientMock: MarketplaceClientMock
beforeAll(async () => {
const mod = await import('@/service/client')
marketplaceClientMock = mod.marketplaceClient as unknown as MarketplaceClientMock
})
beforeEach(() => { beforeEach(() => {
vi.clearAllMocks() vi.clearAllMocks()
}) })
@@ -1490,12 +1545,10 @@ describe('Async Utils', () => {
{ type: 'plugin', org: 'test', name: 'plugin2' }, { type: 'plugin', org: 'test', name: 'plugin2' },
] ]
globalThis.fetch = vi.fn().mockResolvedValue( // Adjusted to our mocked marketplaceClient instead of fetch
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { marketplaceClientMock.collectionPlugins.mockResolvedValueOnce({
status: 200, data: { plugins: mockPlugins },
headers: { 'Content-Type': 'application/json' }, })
}),
)
const { getMarketplacePluginsByCollectionId } = await import('./utils') const { getMarketplacePluginsByCollectionId } = await import('./utils')
const result = await getMarketplacePluginsByCollectionId('test-collection', { const result = await getMarketplacePluginsByCollectionId('test-collection', {
@@ -1504,12 +1557,13 @@ describe('Async Utils', () => {
type: 'plugin', type: 'plugin',
}) })
expect(globalThis.fetch).toHaveBeenCalled() expect(marketplaceClientMock.collectionPlugins).toHaveBeenCalled()
expect(result).toHaveLength(2) expect(result).toHaveLength(2)
}) })
it('should handle fetch error and return empty array', async () => { it('should handle fetch error and return empty array', async () => {
globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error')) // Simulate error from client
marketplaceClientMock.collectionPlugins.mockRejectedValueOnce(new Error('Network error'))
const { getMarketplacePluginsByCollectionId } = await import('./utils') const { getMarketplacePluginsByCollectionId } = await import('./utils')
const result = await getMarketplacePluginsByCollectionId('test-collection') const result = await getMarketplacePluginsByCollectionId('test-collection')
@@ -1519,25 +1573,18 @@ describe('Async Utils', () => {
it('should pass abort signal when provided', async () => { it('should pass abort signal when provided', async () => {
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }] const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
globalThis.fetch = vi.fn().mockResolvedValue( // Our client mock receives the signal as second arg
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { marketplaceClientMock.collectionPlugins.mockResolvedValueOnce({
status: 200, data: { plugins: mockPlugins },
headers: { 'Content-Type': 'application/json' }, })
}),
)
const controller = new AbortController() const controller = new AbortController()
const { getMarketplacePluginsByCollectionId } = await import('./utils') const { getMarketplacePluginsByCollectionId } = await import('./utils')
await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal }) await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal })
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL expect(marketplaceClientMock.collectionPlugins).toHaveBeenCalled()
expect(globalThis.fetch).toHaveBeenCalledWith( const call = marketplaceClientMock.collectionPlugins.mock.calls[0]
expect.any(Request), expect(call[1]).toMatchObject({ signal: controller.signal })
expect.any(Object),
)
const call = vi.mocked(globalThis.fetch).mock.calls[0]
const request = call[0] as Request
expect(request.url).toContain('test-collection')
}) })
}) })
@@ -1548,23 +1595,17 @@ describe('Async Utils', () => {
] ]
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }] const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
let callCount = 0 // Simulate two-step client calls: collections then collectionPlugins
globalThis.fetch = vi.fn().mockImplementation(() => { let stage = 0
callCount++ marketplaceClientMock.collections.mockImplementationOnce(async () => {
if (callCount === 1) { stage = 1
return Promise.resolve( return { data: { collections: mockCollections } }
new Response(JSON.stringify({ data: { collections: mockCollections } }), { })
status: 200, marketplaceClientMock.collectionPlugins.mockImplementation(async () => {
headers: { 'Content-Type': 'application/json' }, if (stage === 1) {
}), return { data: { plugins: mockPlugins } }
)
} }
return Promise.resolve( return { data: { plugins: [] } }
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
}) })
const { getMarketplaceCollectionsAndPlugins } = await import('./utils') const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
@@ -1578,7 +1619,8 @@ describe('Async Utils', () => {
}) })
it('should handle fetch error and return empty data', async () => { it('should handle fetch error and return empty data', async () => {
globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error')) // Simulate client error
marketplaceClientMock.collections.mockRejectedValueOnce(new Error('Network error'))
const { getMarketplaceCollectionsAndPlugins } = await import('./utils') const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
const result = await getMarketplaceCollectionsAndPlugins() const result = await getMarketplaceCollectionsAndPlugins()
@@ -1588,24 +1630,16 @@ describe('Async Utils', () => {
}) })
it('should append condition and type to URL when provided', async () => { it('should append condition and type to URL when provided', async () => {
globalThis.fetch = vi.fn().mockResolvedValue( // Assert that the client was called with query containing condition/type
new Response(JSON.stringify({ data: { collections: [] } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
const { getMarketplaceCollectionsAndPlugins } = await import('./utils') const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
await getMarketplaceCollectionsAndPlugins({ await getMarketplaceCollectionsAndPlugins({
condition: 'category=tool', condition: 'category=tool',
type: 'bundle', type: 'bundle',
}) })
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL expect(marketplaceClientMock.collections).toHaveBeenCalled()
expect(globalThis.fetch).toHaveBeenCalled() const call = marketplaceClientMock.collections.mock.calls[0]
const call = vi.mocked(globalThis.fetch).mock.calls[0] expect(call[0]).toMatchObject({ query: expect.objectContaining({ condition: 'category=tool', type: 'bundle' }) })
const request = call[0] as Request
expect(request.url).toContain('condition=category%3Dtool')
}) })
}) })
}) })

View File

@@ -822,7 +822,7 @@
"count": 2 "count": 2
}, },
"ts/no-explicit-any": { "ts/no-explicit-any": {
"count": 15 "count": 14
} }
}, },
"app/components/base/chat/chat/index.tsx": { "app/components/base/chat/chat/index.tsx": {

View File

@@ -152,6 +152,8 @@ export const formatNumberAbbreviated = (num: number) => {
: `${formatted}${units[unitIndex].symbol}` : `${formatted}${units[unitIndex].symbol}`
} }
} }
// Fallback: if no threshold matched, return the number string
return num.toString()
} }
export const formatToLocalTime = (time: Dayjs, local: Locale, format: string) => { export const formatToLocalTime = (time: Dayjs, local: Locale, format: string) => {