mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 19:59:21 +08:00
feat: chatflow support multimodal (#31293)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -236,4 +236,7 @@ class AgentChatAppRunner(AppRunner):
|
|||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
stream=application_generate_entity.stream,
|
stream=application_generate_entity.stream,
|
||||||
agent=True,
|
agent=True,
|
||||||
|
message_id=message.id,
|
||||||
|
user_id=application_generate_entity.user_id,
|
||||||
|
tenant_id=app_config.tenant_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
|
from mimetypes import guess_extension
|
||||||
from typing import TYPE_CHECKING, Any, Union
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||||
@@ -11,10 +13,16 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
InvokeFrom,
|
InvokeFrom,
|
||||||
ModelConfigWithCredentialsEntity,
|
ModelConfigWithCredentialsEntity,
|
||||||
)
|
)
|
||||||
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
|
from core.app.entities.queue_entities import (
|
||||||
|
QueueAgentMessageEvent,
|
||||||
|
QueueLLMChunkEvent,
|
||||||
|
QueueMessageEndEvent,
|
||||||
|
QueueMessageFileEvent,
|
||||||
|
)
|
||||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||||
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||||
|
from core.file.enums import FileTransferMethod, FileType
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
@@ -22,6 +30,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
TextPromptMessageContent,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||||
@@ -29,7 +38,10 @@ from core.moderation.input_moderation import InputModeration
|
|||||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||||
from models.model import App, AppMode, Message, MessageAnnotation
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.enums import CreatorUserRole
|
||||||
|
from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
@@ -203,6 +215,9 @@ class AppRunner:
|
|||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
agent: bool = False,
|
agent: bool = False,
|
||||||
|
message_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
tenant_id: str | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Handle invoke result
|
Handle invoke result
|
||||||
@@ -210,21 +225,41 @@ class AppRunner:
|
|||||||
:param queue_manager: application queue manager
|
:param queue_manager: application queue manager
|
||||||
:param stream: stream
|
:param stream: stream
|
||||||
:param agent: agent
|
:param agent: agent
|
||||||
|
:param message_id: message id for multimodal output
|
||||||
|
:param user_id: user id for multimodal output
|
||||||
|
:param tenant_id: tenant id for multimodal output
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if not stream and isinstance(invoke_result, LLMResult):
|
if not stream and isinstance(invoke_result, LLMResult):
|
||||||
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
self._handle_invoke_result_direct(
|
||||||
|
invoke_result=invoke_result,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
)
|
||||||
elif stream and isinstance(invoke_result, Generator):
|
elif stream and isinstance(invoke_result, Generator):
|
||||||
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
self._handle_invoke_result_stream(
|
||||||
|
invoke_result=invoke_result,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
agent=agent,
|
||||||
|
message_id=message_id,
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
|
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
|
||||||
|
|
||||||
def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool):
|
def _handle_invoke_result_direct(
|
||||||
|
self,
|
||||||
|
invoke_result: LLMResult,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Handle invoke result direct
|
Handle invoke result direct
|
||||||
:param invoke_result: invoke result
|
:param invoke_result: invoke result
|
||||||
:param queue_manager: application queue manager
|
:param queue_manager: application queue manager
|
||||||
:param agent: agent
|
:param agent: agent
|
||||||
|
:param message_id: message id for multimodal output
|
||||||
|
:param user_id: user id for multimodal output
|
||||||
|
:param tenant_id: tenant id for multimodal output
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
queue_manager.publish(
|
queue_manager.publish(
|
||||||
@@ -235,13 +270,22 @@ class AppRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _handle_invoke_result_stream(
|
def _handle_invoke_result_stream(
|
||||||
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
|
self,
|
||||||
|
invoke_result: Generator[LLMResultChunk, None, None],
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
agent: bool,
|
||||||
|
message_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
tenant_id: str | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Handle invoke result
|
Handle invoke result
|
||||||
:param invoke_result: invoke result
|
:param invoke_result: invoke result
|
||||||
:param queue_manager: application queue manager
|
:param queue_manager: application queue manager
|
||||||
:param agent: agent
|
:param agent: agent
|
||||||
|
:param message_id: message id for multimodal output
|
||||||
|
:param user_id: user id for multimodal output
|
||||||
|
:param tenant_id: tenant id for multimodal output
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
model: str = ""
|
model: str = ""
|
||||||
@@ -259,12 +303,26 @@ class AppRunner:
|
|||||||
text += message.content
|
text += message.content
|
||||||
elif isinstance(message.content, list):
|
elif isinstance(message.content, list):
|
||||||
for content in message.content:
|
for content in message.content:
|
||||||
if not isinstance(content, str):
|
if isinstance(content, str):
|
||||||
# TODO(QuantumGhost): Add multimodal output support for easy ui.
|
text += content
|
||||||
_logger.warning("received multimodal output, type=%s", type(content))
|
elif isinstance(content, TextPromptMessageContent):
|
||||||
text += content.data
|
text += content.data
|
||||||
|
elif isinstance(content, ImagePromptMessageContent):
|
||||||
|
if message_id and user_id and tenant_id:
|
||||||
|
try:
|
||||||
|
self._handle_multimodal_image_content(
|
||||||
|
content=content,
|
||||||
|
message_id=message_id,
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
_logger.exception("Failed to handle multimodal image output")
|
||||||
|
else:
|
||||||
|
_logger.warning("Received multimodal output but missing required parameters")
|
||||||
else:
|
else:
|
||||||
text += content # failback to str
|
text += content.data if hasattr(content, "data") else str(content)
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
model = result.model
|
model = result.model
|
||||||
@@ -289,6 +347,101 @@ class AppRunner:
|
|||||||
PublishFrom.APPLICATION_MANAGER,
|
PublishFrom.APPLICATION_MANAGER,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _handle_multimodal_image_content(
|
||||||
|
self,
|
||||||
|
content: ImagePromptMessageContent,
|
||||||
|
message_id: str,
|
||||||
|
user_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Handle multimodal image content from LLM response.
|
||||||
|
Save the image and create a MessageFile record.
|
||||||
|
|
||||||
|
:param content: ImagePromptMessageContent instance
|
||||||
|
:param message_id: message id
|
||||||
|
:param user_id: user id
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
_logger.info("Handling multimodal image content for message %s", message_id)
|
||||||
|
|
||||||
|
image_url = content.url
|
||||||
|
base64_data = content.base64_data
|
||||||
|
|
||||||
|
_logger.info("Image URL: %s, Base64 data present: %s", image_url, base64_data)
|
||||||
|
|
||||||
|
if not image_url and not base64_data:
|
||||||
|
_logger.warning("Image content has neither URL nor base64 data")
|
||||||
|
return
|
||||||
|
|
||||||
|
tool_file_manager = ToolFileManager()
|
||||||
|
|
||||||
|
# Save the image file
|
||||||
|
try:
|
||||||
|
if image_url:
|
||||||
|
# Download image from URL
|
||||||
|
_logger.info("Downloading image from URL: %s", image_url)
|
||||||
|
tool_file = tool_file_manager.create_file_by_url(
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
file_url=image_url,
|
||||||
|
conversation_id=None,
|
||||||
|
)
|
||||||
|
_logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
|
||||||
|
elif base64_data:
|
||||||
|
if base64_data.startswith("data:"):
|
||||||
|
base64_data = base64_data.split(",", 1)[1]
|
||||||
|
|
||||||
|
image_binary = base64.b64decode(base64_data)
|
||||||
|
mimetype = content.mime_type or "image/png"
|
||||||
|
extension = guess_extension(mimetype) or ".png"
|
||||||
|
|
||||||
|
tool_file = tool_file_manager.create_file_by_raw(
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
conversation_id=None,
|
||||||
|
file_binary=image_binary,
|
||||||
|
mimetype=mimetype,
|
||||||
|
filename=f"generated_image{extension}",
|
||||||
|
)
|
||||||
|
_logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
_logger.exception("Failed to save image file")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create MessageFile record
|
||||||
|
message_file = MessageFile(
|
||||||
|
message_id=message_id,
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||||
|
belongs_to="assistant",
|
||||||
|
url=f"/files/tools/{tool_file.id}",
|
||||||
|
upload_file_id=tool_file.id,
|
||||||
|
created_by_role=(
|
||||||
|
CreatorUserRole.ACCOUNT
|
||||||
|
if queue_manager.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}
|
||||||
|
else CreatorUserRole.END_USER
|
||||||
|
),
|
||||||
|
created_by=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(message_file)
|
||||||
|
db.session.commit()
|
||||||
|
db.session.refresh(message_file)
|
||||||
|
|
||||||
|
# Publish QueueMessageFileEvent
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueMessageFileEvent(message_file_id=message_file.id),
|
||||||
|
PublishFrom.APPLICATION_MANAGER,
|
||||||
|
)
|
||||||
|
|
||||||
|
_logger.info("QueueMessageFileEvent published for message_file_id: %s", message_file.id)
|
||||||
|
|
||||||
def moderation_for_inputs(
|
def moderation_for_inputs(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|||||||
@@ -226,5 +226,10 @@ class ChatAppRunner(AppRunner):
|
|||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
self._handle_invoke_result(
|
self._handle_invoke_result(
|
||||||
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
|
invoke_result=invoke_result,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
stream=application_generate_entity.stream,
|
||||||
|
message_id=message.id,
|
||||||
|
user_id=application_generate_entity.user_id,
|
||||||
|
tenant_id=app_config.tenant_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -184,5 +184,10 @@ class CompletionAppRunner(AppRunner):
|
|||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
self._handle_invoke_result(
|
self._handle_invoke_result(
|
||||||
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
|
invoke_result=invoke_result,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
stream=application_generate_entity.stream,
|
||||||
|
message_id=message.id,
|
||||||
|
user_id=application_generate_entity.user_id,
|
||||||
|
tenant_id=app_config.tenant_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from core.app.entities.task_entities import (
|
|||||||
MessageAudioEndStreamResponse,
|
MessageAudioEndStreamResponse,
|
||||||
MessageAudioStreamResponse,
|
MessageAudioStreamResponse,
|
||||||
MessageEndStreamResponse,
|
MessageEndStreamResponse,
|
||||||
|
StreamEvent,
|
||||||
StreamResponse,
|
StreamResponse,
|
||||||
)
|
)
|
||||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||||
@@ -70,6 +71,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
|
|
||||||
_task_state: EasyUITaskState
|
_task_state: EasyUITaskState
|
||||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||||
|
_precomputed_event_type: StreamEvent | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -342,11 +344,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
self._task_state.llm_result.message.content = current_content
|
self._task_state.llm_result.message.content = current_content
|
||||||
|
|
||||||
if isinstance(event, QueueLLMChunkEvent):
|
if isinstance(event, QueueLLMChunkEvent):
|
||||||
event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id)
|
# Determine the event type once, on first LLM chunk, and reuse for subsequent chunks
|
||||||
|
if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None:
|
||||||
|
self._precomputed_event_type = self._message_cycle_manager.get_message_event_type(
|
||||||
|
message_id=self._message_id
|
||||||
|
)
|
||||||
yield self._message_cycle_manager.message_to_stream_response(
|
yield self._message_cycle_manager.message_to_stream_response(
|
||||||
answer=cast(str, delta_text),
|
answer=cast(str, delta_text),
|
||||||
message_id=self._message_id,
|
message_id=self._message_id,
|
||||||
event_type=event_type,
|
event_type=self._precomputed_event_type,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield self._agent_message_to_stream_response(
|
yield self._agent_message_to_stream_response(
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from threading import Thread
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from sqlalchemy import exists, select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@@ -30,6 +30,7 @@ from core.app.entities.task_entities import (
|
|||||||
StreamEvent,
|
StreamEvent,
|
||||||
WorkflowTaskState,
|
WorkflowTaskState,
|
||||||
)
|
)
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from core.llm_generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from core.tools.signature import sign_tool_file
|
from core.tools.signature import sign_tool_file
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@@ -57,13 +58,15 @@ class MessageCycleManager:
|
|||||||
self._message_has_file: set[str] = set()
|
self._message_has_file: set[str] = set()
|
||||||
|
|
||||||
def get_message_event_type(self, message_id: str) -> StreamEvent:
|
def get_message_event_type(self, message_id: str) -> StreamEvent:
|
||||||
|
# Fast path: cached determination from prior QueueMessageFileEvent
|
||||||
if message_id in self._message_has_file:
|
if message_id in self._message_has_file:
|
||||||
return StreamEvent.MESSAGE_FILE
|
return StreamEvent.MESSAGE_FILE
|
||||||
|
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
# Use SQLAlchemy 2.x style session.scalar(select(...))
|
||||||
has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar()
|
with session_factory.create_session() as session:
|
||||||
|
message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id))
|
||||||
|
|
||||||
if has_file:
|
if message_file:
|
||||||
self._message_has_file.add(message_id)
|
self._message_has_file.add(message_id)
|
||||||
return StreamEvent.MESSAGE_FILE
|
return StreamEvent.MESSAGE_FILE
|
||||||
|
|
||||||
@@ -199,6 +202,8 @@ class MessageCycleManager:
|
|||||||
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
|
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
|
||||||
|
|
||||||
if message_file and message_file.url is not None:
|
if message_file and message_file.url is not None:
|
||||||
|
self._message_has_file.add(message_file.message_id)
|
||||||
|
|
||||||
# get tool file id
|
# get tool file id
|
||||||
tool_file_id = message_file.url.split("/")[-1]
|
tool_file_id = message_file.url.split("/")[-1]
|
||||||
# trim extension
|
# trim extension
|
||||||
|
|||||||
@@ -0,0 +1,454 @@
|
|||||||
|
"""Test multimodal image output handling in BaseAppRunner."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||||
|
from core.app.apps.base_app_runner import AppRunner
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.app.entities.queue_entities import QueueMessageFileEvent
|
||||||
|
from core.file.enums import FileTransferMethod, FileType
|
||||||
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||||
|
from models.enums import CreatorUserRole
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseAppRunnerMultimodal:
|
||||||
|
"""Test that BaseAppRunner correctly handles multimodal image content."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_user_id(self):
|
||||||
|
"""Mock user ID."""
|
||||||
|
return str(uuid4())
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tenant_id(self):
|
||||||
|
"""Mock tenant ID."""
|
||||||
|
return str(uuid4())
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_message_id(self):
|
||||||
|
"""Mock message ID."""
|
||||||
|
return str(uuid4())
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_queue_manager(self):
|
||||||
|
"""Create a mock queue manager."""
|
||||||
|
manager = MagicMock()
|
||||||
|
manager.invoke_from = InvokeFrom.SERVICE_API
|
||||||
|
return manager
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tool_file(self):
|
||||||
|
"""Create a mock tool file."""
|
||||||
|
tool_file = MagicMock()
|
||||||
|
tool_file.id = str(uuid4())
|
||||||
|
return tool_file
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_message_file(self):
|
||||||
|
"""Create a mock message file."""
|
||||||
|
message_file = MagicMock()
|
||||||
|
message_file.id = str(uuid4())
|
||||||
|
return message_file
|
||||||
|
|
||||||
|
def test_handle_multimodal_image_content_with_url(
|
||||||
|
self,
|
||||||
|
mock_user_id,
|
||||||
|
mock_tenant_id,
|
||||||
|
mock_message_id,
|
||||||
|
mock_queue_manager,
|
||||||
|
mock_tool_file,
|
||||||
|
mock_message_file,
|
||||||
|
):
|
||||||
|
"""Test handling image from URL."""
|
||||||
|
# Arrange
|
||||||
|
image_url = "http://example.com/image.png"
|
||||||
|
content = ImagePromptMessageContent(
|
||||||
|
url=image_url,
|
||||||
|
format="png",
|
||||||
|
mime_type="image/png",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||||
|
# Setup mock tool file manager
|
||||||
|
mock_mgr = MagicMock()
|
||||||
|
mock_mgr.create_file_by_url.return_value = mock_tool_file
|
||||||
|
mock_mgr_class.return_value = mock_mgr
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||||
|
# Setup mock message file
|
||||||
|
mock_msg_file_class.return_value = mock_message_file
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||||
|
mock_session.add = MagicMock()
|
||||||
|
mock_session.commit = MagicMock()
|
||||||
|
mock_session.refresh = MagicMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Create a mock runner with the method bound
|
||||||
|
runner = MagicMock()
|
||||||
|
|
||||||
|
method = AppRunner._handle_multimodal_image_content
|
||||||
|
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||||
|
|
||||||
|
runner._handle_multimodal_image_content(
|
||||||
|
content=content,
|
||||||
|
message_id=mock_message_id,
|
||||||
|
user_id=mock_user_id,
|
||||||
|
tenant_id=mock_tenant_id,
|
||||||
|
queue_manager=mock_queue_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Verify tool file was created from URL
|
||||||
|
mock_mgr.create_file_by_url.assert_called_once_with(
|
||||||
|
user_id=mock_user_id,
|
||||||
|
tenant_id=mock_tenant_id,
|
||||||
|
file_url=image_url,
|
||||||
|
conversation_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify message file was created with correct parameters
|
||||||
|
mock_msg_file_class.assert_called_once()
|
||||||
|
call_kwargs = mock_msg_file_class.call_args[1]
|
||||||
|
assert call_kwargs["message_id"] == mock_message_id
|
||||||
|
assert call_kwargs["type"] == FileType.IMAGE
|
||||||
|
assert call_kwargs["transfer_method"] == FileTransferMethod.TOOL_FILE
|
||||||
|
assert call_kwargs["belongs_to"] == "assistant"
|
||||||
|
assert call_kwargs["created_by"] == mock_user_id
|
||||||
|
|
||||||
|
# Verify database operations
|
||||||
|
mock_session.add.assert_called_once_with(mock_message_file)
|
||||||
|
mock_session.commit.assert_called_once()
|
||||||
|
mock_session.refresh.assert_called_once_with(mock_message_file)
|
||||||
|
|
||||||
|
# Verify event was published
|
||||||
|
mock_queue_manager.publish.assert_called_once()
|
||||||
|
publish_call = mock_queue_manager.publish.call_args
|
||||||
|
assert isinstance(publish_call[0][0], QueueMessageFileEvent)
|
||||||
|
assert publish_call[0][0].message_file_id == mock_message_file.id
|
||||||
|
# publish_from might be passed as positional or keyword argument
|
||||||
|
assert (
|
||||||
|
publish_call[0][1] == PublishFrom.APPLICATION_MANAGER
|
||||||
|
or publish_call.kwargs.get("publish_from") == PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_handle_multimodal_image_content_with_base64(
|
||||||
|
self,
|
||||||
|
mock_user_id,
|
||||||
|
mock_tenant_id,
|
||||||
|
mock_message_id,
|
||||||
|
mock_queue_manager,
|
||||||
|
mock_tool_file,
|
||||||
|
mock_message_file,
|
||||||
|
):
|
||||||
|
"""Test handling image from base64 data."""
|
||||||
|
# Arrange
|
||||||
|
import base64
|
||||||
|
|
||||||
|
# Create a small test image (1x1 PNG)
|
||||||
|
test_image_data = base64.b64encode(
|
||||||
|
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde"
|
||||||
|
).decode()
|
||||||
|
content = ImagePromptMessageContent(
|
||||||
|
base64_data=test_image_data,
|
||||||
|
format="png",
|
||||||
|
mime_type="image/png",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||||
|
# Setup mock tool file manager
|
||||||
|
mock_mgr = MagicMock()
|
||||||
|
mock_mgr.create_file_by_raw.return_value = mock_tool_file
|
||||||
|
mock_mgr_class.return_value = mock_mgr
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||||
|
# Setup mock message file
|
||||||
|
mock_msg_file_class.return_value = mock_message_file
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||||
|
mock_session.add = MagicMock()
|
||||||
|
mock_session.commit = MagicMock()
|
||||||
|
mock_session.refresh = MagicMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Create a mock runner with the method bound
|
||||||
|
runner = MagicMock()
|
||||||
|
method = AppRunner._handle_multimodal_image_content
|
||||||
|
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||||
|
|
||||||
|
runner._handle_multimodal_image_content(
|
||||||
|
content=content,
|
||||||
|
message_id=mock_message_id,
|
||||||
|
user_id=mock_user_id,
|
||||||
|
tenant_id=mock_tenant_id,
|
||||||
|
queue_manager=mock_queue_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Verify tool file was created from base64
|
||||||
|
mock_mgr.create_file_by_raw.assert_called_once()
|
||||||
|
call_kwargs = mock_mgr.create_file_by_raw.call_args[1]
|
||||||
|
assert call_kwargs["user_id"] == mock_user_id
|
||||||
|
assert call_kwargs["tenant_id"] == mock_tenant_id
|
||||||
|
assert call_kwargs["conversation_id"] is None
|
||||||
|
assert "file_binary" in call_kwargs
|
||||||
|
assert call_kwargs["mimetype"] == "image/png"
|
||||||
|
assert call_kwargs["filename"].startswith("generated_image")
|
||||||
|
assert call_kwargs["filename"].endswith(".png")
|
||||||
|
|
||||||
|
# Verify message file was created
|
||||||
|
mock_msg_file_class.assert_called_once()
|
||||||
|
|
||||||
|
# Verify database operations
|
||||||
|
mock_session.add.assert_called_once()
|
||||||
|
mock_session.commit.assert_called_once()
|
||||||
|
mock_session.refresh.assert_called_once()
|
||||||
|
|
||||||
|
# Verify event was published
|
||||||
|
mock_queue_manager.publish.assert_called_once()
|
||||||
|
|
||||||
|
def test_handle_multimodal_image_content_with_base64_data_uri(
|
||||||
|
self,
|
||||||
|
mock_user_id,
|
||||||
|
mock_tenant_id,
|
||||||
|
mock_message_id,
|
||||||
|
mock_queue_manager,
|
||||||
|
mock_tool_file,
|
||||||
|
mock_message_file,
|
||||||
|
):
|
||||||
|
"""Test handling image from base64 data with URI prefix."""
|
||||||
|
# Arrange
|
||||||
|
# Data URI format: data:image/png;base64,<base64_data>
|
||||||
|
test_image_data = (
|
||||||
|
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||||
|
)
|
||||||
|
content = ImagePromptMessageContent(
|
||||||
|
base64_data=f"data:image/png;base64,{test_image_data}",
|
||||||
|
format="png",
|
||||||
|
mime_type="image/png",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||||
|
# Setup mock tool file manager
|
||||||
|
mock_mgr = MagicMock()
|
||||||
|
mock_mgr.create_file_by_raw.return_value = mock_tool_file
|
||||||
|
mock_mgr_class.return_value = mock_mgr
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||||
|
# Setup mock message file
|
||||||
|
mock_msg_file_class.return_value = mock_message_file
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||||
|
mock_session.add = MagicMock()
|
||||||
|
mock_session.commit = MagicMock()
|
||||||
|
mock_session.refresh = MagicMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Create a mock runner with the method bound
|
||||||
|
runner = MagicMock()
|
||||||
|
method = AppRunner._handle_multimodal_image_content
|
||||||
|
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||||
|
|
||||||
|
runner._handle_multimodal_image_content(
|
||||||
|
content=content,
|
||||||
|
message_id=mock_message_id,
|
||||||
|
user_id=mock_user_id,
|
||||||
|
tenant_id=mock_tenant_id,
|
||||||
|
queue_manager=mock_queue_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - verify that base64 data was extracted correctly (without prefix)
|
||||||
|
mock_mgr.create_file_by_raw.assert_called_once()
|
||||||
|
call_kwargs = mock_mgr.create_file_by_raw.call_args[1]
|
||||||
|
# The base64 data should be decoded, so we check the binary was passed
|
||||||
|
assert "file_binary" in call_kwargs
|
||||||
|
|
||||||
|
def test_handle_multimodal_image_content_without_url_or_base64(
|
||||||
|
self,
|
||||||
|
mock_user_id,
|
||||||
|
mock_tenant_id,
|
||||||
|
mock_message_id,
|
||||||
|
mock_queue_manager,
|
||||||
|
):
|
||||||
|
"""Test handling image content without URL or base64 data."""
|
||||||
|
# Arrange
|
||||||
|
content = ImagePromptMessageContent(
|
||||||
|
url="",
|
||||||
|
base64_data="",
|
||||||
|
format="png",
|
||||||
|
mime_type="image/png",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||||
|
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||||
|
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||||
|
# Act
|
||||||
|
# Create a mock runner with the method bound
|
||||||
|
runner = MagicMock()
|
||||||
|
method = AppRunner._handle_multimodal_image_content
|
||||||
|
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||||
|
|
||||||
|
runner._handle_multimodal_image_content(
|
||||||
|
content=content,
|
||||||
|
message_id=mock_message_id,
|
||||||
|
user_id=mock_user_id,
|
||||||
|
tenant_id=mock_tenant_id,
|
||||||
|
queue_manager=mock_queue_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - should not create any files or publish events
|
||||||
|
mock_mgr_class.assert_not_called()
|
||||||
|
mock_msg_file_class.assert_not_called()
|
||||||
|
mock_session.add.assert_not_called()
|
||||||
|
mock_queue_manager.publish.assert_not_called()
|
||||||
|
|
||||||
|
def test_handle_multimodal_image_content_with_error(
|
||||||
|
self,
|
||||||
|
mock_user_id,
|
||||||
|
mock_tenant_id,
|
||||||
|
mock_message_id,
|
||||||
|
mock_queue_manager,
|
||||||
|
):
|
||||||
|
"""Test handling image content when an error occurs."""
|
||||||
|
# Arrange
|
||||||
|
image_url = "http://example.com/image.png"
|
||||||
|
content = ImagePromptMessageContent(
|
||||||
|
url=image_url,
|
||||||
|
format="png",
|
||||||
|
mime_type="image/png",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||||
|
# Setup mock to raise exception
|
||||||
|
mock_mgr = MagicMock()
|
||||||
|
mock_mgr.create_file_by_url.side_effect = Exception("Network error")
|
||||||
|
mock_mgr_class.return_value = mock_mgr
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||||
|
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||||
|
# Act
|
||||||
|
# Create a mock runner with the method bound
|
||||||
|
runner = MagicMock()
|
||||||
|
method = AppRunner._handle_multimodal_image_content
|
||||||
|
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||||
|
|
||||||
|
# Should not raise exception, just log it
|
||||||
|
runner._handle_multimodal_image_content(
|
||||||
|
content=content,
|
||||||
|
message_id=mock_message_id,
|
||||||
|
user_id=mock_user_id,
|
||||||
|
tenant_id=mock_tenant_id,
|
||||||
|
queue_manager=mock_queue_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - should not create message file or publish event on error
|
||||||
|
mock_msg_file_class.assert_not_called()
|
||||||
|
mock_session.add.assert_not_called()
|
||||||
|
mock_queue_manager.publish.assert_not_called()
|
||||||
|
|
||||||
|
def test_handle_multimodal_image_content_debugger_mode(
|
||||||
|
self,
|
||||||
|
mock_user_id,
|
||||||
|
mock_tenant_id,
|
||||||
|
mock_message_id,
|
||||||
|
mock_queue_manager,
|
||||||
|
mock_tool_file,
|
||||||
|
mock_message_file,
|
||||||
|
):
|
||||||
|
"""Test that debugger mode sets correct created_by_role."""
|
||||||
|
# Arrange
|
||||||
|
image_url = "http://example.com/image.png"
|
||||||
|
content = ImagePromptMessageContent(
|
||||||
|
url=image_url,
|
||||||
|
format="png",
|
||||||
|
mime_type="image/png",
|
||||||
|
)
|
||||||
|
mock_queue_manager.invoke_from = InvokeFrom.DEBUGGER
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||||
|
# Setup mock tool file manager
|
||||||
|
mock_mgr = MagicMock()
|
||||||
|
mock_mgr.create_file_by_url.return_value = mock_tool_file
|
||||||
|
mock_mgr_class.return_value = mock_mgr
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||||
|
# Setup mock message file
|
||||||
|
mock_msg_file_class.return_value = mock_message_file
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||||
|
mock_session.add = MagicMock()
|
||||||
|
mock_session.commit = MagicMock()
|
||||||
|
mock_session.refresh = MagicMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Create a mock runner with the method bound
|
||||||
|
runner = MagicMock()
|
||||||
|
method = AppRunner._handle_multimodal_image_content
|
||||||
|
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||||
|
|
||||||
|
runner._handle_multimodal_image_content(
|
||||||
|
content=content,
|
||||||
|
message_id=mock_message_id,
|
||||||
|
user_id=mock_user_id,
|
||||||
|
tenant_id=mock_tenant_id,
|
||||||
|
queue_manager=mock_queue_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - verify created_by_role is ACCOUNT for debugger mode
|
||||||
|
call_kwargs = mock_msg_file_class.call_args[1]
|
||||||
|
assert call_kwargs["created_by_role"] == CreatorUserRole.ACCOUNT
|
||||||
|
|
||||||
|
def test_handle_multimodal_image_content_service_api_mode(
|
||||||
|
self,
|
||||||
|
mock_user_id,
|
||||||
|
mock_tenant_id,
|
||||||
|
mock_message_id,
|
||||||
|
mock_queue_manager,
|
||||||
|
mock_tool_file,
|
||||||
|
mock_message_file,
|
||||||
|
):
|
||||||
|
"""Test that service API mode sets correct created_by_role."""
|
||||||
|
# Arrange
|
||||||
|
image_url = "http://example.com/image.png"
|
||||||
|
content = ImagePromptMessageContent(
|
||||||
|
url=image_url,
|
||||||
|
format="png",
|
||||||
|
mime_type="image/png",
|
||||||
|
)
|
||||||
|
mock_queue_manager.invoke_from = InvokeFrom.SERVICE_API
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||||
|
# Setup mock tool file manager
|
||||||
|
mock_mgr = MagicMock()
|
||||||
|
mock_mgr.create_file_by_url.return_value = mock_tool_file
|
||||||
|
mock_mgr_class.return_value = mock_mgr
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||||
|
# Setup mock message file
|
||||||
|
mock_msg_file_class.return_value = mock_message_file
|
||||||
|
|
||||||
|
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||||
|
mock_session.add = MagicMock()
|
||||||
|
mock_session.commit = MagicMock()
|
||||||
|
mock_session.refresh = MagicMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Create a mock runner with the method bound
|
||||||
|
runner = MagicMock()
|
||||||
|
method = AppRunner._handle_multimodal_image_content
|
||||||
|
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||||
|
|
||||||
|
runner._handle_multimodal_image_content(
|
||||||
|
content=content,
|
||||||
|
message_id=mock_message_id,
|
||||||
|
user_id=mock_user_id,
|
||||||
|
tenant_id=mock_tenant_id,
|
||||||
|
queue_manager=mock_queue_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - verify created_by_role is END_USER for service API
|
||||||
|
call_kwargs = mock_msg_file_class.call_args[1]
|
||||||
|
assert call_kwargs["created_by_role"] == CreatorUserRole.END_USER
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Unit tests for the message cycle manager optimization."""
|
"""Unit tests for the message cycle manager optimization."""
|
||||||
|
|
||||||
from types import SimpleNamespace
|
from unittest.mock import Mock, patch
|
||||||
from unittest.mock import ANY, Mock, patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
@@ -28,17 +27,14 @@ class TestMessageCycleManagerOptimization:
|
|||||||
|
|
||||||
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
|
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
|
||||||
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
|
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
|
||||||
with (
|
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
|
||||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
|
||||||
):
|
|
||||||
# Setup mock session and message file
|
# Setup mock session and message file
|
||||||
mock_session = Mock()
|
mock_session = Mock()
|
||||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
mock_message_file = Mock()
|
mock_message_file = Mock()
|
||||||
# Current implementation uses session.query(...).scalar()
|
# Current implementation uses session.scalar(select(...))
|
||||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
mock_session.scalar.return_value = mock_message_file
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
with current_app.app_context():
|
with current_app.app_context():
|
||||||
@@ -46,19 +42,16 @@ class TestMessageCycleManagerOptimization:
|
|||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result == StreamEvent.MESSAGE_FILE
|
assert result == StreamEvent.MESSAGE_FILE
|
||||||
mock_session.query.return_value.scalar.assert_called_once()
|
mock_session.scalar.assert_called_once()
|
||||||
|
|
||||||
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
|
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
|
||||||
"""Test get_message_event_type returns MESSAGE when message has no files."""
|
"""Test get_message_event_type returns MESSAGE when message has no files."""
|
||||||
with (
|
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
|
||||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
|
||||||
):
|
|
||||||
# Setup mock session and no message file
|
# Setup mock session and no message file
|
||||||
mock_session = Mock()
|
mock_session = Mock()
|
||||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||||
# Current implementation uses session.query(...).scalar()
|
# Current implementation uses session.scalar(select(...))
|
||||||
mock_session.query.return_value.scalar.return_value = None
|
mock_session.scalar.return_value = None
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
with current_app.app_context():
|
with current_app.app_context():
|
||||||
@@ -66,21 +59,18 @@ class TestMessageCycleManagerOptimization:
|
|||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result == StreamEvent.MESSAGE
|
assert result == StreamEvent.MESSAGE
|
||||||
mock_session.query.return_value.scalar.assert_called_once()
|
mock_session.scalar.assert_called_once()
|
||||||
|
|
||||||
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
|
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
|
||||||
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
|
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
|
||||||
with (
|
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
|
||||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
|
||||||
):
|
|
||||||
# Setup mock session and message file
|
# Setup mock session and message file
|
||||||
mock_session = Mock()
|
mock_session = Mock()
|
||||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
mock_message_file = Mock()
|
mock_message_file = Mock()
|
||||||
# Current implementation uses session.query(...).scalar()
|
# Current implementation uses session.scalar(select(...))
|
||||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
mock_session.scalar.return_value = mock_message_file
|
||||||
|
|
||||||
# Execute: compute event type once, then pass to message_to_stream_response
|
# Execute: compute event type once, then pass to message_to_stream_response
|
||||||
with current_app.app_context():
|
with current_app.app_context():
|
||||||
@@ -94,11 +84,11 @@ class TestMessageCycleManagerOptimization:
|
|||||||
assert result.answer == "Hello world"
|
assert result.answer == "Hello world"
|
||||||
assert result.id == "test-message-id"
|
assert result.id == "test-message-id"
|
||||||
assert result.event == StreamEvent.MESSAGE_FILE
|
assert result.event == StreamEvent.MESSAGE_FILE
|
||||||
mock_session.query.return_value.scalar.assert_called_once()
|
mock_session.scalar.assert_called_once()
|
||||||
|
|
||||||
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
|
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
|
||||||
"""Test that message_to_stream_response skips database query when event_type is provided."""
|
"""Test that message_to_stream_response skips database query when event_type is provided."""
|
||||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||||
# Execute with event_type provided
|
# Execute with event_type provided
|
||||||
result = message_cycle_manager.message_to_stream_response(
|
result = message_cycle_manager.message_to_stream_response(
|
||||||
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||||
@@ -109,8 +99,8 @@ class TestMessageCycleManagerOptimization:
|
|||||||
assert result.answer == "Hello world"
|
assert result.answer == "Hello world"
|
||||||
assert result.id == "test-message-id"
|
assert result.id == "test-message-id"
|
||||||
assert result.event == StreamEvent.MESSAGE
|
assert result.event == StreamEvent.MESSAGE
|
||||||
# Should not query database when event_type is provided
|
# Should not open a session when event_type is provided
|
||||||
mock_session_class.assert_not_called()
|
mock_session_factory.create_session.assert_not_called()
|
||||||
|
|
||||||
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
|
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
|
||||||
"""Test message_to_stream_response with from_variable_selector parameter."""
|
"""Test message_to_stream_response with from_variable_selector parameter."""
|
||||||
@@ -130,24 +120,21 @@ class TestMessageCycleManagerOptimization:
|
|||||||
def test_optimization_usage_example(self, message_cycle_manager):
|
def test_optimization_usage_example(self, message_cycle_manager):
|
||||||
"""Test the optimization pattern that should be used by callers."""
|
"""Test the optimization pattern that should be used by callers."""
|
||||||
# Step 1: Get event type once (this queries database)
|
# Step 1: Get event type once (this queries database)
|
||||||
with (
|
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
|
||||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
|
||||||
):
|
|
||||||
mock_session = Mock()
|
mock_session = Mock()
|
||||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||||
# Current implementation uses session.query(...).scalar()
|
# Current implementation uses session.scalar(select(...))
|
||||||
mock_session.query.return_value.scalar.return_value = None # No files
|
mock_session.scalar.return_value = None # No files
|
||||||
with current_app.app_context():
|
with current_app.app_context():
|
||||||
event_type = message_cycle_manager.get_message_event_type("test-message-id")
|
event_type = message_cycle_manager.get_message_event_type("test-message-id")
|
||||||
|
|
||||||
# Should query database once
|
# Should open session once
|
||||||
mock_session_class.assert_called_once_with(ANY, expire_on_commit=False)
|
mock_session_factory.create_session.assert_called_once()
|
||||||
assert event_type == StreamEvent.MESSAGE
|
assert event_type == StreamEvent.MESSAGE
|
||||||
|
|
||||||
# Step 2: Use event_type for multiple calls (no additional queries)
|
# Step 2: Use event_type for multiple calls (no additional queries)
|
||||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||||
mock_session_class.return_value.__enter__.return_value = Mock()
|
mock_session_factory.create_session.return_value.__enter__.return_value = Mock()
|
||||||
|
|
||||||
chunk1_response = message_cycle_manager.message_to_stream_response(
|
chunk1_response = message_cycle_manager.message_to_stream_response(
|
||||||
answer="Chunk 1", message_id="test-message-id", event_type=event_type
|
answer="Chunk 1", message_id="test-message-id", event_type=event_type
|
||||||
@@ -157,8 +144,8 @@ class TestMessageCycleManagerOptimization:
|
|||||||
answer="Chunk 2", message_id="test-message-id", event_type=event_type
|
answer="Chunk 2", message_id="test-message-id", event_type=event_type
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should not query database again
|
# Should not open session again when event_type provided
|
||||||
mock_session_class.assert_not_called()
|
mock_session_factory.create_session.assert_not_called()
|
||||||
|
|
||||||
assert chunk1_response.event == StreamEvent.MESSAGE
|
assert chunk1_response.event == StreamEvent.MESSAGE
|
||||||
assert chunk2_response.event == StreamEvent.MESSAGE
|
assert chunk2_response.event == StreamEvent.MESSAGE
|
||||||
|
|||||||
178
web/app/components/base/chat/chat/hooks.multimodal.spec.ts
Normal file
178
web/app/components/base/chat/chat/hooks.multimodal.spec.ts
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
/**
|
||||||
|
* Tests for multimodal image file handling in chat hooks.
|
||||||
|
* Tests the file object conversion logic without full hook integration.
|
||||||
|
*/
|
||||||
|
|
||||||
|
describe('Multimodal File Handling', () => {
|
||||||
|
describe('File type to MIME type mapping', () => {
|
||||||
|
it('should map image to image/png', () => {
|
||||||
|
const fileType: string = 'image'
|
||||||
|
const expectedMime = 'image/png'
|
||||||
|
const mimeType = fileType === 'image' ? 'image/png' : 'application/octet-stream'
|
||||||
|
expect(mimeType).toBe(expectedMime)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should map video to video/mp4', () => {
|
||||||
|
const fileType: string = 'video'
|
||||||
|
const expectedMime = 'video/mp4'
|
||||||
|
const mimeType = fileType === 'video' ? 'video/mp4' : 'application/octet-stream'
|
||||||
|
expect(mimeType).toBe(expectedMime)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should map audio to audio/mpeg', () => {
|
||||||
|
const fileType: string = 'audio'
|
||||||
|
const expectedMime = 'audio/mpeg'
|
||||||
|
const mimeType = fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'
|
||||||
|
expect(mimeType).toBe(expectedMime)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should map unknown to application/octet-stream', () => {
|
||||||
|
const fileType: string = 'unknown'
|
||||||
|
const expectedMime = 'application/octet-stream'
|
||||||
|
const mimeType = ['image', 'video', 'audio'].includes(fileType) ? 'image/png' : 'application/octet-stream'
|
||||||
|
expect(mimeType).toBe(expectedMime)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('TransferMethod selection', () => {
|
||||||
|
it('should select remote_url for images', () => {
|
||||||
|
const fileType: string = 'image'
|
||||||
|
const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file'
|
||||||
|
expect(transferMethod).toBe('remote_url')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should select local_file for non-images', () => {
|
||||||
|
const fileType: string = 'video'
|
||||||
|
const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file'
|
||||||
|
expect(transferMethod).toBe('local_file')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('File extension mapping', () => {
|
||||||
|
it('should use .png extension for images', () => {
|
||||||
|
const fileType: string = 'image'
|
||||||
|
const expectedExtension = '.png'
|
||||||
|
const extension = fileType === 'image' ? 'png' : 'bin'
|
||||||
|
expect(extension).toBe(expectedExtension.replace('.', ''))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should use .mp4 extension for videos', () => {
|
||||||
|
const fileType: string = 'video'
|
||||||
|
const expectedExtension = '.mp4'
|
||||||
|
const extension = fileType === 'video' ? 'mp4' : 'bin'
|
||||||
|
expect(extension).toBe(expectedExtension.replace('.', ''))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should use .mp3 extension for audio', () => {
|
||||||
|
const fileType: string = 'audio'
|
||||||
|
const expectedExtension = '.mp3'
|
||||||
|
const extension = fileType === 'audio' ? 'mp3' : 'bin'
|
||||||
|
expect(extension).toBe(expectedExtension.replace('.', ''))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('File name generation', () => {
|
||||||
|
it('should generate correct file name for images', () => {
|
||||||
|
const fileType: string = 'image'
|
||||||
|
const expectedName = 'generated_image.png'
|
||||||
|
const fileName = `generated_${fileType}.${fileType === 'image' ? 'png' : 'bin'}`
|
||||||
|
expect(fileName).toBe(expectedName)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should generate correct file name for videos', () => {
|
||||||
|
const fileType: string = 'video'
|
||||||
|
const expectedName = 'generated_video.mp4'
|
||||||
|
const fileName = `generated_${fileType}.${fileType === 'video' ? 'mp4' : 'bin'}`
|
||||||
|
expect(fileName).toBe(expectedName)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should generate correct file name for audio', () => {
|
||||||
|
const fileType: string = 'audio'
|
||||||
|
const expectedName = 'generated_audio.mp3'
|
||||||
|
const fileName = `generated_${fileType}.${fileType === 'audio' ? 'mp3' : 'bin'}`
|
||||||
|
expect(fileName).toBe(expectedName)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('SupportFileType mapping', () => {
|
||||||
|
it('should map image type to image supportFileType', () => {
|
||||||
|
const fileType: string = 'image'
|
||||||
|
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
|
||||||
|
expect(supportFileType).toBe('image')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should map video type to video supportFileType', () => {
|
||||||
|
const fileType: string = 'video'
|
||||||
|
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
|
||||||
|
expect(supportFileType).toBe('video')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should map audio type to audio supportFileType', () => {
|
||||||
|
const fileType: string = 'audio'
|
||||||
|
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
|
||||||
|
expect(supportFileType).toBe('audio')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should map unknown type to document supportFileType', () => {
|
||||||
|
const fileType: string = 'unknown'
|
||||||
|
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
|
||||||
|
expect(supportFileType).toBe('document')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('File conversion logic', () => {
|
||||||
|
it('should detect existing transferMethod', () => {
|
||||||
|
const fileWithTransferMethod = {
|
||||||
|
id: 'file-123',
|
||||||
|
transferMethod: 'remote_url' as const,
|
||||||
|
type: 'image/png',
|
||||||
|
name: 'test.png',
|
||||||
|
size: 1024,
|
||||||
|
supportFileType: 'image',
|
||||||
|
progress: 100,
|
||||||
|
}
|
||||||
|
const hasTransferMethod = 'transferMethod' in fileWithTransferMethod
|
||||||
|
expect(hasTransferMethod).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should detect missing transferMethod', () => {
|
||||||
|
const fileWithoutTransferMethod = {
|
||||||
|
id: 'file-456',
|
||||||
|
type: 'image',
|
||||||
|
url: 'http://example.com/image.png',
|
||||||
|
belongs_to: 'assistant',
|
||||||
|
}
|
||||||
|
const hasTransferMethod = 'transferMethod' in fileWithoutTransferMethod
|
||||||
|
expect(hasTransferMethod).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should create file with size 0 for generated files', () => {
|
||||||
|
const expectedSize = 0
|
||||||
|
expect(expectedSize).toBe(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Agent vs Non-Agent mode logic', () => {
|
||||||
|
it('should check for agent_thoughts to determine mode', () => {
|
||||||
|
const agentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {
|
||||||
|
agent_thoughts: [{}],
|
||||||
|
}
|
||||||
|
const isAgentMode = agentResponse.agent_thoughts && agentResponse.agent_thoughts.length > 0
|
||||||
|
expect(isAgentMode).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should detect non-agent mode when agent_thoughts is empty', () => {
|
||||||
|
const nonAgentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {
|
||||||
|
agent_thoughts: [],
|
||||||
|
}
|
||||||
|
const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0
|
||||||
|
expect(isAgentMode).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should detect non-agent mode when agent_thoughts is undefined', () => {
|
||||||
|
const nonAgentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {}
|
||||||
|
const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0
|
||||||
|
expect(isAgentMode).toBeFalsy()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -419,9 +419,40 @@ export const useChat = (
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
onFile(file) {
|
onFile(file) {
|
||||||
|
// Convert simple file type to MIME type for non-agent mode
|
||||||
|
// Backend sends: { id, type: "image", belongs_to, url }
|
||||||
|
// Frontend expects: { id, type: "image/png", transferMethod, url, uploadedId, supportFileType, name, size }
|
||||||
|
|
||||||
|
// Determine file type for MIME conversion
|
||||||
|
const fileType = (file as { type?: string }).type || 'image'
|
||||||
|
|
||||||
|
// If file already has transferMethod, use it as base and ensure all required fields exist
|
||||||
|
// Otherwise, create a new complete file object
|
||||||
|
const baseFile = ('transferMethod' in file) ? (file as Partial<FileEntity>) : null
|
||||||
|
|
||||||
|
const convertedFile: FileEntity = {
|
||||||
|
id: baseFile?.id || (file as { id: string }).id,
|
||||||
|
type: baseFile?.type || (fileType === 'image' ? 'image/png' : fileType === 'video' ? 'video/mp4' : fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'),
|
||||||
|
transferMethod: (baseFile?.transferMethod as FileEntity['transferMethod']) || (fileType === 'image' ? 'remote_url' : 'local_file'),
|
||||||
|
uploadedId: baseFile?.uploadedId || (file as { id: string }).id,
|
||||||
|
supportFileType: baseFile?.supportFileType || (fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'),
|
||||||
|
progress: baseFile?.progress ?? 100,
|
||||||
|
name: baseFile?.name || `generated_${fileType}.${fileType === 'image' ? 'png' : fileType === 'video' ? 'mp4' : fileType === 'audio' ? 'mp3' : 'bin'}`,
|
||||||
|
url: baseFile?.url || (file as { url?: string }).url,
|
||||||
|
size: baseFile?.size ?? 0, // Generated files don't have a known size
|
||||||
|
}
|
||||||
|
|
||||||
|
// For agent mode, add files to the last thought
|
||||||
const lastThought = responseItem.agent_thoughts?.[responseItem.agent_thoughts?.length - 1]
|
const lastThought = responseItem.agent_thoughts?.[responseItem.agent_thoughts?.length - 1]
|
||||||
if (lastThought)
|
if (lastThought) {
|
||||||
responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(lastThought as any).message_files, file]
|
const thought = lastThought as { message_files?: FileEntity[] }
|
||||||
|
responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(thought.message_files ?? []), convertedFile]
|
||||||
|
}
|
||||||
|
// For non-agent mode, add files directly to responseItem.message_files
|
||||||
|
else {
|
||||||
|
const currentFiles = (responseItem.message_files as FileEntity[] | undefined) ?? []
|
||||||
|
responseItem.message_files = [...currentFiles, convertedFile]
|
||||||
|
}
|
||||||
|
|
||||||
updateCurrentQAOnTree({
|
updateCurrentQAOnTree({
|
||||||
placeholderQuestionId,
|
placeholderQuestionId,
|
||||||
|
|||||||
@@ -2039,8 +2039,13 @@ describe('Integration: Hit Testing Flow', () => {
|
|||||||
|
|
||||||
renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||||
|
|
||||||
|
// Wait for textbox with timeout for CI
|
||||||
|
const textarea = await waitFor(
|
||||||
|
() => screen.getByRole('textbox'),
|
||||||
|
{ timeout: 3000 },
|
||||||
|
)
|
||||||
|
|
||||||
// Type query
|
// Type query
|
||||||
const textarea = screen.getByRole('textbox')
|
|
||||||
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
||||||
|
|
||||||
// Find submit button by class
|
// Find submit button by class
|
||||||
@@ -2054,8 +2059,13 @@ describe('Integration: Hit Testing Flow', () => {
|
|||||||
|
|
||||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||||
|
|
||||||
|
// Wait for textbox with timeout for CI
|
||||||
|
const textarea = await waitFor(
|
||||||
|
() => screen.getByRole('textbox'),
|
||||||
|
{ timeout: 3000 },
|
||||||
|
)
|
||||||
|
|
||||||
// Type query
|
// Type query
|
||||||
const textarea = screen.getByRole('textbox')
|
|
||||||
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
||||||
|
|
||||||
// Component should still be functional - check for the main container
|
// Component should still be functional - check for the main container
|
||||||
@@ -2089,10 +2099,15 @@ describe('Integration: Hit Testing Flow', () => {
|
|||||||
isLoading: false,
|
isLoading: false,
|
||||||
} as unknown as ReturnType<typeof useDatasetTestingRecords>)
|
} as unknown as ReturnType<typeof useDatasetTestingRecords>)
|
||||||
|
|
||||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
const { container: _container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||||
|
|
||||||
|
// Wait for textbox to be rendered with timeout for CI environment
|
||||||
|
const textarea = await waitFor(
|
||||||
|
() => screen.getByRole('textbox'),
|
||||||
|
{ timeout: 3000 },
|
||||||
|
)
|
||||||
|
|
||||||
// Type query
|
// Type query
|
||||||
const textarea = screen.getByRole('textbox')
|
|
||||||
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
||||||
|
|
||||||
// Submit
|
// Submit
|
||||||
@@ -2101,8 +2116,13 @@ describe('Integration: Hit Testing Flow', () => {
|
|||||||
if (submitButton)
|
if (submitButton)
|
||||||
fireEvent.click(submitButton)
|
fireEvent.click(submitButton)
|
||||||
|
|
||||||
// Verify the component is still rendered after submission
|
// Wait for the mutation to complete
|
||||||
expect(container.firstChild).toBeInTheDocument()
|
await waitFor(
|
||||||
|
() => {
|
||||||
|
expect(mockHitTestingMutateAsync).toHaveBeenCalled()
|
||||||
|
},
|
||||||
|
{ timeout: 3000 },
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should render ResultItem components for non-external results', async () => {
|
it('should render ResultItem components for non-external results', async () => {
|
||||||
@@ -2127,10 +2147,15 @@ describe('Integration: Hit Testing Flow', () => {
|
|||||||
isLoading: false,
|
isLoading: false,
|
||||||
} as unknown as ReturnType<typeof useDatasetTestingRecords>)
|
} as unknown as ReturnType<typeof useDatasetTestingRecords>)
|
||||||
|
|
||||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
const { container: _container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||||
|
|
||||||
|
// Wait for component to be fully rendered with longer timeout
|
||||||
|
const textarea = await waitFor(
|
||||||
|
() => screen.getByRole('textbox'),
|
||||||
|
{ timeout: 3000 },
|
||||||
|
)
|
||||||
|
|
||||||
// Submit a query
|
// Submit a query
|
||||||
const textarea = screen.getByRole('textbox')
|
|
||||||
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
||||||
|
|
||||||
const buttons = screen.getAllByRole('button')
|
const buttons = screen.getAllByRole('button')
|
||||||
@@ -2138,8 +2163,13 @@ describe('Integration: Hit Testing Flow', () => {
|
|||||||
if (submitButton)
|
if (submitButton)
|
||||||
fireEvent.click(submitButton)
|
fireEvent.click(submitButton)
|
||||||
|
|
||||||
// Verify component is rendered after submission
|
// Wait for mutation to complete with longer timeout
|
||||||
expect(container.firstChild).toBeInTheDocument()
|
await waitFor(
|
||||||
|
() => {
|
||||||
|
expect(mockHitTestingMutateAsync).toHaveBeenCalled()
|
||||||
|
},
|
||||||
|
{ timeout: 3000 },
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should render external results when dataset is external', async () => {
|
it('should render external results when dataset is external', async () => {
|
||||||
@@ -2165,8 +2195,14 @@ describe('Integration: Hit Testing Flow', () => {
|
|||||||
|
|
||||||
// Component should render
|
// Component should render
|
||||||
expect(container.firstChild).toBeInTheDocument()
|
expect(container.firstChild).toBeInTheDocument()
|
||||||
|
|
||||||
|
// Wait for textbox with timeout for CI
|
||||||
|
const textarea = await waitFor(
|
||||||
|
() => screen.getByRole('textbox'),
|
||||||
|
{ timeout: 3000 },
|
||||||
|
)
|
||||||
|
|
||||||
// Type in textarea to verify component is functional
|
// Type in textarea to verify component is functional
|
||||||
const textarea = screen.getByRole('textbox')
|
|
||||||
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
||||||
|
|
||||||
const buttons = screen.getAllByRole('button')
|
const buttons = screen.getAllByRole('button')
|
||||||
@@ -2174,9 +2210,13 @@ describe('Integration: Hit Testing Flow', () => {
|
|||||||
if (submitButton)
|
if (submitButton)
|
||||||
fireEvent.click(submitButton)
|
fireEvent.click(submitButton)
|
||||||
|
|
||||||
await waitFor(() => {
|
// Verify component is still functional after submission
|
||||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
await waitFor(
|
||||||
})
|
() => {
|
||||||
|
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
||||||
|
},
|
||||||
|
{ timeout: 3000 },
|
||||||
|
)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -2260,8 +2300,13 @@ describe('renderHitResults Coverage', () => {
|
|||||||
|
|
||||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||||
|
|
||||||
|
// Wait for textbox with timeout for CI
|
||||||
|
const textarea = await waitFor(
|
||||||
|
() => screen.getByRole('textbox'),
|
||||||
|
{ timeout: 3000 },
|
||||||
|
)
|
||||||
|
|
||||||
// Enter query
|
// Enter query
|
||||||
const textarea = screen.getByRole('textbox')
|
|
||||||
fireEvent.change(textarea, { target: { value: 'test query' } })
|
fireEvent.change(textarea, { target: { value: 'test query' } })
|
||||||
|
|
||||||
// Submit
|
// Submit
|
||||||
@@ -2386,8 +2431,13 @@ describe('HitTestingPage Internal Functions Coverage', () => {
|
|||||||
|
|
||||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||||
|
|
||||||
|
// Wait for textbox with timeout for CI
|
||||||
|
const textarea = await waitFor(
|
||||||
|
() => screen.getByRole('textbox'),
|
||||||
|
{ timeout: 3000 },
|
||||||
|
)
|
||||||
|
|
||||||
// Enter query and submit
|
// Enter query and submit
|
||||||
const textarea = screen.getByRole('textbox')
|
|
||||||
fireEvent.change(textarea, { target: { value: 'test query' } })
|
fireEvent.change(textarea, { target: { value: 'test query' } })
|
||||||
|
|
||||||
const buttons = screen.getAllByRole('button')
|
const buttons = screen.getAllByRole('button')
|
||||||
@@ -2400,7 +2450,7 @@ describe('HitTestingPage Internal Functions Coverage', () => {
|
|||||||
// Wait for state updates
|
// Wait for state updates
|
||||||
await waitFor(() => {
|
await waitFor(() => {
|
||||||
expect(container.firstChild).toBeInTheDocument()
|
expect(container.firstChild).toBeInTheDocument()
|
||||||
}, { timeout: 2000 })
|
}, { timeout: 3000 })
|
||||||
|
|
||||||
// Verify mutation was called
|
// Verify mutation was called
|
||||||
expect(mockHitTestingMutateAsync).toHaveBeenCalled()
|
expect(mockHitTestingMutateAsync).toHaveBeenCalled()
|
||||||
@@ -2445,8 +2495,13 @@ describe('HitTestingPage Internal Functions Coverage', () => {
|
|||||||
|
|
||||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||||
|
|
||||||
|
// Wait for textbox with timeout for CI
|
||||||
|
const textarea = await waitFor(
|
||||||
|
() => screen.getByRole('textbox'),
|
||||||
|
{ timeout: 3000 },
|
||||||
|
)
|
||||||
|
|
||||||
// Submit a query
|
// Submit a query
|
||||||
const textarea = screen.getByRole('textbox')
|
|
||||||
fireEvent.change(textarea, { target: { value: 'test' } })
|
fireEvent.change(textarea, { target: { value: 'test' } })
|
||||||
|
|
||||||
const buttons = screen.getAllByRole('button')
|
const buttons = screen.getAllByRole('button')
|
||||||
@@ -2458,7 +2513,7 @@ describe('HitTestingPage Internal Functions Coverage', () => {
|
|||||||
// Verify the component renders
|
// Verify the component renders
|
||||||
await waitFor(() => {
|
await waitFor(() => {
|
||||||
expect(container.firstChild).toBeInTheDocument()
|
expect(container.firstChild).toBeInTheDocument()
|
||||||
})
|
}, { timeout: 3000 })
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -162,6 +162,44 @@ vi.mock('@/utils/var', () => ({
|
|||||||
getMarketplaceUrl: (path: string, _params?: Record<string, string | undefined>) => `https://marketplace.dify.ai${path}`,
|
getMarketplaceUrl: (path: string, _params?: Record<string, string | undefined>) => `https://marketplace.dify.ai${path}`,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
// Mock marketplace client used by marketplace utils
|
||||||
|
vi.mock('@/service/client', () => ({
|
||||||
|
marketplaceClient: {
|
||||||
|
collections: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({
|
||||||
|
data: {
|
||||||
|
collections: [
|
||||||
|
{
|
||||||
|
name: 'collection-1',
|
||||||
|
label: { 'en-US': 'Collection 1' },
|
||||||
|
description: { 'en-US': 'Desc' },
|
||||||
|
rule: '',
|
||||||
|
created_at: '2024-01-01',
|
||||||
|
updated_at: '2024-01-01',
|
||||||
|
searchable: true,
|
||||||
|
search_params: { query: '', sort_by: 'install_count', sort_order: 'DESC' },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
})),
|
||||||
|
collectionPlugins: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({
|
||||||
|
data: {
|
||||||
|
plugins: [
|
||||||
|
{ type: 'plugin', org: 'test', name: 'plugin1', tags: [] },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
})),
|
||||||
|
// Some utils paths may call searchAdvanced; provide a minimal stub
|
||||||
|
searchAdvanced: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({
|
||||||
|
data: {
|
||||||
|
plugins: [
|
||||||
|
{ type: 'plugin', org: 'test', name: 'plugin1', tags: [] },
|
||||||
|
],
|
||||||
|
total: 1,
|
||||||
|
},
|
||||||
|
})),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
// Mock context/query-client
|
// Mock context/query-client
|
||||||
vi.mock('@/context/query-client', () => ({
|
vi.mock('@/context/query-client', () => ({
|
||||||
TanstackQueryInitializer: ({ children }: { children: React.ReactNode }) => <div data-testid="query-initializer">{children}</div>,
|
TanstackQueryInitializer: ({ children }: { children: React.ReactNode }) => <div data-testid="query-initializer">{children}</div>,
|
||||||
@@ -1474,7 +1512,24 @@ describe('flatMap Coverage', () => {
|
|||||||
// ================================
|
// ================================
|
||||||
// Async Utils Tests
|
// Async Utils Tests
|
||||||
// ================================
|
// ================================
|
||||||
|
|
||||||
|
// Narrow mock surface and avoid any in tests
|
||||||
|
// Types are local to this spec to keep scope minimal
|
||||||
|
|
||||||
|
type FnMock = ReturnType<typeof vi.fn>
|
||||||
|
|
||||||
|
type MarketplaceClientMock = {
|
||||||
|
collectionPlugins: FnMock
|
||||||
|
collections: FnMock
|
||||||
|
}
|
||||||
|
|
||||||
describe('Async Utils', () => {
|
describe('Async Utils', () => {
|
||||||
|
let marketplaceClientMock: MarketplaceClientMock
|
||||||
|
|
||||||
|
beforeAll(async () => {
|
||||||
|
const mod = await import('@/service/client')
|
||||||
|
marketplaceClientMock = mod.marketplaceClient as unknown as MarketplaceClientMock
|
||||||
|
})
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
})
|
})
|
||||||
@@ -1490,12 +1545,10 @@ describe('Async Utils', () => {
|
|||||||
{ type: 'plugin', org: 'test', name: 'plugin2' },
|
{ type: 'plugin', org: 'test', name: 'plugin2' },
|
||||||
]
|
]
|
||||||
|
|
||||||
globalThis.fetch = vi.fn().mockResolvedValue(
|
// Adjusted to our mocked marketplaceClient instead of fetch
|
||||||
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
|
marketplaceClientMock.collectionPlugins.mockResolvedValueOnce({
|
||||||
status: 200,
|
data: { plugins: mockPlugins },
|
||||||
headers: { 'Content-Type': 'application/json' },
|
})
|
||||||
}),
|
|
||||||
)
|
|
||||||
|
|
||||||
const { getMarketplacePluginsByCollectionId } = await import('./utils')
|
const { getMarketplacePluginsByCollectionId } = await import('./utils')
|
||||||
const result = await getMarketplacePluginsByCollectionId('test-collection', {
|
const result = await getMarketplacePluginsByCollectionId('test-collection', {
|
||||||
@@ -1504,12 +1557,13 @@ describe('Async Utils', () => {
|
|||||||
type: 'plugin',
|
type: 'plugin',
|
||||||
})
|
})
|
||||||
|
|
||||||
expect(globalThis.fetch).toHaveBeenCalled()
|
expect(marketplaceClientMock.collectionPlugins).toHaveBeenCalled()
|
||||||
expect(result).toHaveLength(2)
|
expect(result).toHaveLength(2)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should handle fetch error and return empty array', async () => {
|
it('should handle fetch error and return empty array', async () => {
|
||||||
globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error'))
|
// Simulate error from client
|
||||||
|
marketplaceClientMock.collectionPlugins.mockRejectedValueOnce(new Error('Network error'))
|
||||||
|
|
||||||
const { getMarketplacePluginsByCollectionId } = await import('./utils')
|
const { getMarketplacePluginsByCollectionId } = await import('./utils')
|
||||||
const result = await getMarketplacePluginsByCollectionId('test-collection')
|
const result = await getMarketplacePluginsByCollectionId('test-collection')
|
||||||
@@ -1519,25 +1573,18 @@ describe('Async Utils', () => {
|
|||||||
|
|
||||||
it('should pass abort signal when provided', async () => {
|
it('should pass abort signal when provided', async () => {
|
||||||
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
|
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
|
||||||
globalThis.fetch = vi.fn().mockResolvedValue(
|
// Our client mock receives the signal as second arg
|
||||||
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
|
marketplaceClientMock.collectionPlugins.mockResolvedValueOnce({
|
||||||
status: 200,
|
data: { plugins: mockPlugins },
|
||||||
headers: { 'Content-Type': 'application/json' },
|
})
|
||||||
}),
|
|
||||||
)
|
|
||||||
|
|
||||||
const controller = new AbortController()
|
const controller = new AbortController()
|
||||||
const { getMarketplacePluginsByCollectionId } = await import('./utils')
|
const { getMarketplacePluginsByCollectionId } = await import('./utils')
|
||||||
await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal })
|
await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal })
|
||||||
|
|
||||||
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
|
expect(marketplaceClientMock.collectionPlugins).toHaveBeenCalled()
|
||||||
expect(globalThis.fetch).toHaveBeenCalledWith(
|
const call = marketplaceClientMock.collectionPlugins.mock.calls[0]
|
||||||
expect.any(Request),
|
expect(call[1]).toMatchObject({ signal: controller.signal })
|
||||||
expect.any(Object),
|
|
||||||
)
|
|
||||||
const call = vi.mocked(globalThis.fetch).mock.calls[0]
|
|
||||||
const request = call[0] as Request
|
|
||||||
expect(request.url).toContain('test-collection')
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -1548,23 +1595,17 @@ describe('Async Utils', () => {
|
|||||||
]
|
]
|
||||||
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
|
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
|
||||||
|
|
||||||
let callCount = 0
|
// Simulate two-step client calls: collections then collectionPlugins
|
||||||
globalThis.fetch = vi.fn().mockImplementation(() => {
|
let stage = 0
|
||||||
callCount++
|
marketplaceClientMock.collections.mockImplementationOnce(async () => {
|
||||||
if (callCount === 1) {
|
stage = 1
|
||||||
return Promise.resolve(
|
return { data: { collections: mockCollections } }
|
||||||
new Response(JSON.stringify({ data: { collections: mockCollections } }), {
|
})
|
||||||
status: 200,
|
marketplaceClientMock.collectionPlugins.mockImplementation(async () => {
|
||||||
headers: { 'Content-Type': 'application/json' },
|
if (stage === 1) {
|
||||||
}),
|
return { data: { plugins: mockPlugins } }
|
||||||
)
|
|
||||||
}
|
}
|
||||||
return Promise.resolve(
|
return { data: { plugins: [] } }
|
||||||
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
|
|
||||||
status: 200,
|
|
||||||
headers: { 'Content-Type': 'application/json' },
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
|
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
|
||||||
@@ -1578,7 +1619,8 @@ describe('Async Utils', () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
it('should handle fetch error and return empty data', async () => {
|
it('should handle fetch error and return empty data', async () => {
|
||||||
globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error'))
|
// Simulate client error
|
||||||
|
marketplaceClientMock.collections.mockRejectedValueOnce(new Error('Network error'))
|
||||||
|
|
||||||
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
|
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
|
||||||
const result = await getMarketplaceCollectionsAndPlugins()
|
const result = await getMarketplaceCollectionsAndPlugins()
|
||||||
@@ -1588,24 +1630,16 @@ describe('Async Utils', () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
it('should append condition and type to URL when provided', async () => {
|
it('should append condition and type to URL when provided', async () => {
|
||||||
globalThis.fetch = vi.fn().mockResolvedValue(
|
// Assert that the client was called with query containing condition/type
|
||||||
new Response(JSON.stringify({ data: { collections: [] } }), {
|
|
||||||
status: 200,
|
|
||||||
headers: { 'Content-Type': 'application/json' },
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
|
|
||||||
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
|
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
|
||||||
await getMarketplaceCollectionsAndPlugins({
|
await getMarketplaceCollectionsAndPlugins({
|
||||||
condition: 'category=tool',
|
condition: 'category=tool',
|
||||||
type: 'bundle',
|
type: 'bundle',
|
||||||
})
|
})
|
||||||
|
|
||||||
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
|
expect(marketplaceClientMock.collections).toHaveBeenCalled()
|
||||||
expect(globalThis.fetch).toHaveBeenCalled()
|
const call = marketplaceClientMock.collections.mock.calls[0]
|
||||||
const call = vi.mocked(globalThis.fetch).mock.calls[0]
|
expect(call[0]).toMatchObject({ query: expect.objectContaining({ condition: 'category=tool', type: 'bundle' }) })
|
||||||
const request = call[0] as Request
|
|
||||||
expect(request.url).toContain('condition=category%3Dtool')
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -822,7 +822,7 @@
|
|||||||
"count": 2
|
"count": 2
|
||||||
},
|
},
|
||||||
"ts/no-explicit-any": {
|
"ts/no-explicit-any": {
|
||||||
"count": 15
|
"count": 14
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"app/components/base/chat/chat/index.tsx": {
|
"app/components/base/chat/chat/index.tsx": {
|
||||||
|
|||||||
@@ -152,6 +152,8 @@ export const formatNumberAbbreviated = (num: number) => {
|
|||||||
: `${formatted}${units[unitIndex].symbol}`
|
: `${formatted}${units[unitIndex].symbol}`
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Fallback: if no threshold matched, return the number string
|
||||||
|
return num.toString()
|
||||||
}
|
}
|
||||||
|
|
||||||
export const formatToLocalTime = (time: Dayjs, local: Locale, format: string) => {
|
export const formatToLocalTime = (time: Dayjs, local: Locale, format: string) => {
|
||||||
|
|||||||
Reference in New Issue
Block a user