From 57f358a96b68f72b490fb55342228af965a21685 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Wed, 1 Apr 2026 09:19:32 +0800 Subject: [PATCH 01/42] perf: use global httpx client instead of per request create new one (#34311) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/core/plugin/impl/base.py | 10 +++++-- .../vdb/tidb_on_qdrant/tidb_service.py | 25 ++++++++++++---- api/libs/oauth.py | 18 +++++++---- api/libs/oauth_data_source.py | 18 +++++++---- api/services/auth/jina.py | 8 ++++- api/services/auth/jina/jina.py | 8 ++++- api/services/billing_service.py | 8 ++++- api/services/website_service.py | 25 +++++++++++----- .../services/auth/test_auth_integration.py | 2 +- .../core/datasource/test_website_crawl.py | 7 +++-- .../core/plugin/impl/test_base_client_impl.py | 2 +- .../core/plugin/test_endpoint_client.py | 11 +++++++ .../core/plugin/test_plugin_runtime.py | 14 +++++++++ .../unit_tests/libs/test_oauth_clients.py | 18 +++++------ .../services/auth/test_jina_auth.py | 14 ++++----- .../auth/test_jina_auth_standalone_module.py | 8 ++--- .../services/test_billing_service.py | 2 +- .../test_datasource_provider_service.py | 3 ++ .../services/test_website_service.py | 30 ++++++++++++------- 19 files changed, 167 insertions(+), 64 deletions(-) diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 2d0ab3fcd73..706ae248f0a 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -17,6 +17,7 @@ from pydantic import BaseModel from yarl import URL from configs import dify_config +from core.helper.http_client_pooling import get_pooled_http_client from core.plugin.endpoint.exc import EndpointSetupFailedError from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError from core.plugin.impl.exc import ( @@ -54,6 +55,11 @@ T = TypeVar("T", bound=(BaseModel | dict[str, Any] | list[Any] | bool | str)) logger = logging.getLogger(__name__) +_httpx_client: httpx.Client = get_pooled_http_client( + "plugin_daemon", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100), trust_env=False), +) + class BasePluginClient: def _request( @@ -84,7 +90,7 @@ class BasePluginClient: request_kwargs["content"] = prepared_data try: - response = httpx.request(**request_kwargs) + response = _httpx_client.request(**request_kwargs) except httpx.RequestError: logger.exception("Request to Plugin Daemon Service failed") raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed") @@ -171,7 +177,7 @@ class BasePluginClient: stream_kwargs["content"] = prepared_data try: - with httpx.stream(**stream_kwargs) as response: + with _httpx_client.stream(**stream_kwargs) as response: for raw_line in response.iter_lines(): if not raw_line: continue diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 06b17b9e62c..37114be6e72 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -6,11 +6,18 @@ import httpx from httpx import DigestAuth from configs import dify_config +from core.helper.http_client_pooling import get_pooled_http_client from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import TidbAuthBinding from models.enums import TidbAuthBindingStatus +# Reuse a pooled HTTP client for all TiDB Cloud requests to minimize connection churn +_tidb_http_client: httpx.Client = get_pooled_http_client( + "tidb:cloud", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) + class TidbService: @staticmethod @@ -50,7 +57,9 @@ class TidbService: "rootPassword": password, } - response = httpx.post(f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key)) + response = _tidb_http_client.post( + f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key) + ) if response.status_code == 200: response_data = response.json() @@ -84,7 +93,9 @@ class TidbService: :return: The response from the API. """ - response = httpx.delete(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key)) + response = _tidb_http_client.delete( + f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key) + ) if response.status_code == 200: return response.json() @@ -103,7 +114,7 @@ class TidbService: :return: The response from the API. """ - response = httpx.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key)) + response = _tidb_http_client.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key)) if response.status_code == 200: return response.json() @@ -128,7 +139,7 @@ class TidbService: body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []} - response = httpx.patch( + response = _tidb_http_client.patch( f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}", json=body, auth=DigestAuth(public_key, private_key), @@ -162,7 +173,9 @@ class TidbService: tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} cluster_ids = [item.cluster_id for item in tidb_serverless_list] params = {"clusterIds": cluster_ids, "view": "BASIC"} - response = httpx.get(f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key)) + response = _tidb_http_client.get( + f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key) + ) if response.status_code == 200: response_data = response.json() @@ -223,7 +236,7 @@ class TidbService: clusters.append(cluster_data) request_body = {"requests": clusters} - response = httpx.post( + response = _tidb_http_client.post( f"{api_url}/clusters:batchCreate", json=request_body, auth=DigestAuth(public_key, private_key) ) diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 76e741301cd..a2f11140333 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -7,6 +7,8 @@ from typing import NotRequired import httpx from pydantic import TypeAdapter, ValidationError +from core.helper.http_client_pooling import get_pooled_http_client + if sys.version_info >= (3, 12): from typing import TypedDict else: @@ -20,6 +22,12 @@ JsonObjectList = list[JsonObject] JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject) JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList) +# Reuse a pooled httpx.Client for OAuth flows (public endpoints, no SSRF proxy). +_http_client: httpx.Client = get_pooled_http_client( + "oauth:default", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) + class AccessTokenResponse(TypedDict, total=False): access_token: str @@ -115,7 +123,7 @@ class GitHubOAuth(OAuth): "redirect_uri": self.redirect_uri, } headers = {"Accept": "application/json"} - response = httpx.post(self._TOKEN_URL, data=data, headers=headers) + response = _http_client.post(self._TOKEN_URL, data=data, headers=headers) response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response)) access_token = response_json.get("access_token") @@ -127,7 +135,7 @@ class GitHubOAuth(OAuth): def get_raw_user_info(self, token: str) -> JsonObject: headers = {"Authorization": f"token {token}"} - response = httpx.get(self._USER_INFO_URL, headers=headers) + response = _http_client.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response)) @@ -147,7 +155,7 @@ class GitHubOAuth(OAuth): Returns an empty string when no usable email is found. """ try: - email_response = httpx.get(GitHubOAuth._EMAIL_INFO_URL, headers=headers) + email_response = _http_client.get(GitHubOAuth._EMAIL_INFO_URL, headers=headers) email_response.raise_for_status() email_records = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response)) except (httpx.HTTPStatusError, ValidationError): @@ -204,7 +212,7 @@ class GoogleOAuth(OAuth): "redirect_uri": self.redirect_uri, } headers = {"Accept": "application/json"} - response = httpx.post(self._TOKEN_URL, data=data, headers=headers) + response = _http_client.post(self._TOKEN_URL, data=data, headers=headers) response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response)) access_token = response_json.get("access_token") @@ -216,7 +224,7 @@ class GoogleOAuth(OAuth): def get_raw_user_info(self, token: str) -> JsonObject: headers = {"Authorization": f"Bearer {token}"} - response = httpx.get(self._USER_INFO_URL, headers=headers) + response = _http_client.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() return _json_object(response) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index d5dc35ac977..190558e1f39 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -7,6 +7,7 @@ from flask_login import current_user from pydantic import TypeAdapter from sqlalchemy import select +from core.helper.http_client_pooling import get_pooled_http_client from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.source import DataSourceOauthBinding @@ -38,6 +39,13 @@ NOTION_SOURCE_INFO_ADAPTER = TypeAdapter(NotionSourceInfo) NOTION_PAGE_SUMMARY_ADAPTER = TypeAdapter(NotionPageSummary) +# Reuse a small pooled client for OAuth data source flows. +_http_client: httpx.Client = get_pooled_http_client( + "oauth:notion", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) + + class OAuthDataSource: client_id: str client_secret: str @@ -75,7 +83,7 @@ class NotionOAuth(OAuthDataSource): data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} headers = {"Accept": "application/json"} auth = (self.client_id, self.client_secret) - response = httpx.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) + response = _http_client.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) response_json = response.json() access_token = response_json.get("access_token") @@ -268,7 +276,7 @@ class NotionOAuth(OAuthDataSource): "Notion-Version": "2022-06-28", } - response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) + response = _http_client.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response_json = response.json() results.extend(response_json.get("results", [])) @@ -283,7 +291,7 @@ class NotionOAuth(OAuthDataSource): "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", } - response = httpx.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) + response = _http_client.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) response_json = response.json() if response.status_code != 200: message = response_json.get("message", "unknown error") @@ -299,7 +307,7 @@ class NotionOAuth(OAuthDataSource): "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", } - response = httpx.get(url=self._NOTION_BOT_USER, headers=headers) + response = _http_client.get(url=self._NOTION_BOT_USER, headers=headers) response_json = response.json() if "object" in response_json and response_json["object"] == "user": user_type = response_json["type"] @@ -323,7 +331,7 @@ class NotionOAuth(OAuthDataSource): "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", } - response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) + response = _http_client.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response_json = response.json() results.extend(response_json.get("results", [])) diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index e5e2319ce13..e63c9a3a4db 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -2,8 +2,14 @@ import json import httpx +from core.helper.http_client_pooling import get_pooled_http_client from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials +_http_client: httpx.Client = get_pooled_http_client( + "auth:jina_standalone", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) + class JinaAuth(ApiKeyAuthBase): def __init__(self, credentials: AuthCredentials): @@ -31,7 +37,7 @@ class JinaAuth(ApiKeyAuthBase): return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _post_request(self, url, data, headers): - return httpx.post(url, headers=headers, json=data) + return _http_client.post(url, headers=headers, json=data) def _handle_error(self, response): if response.status_code in {402, 409, 500}: diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index e5e2319ce13..8ea0b6cd69c 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -2,8 +2,14 @@ import json import httpx +from core.helper.http_client_pooling import get_pooled_http_client from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials +_http_client: httpx.Client = get_pooled_http_client( + "auth:jina", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) + class JinaAuth(ApiKeyAuthBase): def __init__(self, credentials: AuthCredentials): @@ -31,7 +37,7 @@ class JinaAuth(ApiKeyAuthBase): return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _post_request(self, url, data, headers): - return httpx.post(url, headers=headers, json=data) + return _http_client.post(url, headers=headers, json=data) def _handle_error(self, response): if response.status_code in {402, 409, 500}: diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 70d4ce1ee6b..54c595e0cbd 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -10,6 +10,7 @@ from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fix from typing_extensions import TypedDict from werkzeug.exceptions import InternalServerError +from core.helper.http_client_pooling import get_pooled_http_client from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -18,6 +19,11 @@ from models import Account, TenantAccountJoin, TenantAccountRole logger = logging.getLogger(__name__) +_http_client: httpx.Client = get_pooled_http_client( + "billing:default", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) + class SubscriptionPlan(TypedDict): """Tenant subscriptionplan information.""" @@ -131,7 +137,7 @@ class BillingService: headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" - response = httpx.request(method, url, json=json, params=params, headers=headers, follow_redirects=True) + response = _http_client.request(method, url, json=json, params=params, headers=headers, follow_redirects=True) if method == "GET" and response.status_code != httpx.codes.OK: raise ValueError("Unable to retrieve billing information. Please try again later or contact support.") if method == "PUT": diff --git a/api/services/website_service.py b/api/services/website_service.py index b2917ba1529..6a521a9cc0c 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -9,12 +9,23 @@ import httpx from flask_login import current_user from core.helper import encrypter +from core.helper.http_client_pooling import get_pooled_http_client from core.rag.extractor.firecrawl.firecrawl_app import CrawlStatusResponse, FirecrawlApp, FirecrawlDocumentData from core.rag.extractor.watercrawl.provider import WaterCrawlProvider from extensions.ext_redis import redis_client from extensions.ext_storage import storage from services.datasource_provider_service import DatasourceProviderService +# Reuse pooled HTTP clients to avoid creating new connections per request and ease testing. +_jina_http_client: httpx.Client = get_pooled_http_client( + "website:jinareader", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) +_adaptive_http_client: httpx.Client = get_pooled_http_client( + "website:adaptivecrawl", + lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)), +) + @dataclass class CrawlOptions: @@ -225,7 +236,7 @@ class WebsiteService: @classmethod def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]: if not request.options.crawl_sub_pages: - response = httpx.get( + response = _jina_http_client.get( f"https://r.jina.ai/{request.url}", headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, ) @@ -233,7 +244,7 @@ class WebsiteService: raise ValueError("Failed to crawl:") return {"status": "active", "data": response.json().get("data")} else: - response = httpx.post( + response = _adaptive_http_client.post( "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app", json={ "url": request.url, @@ -296,7 +307,7 @@ class WebsiteService: @classmethod def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]: - response = httpx.post( + response = _adaptive_http_client.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, json={"taskId": job_id}, @@ -312,7 +323,7 @@ class WebsiteService: } if crawl_status_data["status"] == "completed": - response = httpx.post( + response = _adaptive_http_client.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, @@ -374,7 +385,7 @@ class WebsiteService: @classmethod def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None: if not job_id: - response = httpx.get( + response = _jina_http_client.get( f"https://r.jina.ai/{url}", headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, ) @@ -383,7 +394,7 @@ class WebsiteService: return dict(response.json().get("data", {})) else: # Get crawl status first - status_response = httpx.post( + status_response = _adaptive_http_client.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, json={"taskId": job_id}, @@ -393,7 +404,7 @@ class WebsiteService: raise ValueError("Crawl job is not completed") # Get processed data - data_response = httpx.post( + data_response = _adaptive_http_client.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())}, diff --git a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py index dc4c0fda1d4..f48c6da690f 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py @@ -79,7 +79,7 @@ class TestAuthIntegration: @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") @patch("services.auth.firecrawl.firecrawl.httpx.post") - @patch("services.auth.jina.jina.httpx.post") + @patch("services.auth.jina.jina._http_client.post") def test_multi_tenant_isolation( self, mock_jina_http, diff --git a/api/tests/unit_tests/core/datasource/test_website_crawl.py b/api/tests/unit_tests/core/datasource/test_website_crawl.py index 1d79db2640a..53000881ddc 100644 --- a/api/tests/unit_tests/core/datasource/test_website_crawl.py +++ b/api/tests/unit_tests/core/datasource/test_website_crawl.py @@ -560,7 +560,10 @@ class TestWebsiteService: mock_response = Mock() mock_response.json.return_value = {"code": 200, "data": {"taskId": "task-789"}} - mock_httpx_post = mocker.patch("services.website_service.httpx.post", return_value=mock_response) + mock_httpx_post = mocker.patch( + "services.website_service._adaptive_http_client.post", + return_value=mock_response, + ) from services.website_service import WebsiteCrawlApiRequest @@ -1340,7 +1343,7 @@ class TestProviderSpecificFeatures: "url": "https://example.com/page", }, } - mocker.patch("services.website_service.httpx.get", return_value=mock_response) + mocker.patch("services.website_service._jina_http_client.get", return_value=mock_response) from services.website_service import WebsiteCrawlApiRequest diff --git a/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py b/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py index c216906d68e..23894bd417b 100644 --- a/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py +++ b/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py @@ -57,7 +57,7 @@ class TestBasePluginClientImpl: def test_stream_request_handles_data_lines_and_dict_payload(self, mocker): client = BasePluginClient() stream_mock = mocker.patch( - "core.plugin.impl.base.httpx.stream", + "httpx.Client.stream", return_value=_StreamContext([b"", b"data: hello", "world"]), ) diff --git a/api/tests/unit_tests/core/plugin/test_endpoint_client.py b/api/tests/unit_tests/core/plugin/test_endpoint_client.py index 48e30e9c2fc..ff9deb918af 100644 --- a/api/tests/unit_tests/core/plugin/test_endpoint_client.py +++ b/api/tests/unit_tests/core/plugin/test_endpoint_client.py @@ -10,12 +10,23 @@ Tests follow the Arrange-Act-Assert pattern for clarity. from unittest.mock import MagicMock, patch +import httpx import pytest from core.plugin.impl.endpoint import PluginEndpointClient from core.plugin.impl.exc import PluginDaemonInternalServerError +@pytest.fixture(autouse=True) +def _patch_shared_httpx_client(): + """Patch module-level client methods to delegate to module httpx.request/stream.""" + with ( + patch("core.plugin.impl.base._httpx_client.request", side_effect=lambda **kw: httpx.request(**kw)), + patch("core.plugin.impl.base._httpx_client.stream", side_effect=lambda **kw: httpx.stream(**kw)), + ): + yield + + class TestPluginEndpointClientDelete: """Unit tests for PluginEndpointClient delete_endpoint operation. diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index 3063ca01970..a3b1e5f6b0e 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -47,6 +47,20 @@ from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.tool import PluginToolManager +@pytest.fixture(autouse=True) +def _patch_shared_httpx_client(): + """Make BasePluginClient's module-level httpx client delegate to patched httpx.request/stream. + + After refactor, code uses core.plugin.impl.base._httpx_client directly. + Patch its request/stream to route through module-level httpx so existing mocks still apply. + """ + with ( + patch("core.plugin.impl.base._httpx_client.request", side_effect=lambda **kw: httpx.request(**kw)), + patch("core.plugin.impl.base._httpx_client.stream", side_effect=lambda **kw: httpx.stream(**kw)), + ): + yield + + class TestPluginRuntimeExecution: """Unit tests for plugin execution functionality. diff --git a/api/tests/unit_tests/libs/test_oauth_clients.py b/api/tests/unit_tests/libs/test_oauth_clients.py index ab468c86874..830284e6979 100644 --- a/api/tests/unit_tests/libs/test_oauth_clients.py +++ b/api/tests/unit_tests/libs/test_oauth_clients.py @@ -68,7 +68,7 @@ class TestGitHubOAuth(BaseOAuthTest): ({}, None, True), ], ) - @patch("httpx.post", autospec=True) + @patch("libs.oauth._http_client.post", autospec=True) def test_should_retrieve_access_token( self, mock_post, oauth, mock_response, response_data, expected_token, should_raise ): @@ -109,7 +109,7 @@ class TestGitHubOAuth(BaseOAuthTest): ), ], ) - @patch("httpx.get", autospec=True) + @patch("libs.oauth._http_client.get", autospec=True) def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email): user_response = MagicMock() user_response.json.return_value = user_data @@ -127,7 +127,7 @@ class TestGitHubOAuth(BaseOAuthTest): # The profile email is absent/null, so /user/emails should be called assert mock_get.call_count == 2 - @patch("httpx.get", autospec=True) + @patch("libs.oauth._http_client.get", autospec=True) def test_should_skip_email_endpoint_when_profile_email_present(self, mock_get, oauth): """When the /user profile already contains an email, do not call /user/emails.""" user_response = MagicMock() @@ -162,7 +162,7 @@ class TestGitHubOAuth(BaseOAuthTest): ), ], ) - @patch("httpx.get", autospec=True) + @patch("libs.oauth._http_client.get", autospec=True) def test_should_use_noreply_email_when_no_usable_email(self, mock_get, oauth, user_data, email_data): user_response = MagicMock() user_response.json.return_value = user_data @@ -177,7 +177,7 @@ class TestGitHubOAuth(BaseOAuthTest): assert user_info.id == str(user_data["id"]) assert user_info.email == "12345@users.noreply.github.com" - @patch("httpx.get", autospec=True) + @patch("libs.oauth._http_client.get", autospec=True) def test_should_use_noreply_email_when_email_endpoint_fails(self, mock_get, oauth): user_response = MagicMock() user_response.json.return_value = {"id": 12345, "login": "testuser", "name": "Test User"} @@ -194,7 +194,7 @@ class TestGitHubOAuth(BaseOAuthTest): assert user_info.id == "12345" assert user_info.email == "12345@users.noreply.github.com" - @patch("httpx.get", autospec=True) + @patch("libs.oauth._http_client.get", autospec=True) def test_should_handle_network_errors(self, mock_get, oauth): mock_get.side_effect = httpx.RequestError("Network error") @@ -240,7 +240,7 @@ class TestGoogleOAuth(BaseOAuthTest): ({}, None, True), ], ) - @patch("httpx.post", autospec=True) + @patch("libs.oauth._http_client.post", autospec=True) def test_should_retrieve_access_token( self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise ): @@ -274,7 +274,7 @@ class TestGoogleOAuth(BaseOAuthTest): ({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string ], ) - @patch("httpx.get", autospec=True) + @patch("libs.oauth._http_client.get", autospec=True) def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name): mock_response.json.return_value = user_data mock_get.return_value = mock_response @@ -295,7 +295,7 @@ class TestGoogleOAuth(BaseOAuthTest): httpx.TimeoutException, ], ) - @patch("httpx.get", autospec=True) + @patch("libs.oauth._http_client.get", autospec=True) def test_should_handle_http_errors(self, mock_get, oauth, exception_type): mock_response = MagicMock() mock_response.raise_for_status.side_effect = exception_type("Error") diff --git a/api/tests/unit_tests/services/auth/test_jina_auth.py b/api/tests/unit_tests/services/auth/test_jina_auth.py index 67f252390d5..2c34d46f1e4 100644 --- a/api/tests/unit_tests/services/auth/test_jina_auth.py +++ b/api/tests/unit_tests/services/auth/test_jina_auth.py @@ -35,7 +35,7 @@ class TestJinaAuth: JinaAuth(credentials) assert str(exc_info.value) == "No API key provided" - @patch("services.auth.jina.jina.httpx.post", autospec=True) + @patch("services.auth.jina.jina._http_client.post", autospec=True) def test_should_validate_valid_credentials_successfully(self, mock_post): """Test successful credential validation""" mock_response = MagicMock() @@ -53,7 +53,7 @@ class TestJinaAuth: json={"url": "https://example.com"}, ) - @patch("services.auth.jina.jina.httpx.post", autospec=True) + @patch("services.auth.jina.jina._http_client.post", autospec=True) def test_should_handle_http_402_error(self, mock_post): """Test handling of 402 Payment Required error""" mock_response = MagicMock() @@ -68,7 +68,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required" - @patch("services.auth.jina.jina.httpx.post", autospec=True) + @patch("services.auth.jina.jina._http_client.post", autospec=True) def test_should_handle_http_409_error(self, mock_post): """Test handling of 409 Conflict error""" mock_response = MagicMock() @@ -83,7 +83,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error" - @patch("services.auth.jina.jina.httpx.post", autospec=True) + @patch("services.auth.jina.jina._http_client.post", autospec=True) def test_should_handle_http_500_error(self, mock_post): """Test handling of 500 Internal Server Error""" mock_response = MagicMock() @@ -98,7 +98,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error" - @patch("services.auth.jina.jina.httpx.post", autospec=True) + @patch("services.auth.jina.jina._http_client.post", autospec=True) def test_should_handle_unexpected_error_with_text_response(self, mock_post): """Test handling of unexpected errors with text response""" mock_response = MagicMock() @@ -114,7 +114,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden" - @patch("services.auth.jina.jina.httpx.post", autospec=True) + @patch("services.auth.jina.jina._http_client.post", autospec=True) def test_should_handle_unexpected_error_without_text(self, mock_post): """Test handling of unexpected errors without text response""" mock_response = MagicMock() @@ -130,7 +130,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404" - @patch("services.auth.jina.jina.httpx.post", autospec=True) + @patch("services.auth.jina.jina._http_client.post", autospec=True) def test_should_handle_network_errors(self, mock_post): """Test handling of network connection errors""" mock_post.side_effect = httpx.ConnectError("Network error") diff --git a/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py b/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py index c2fcd71875c..4b5a97bf3fd 100644 --- a/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py +++ b/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py @@ -60,7 +60,7 @@ def test_prepare_headers_includes_bearer_api_key(jina_module: ModuleType) -> Non def test_post_request_calls_httpx(jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch) -> None: auth = jina_module.JinaAuth(_credentials(api_key="k")) post_mock = MagicMock(name="httpx.post") - monkeypatch.setattr(jina_module.httpx, "post", post_mock) + monkeypatch.setattr(jina_module._http_client, "post", post_mock) auth._post_request("https://r.jina.ai", {"url": "https://example.com"}, {"h": "v"}) post_mock.assert_called_once_with("https://r.jina.ai", headers={"h": "v"}, json={"url": "https://example.com"}) @@ -72,7 +72,7 @@ def test_validate_credentials_success(jina_module: ModuleType, monkeypatch: pyte response = MagicMock() response.status_code = 200 post_mock = MagicMock(return_value=response) - monkeypatch.setattr(jina_module.httpx, "post", post_mock) + monkeypatch.setattr(jina_module._http_client, "post", post_mock) assert auth.validate_credentials() is True post_mock.assert_called_once_with( @@ -90,7 +90,7 @@ def test_validate_credentials_non_200_raises_via_handle_error( response = MagicMock() response.status_code = 402 response.json.return_value = {"error": "Payment required"} - monkeypatch.setattr(jina_module.httpx, "post", MagicMock(return_value=response)) + monkeypatch.setattr(jina_module._http_client, "post", MagicMock(return_value=response)) with pytest.raises(Exception, match="Status code: 402.*Payment required"): auth.validate_credentials() @@ -151,7 +151,7 @@ def test_validate_credentials_propagates_network_errors( jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch ) -> None: auth = jina_module.JinaAuth(_credentials(api_key="k")) - monkeypatch.setattr(jina_module.httpx, "post", MagicMock(side_effect=httpx.ConnectError("boom"))) + monkeypatch.setattr(jina_module._http_client, "post", MagicMock(side_effect=httpx.ConnectError("boom"))) with pytest.raises(httpx.ConnectError, match="boom"): auth.validate_credentials() diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 316381f0ca1..b3d2e608025 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -38,7 +38,7 @@ class TestBillingServiceSendRequest: @pytest.fixture def mock_httpx_request(self): """Mock httpx.request for testing.""" - with patch("services.billing_service.httpx.request") as mock_request: + with patch("services.billing_service._http_client.request") as mock_request: yield mock_request @pytest.fixture diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py index 3df7d500cf2..da414816ff1 100644 --- a/api/tests/unit_tests/services/test_datasource_provider_service.py +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -1,5 +1,6 @@ from unittest.mock import MagicMock, patch +import httpx import pytest from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy.orm import Session @@ -71,6 +72,8 @@ class TestDatasourceProviderService: @pytest.fixture(autouse=True) def patch_externals(self): with ( + patch("core.plugin.impl.base._httpx_client.request", side_effect=lambda **kw: httpx.request(**kw)), + patch("core.plugin.impl.base._httpx_client.stream", side_effect=lambda **kw: httpx.stream(**kw)), patch("httpx.request") as mock_httpx, patch("services.datasource_provider_service.dify_config") as mock_cfg, patch("services.datasource_provider_service.encrypter") as mock_enc, diff --git a/api/tests/unit_tests/services/test_website_service.py b/api/tests/unit_tests/services/test_website_service.py index e973da7d564..b0ddc7388a1 100644 --- a/api/tests/unit_tests/services/test_website_service.py +++ b/api/tests/unit_tests/services/test_website_service.py @@ -343,7 +343,7 @@ def test_crawl_with_watercrawl_passes_options_dict(monkeypatch: pytest.MonkeyPat def test_crawl_with_jinareader_single_page_success(monkeypatch: pytest.MonkeyPatch) -> None: get_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"title": "t"}})) - monkeypatch.setattr(website_service_module.httpx, "get", get_mock) + monkeypatch.setattr(website_service_module._jina_http_client, "get", get_mock) req = WebsiteCrawlApiRequest( provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False} @@ -356,7 +356,11 @@ def test_crawl_with_jinareader_single_page_success(monkeypatch: pytest.MonkeyPat def test_crawl_with_jinareader_single_page_failure(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500}))) + monkeypatch.setattr( + website_service_module._jina_http_client, + "get", + MagicMock(return_value=_DummyHttpxResponse({"code": 500})), + ) req = WebsiteCrawlApiRequest( provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False} ).to_crawl_request() @@ -368,7 +372,7 @@ def test_crawl_with_jinareader_single_page_failure(monkeypatch: pytest.MonkeyPat def test_crawl_with_jinareader_multi_page_success(monkeypatch: pytest.MonkeyPatch) -> None: post_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"taskId": "t1"}})) - monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock) req = WebsiteCrawlApiRequest( provider="jinareader", @@ -384,7 +388,7 @@ def test_crawl_with_jinareader_multi_page_success(monkeypatch: pytest.MonkeyPatc def test_crawl_with_jinareader_multi_page_failure(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( - website_service_module.httpx, "post", MagicMock(return_value=_DummyHttpxResponse({"code": 400})) + website_service_module._adaptive_http_client, "post", MagicMock(return_value=_DummyHttpxResponse({"code": 400})) ) req = WebsiteCrawlApiRequest( provider="jinareader", @@ -482,7 +486,7 @@ def test_get_jinareader_status_active(monkeypatch: pytest.MonkeyPatch) -> None: } ) ) - monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock) result = WebsiteService._get_jinareader_status("job-1", "k") assert result["status"] == "active" @@ -518,7 +522,7 @@ def test_get_jinareader_status_completed_formats_processed_items(monkeypatch: py } } post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) - monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock) result = WebsiteService._get_jinareader_status("job-1", "k") assert result["status"] == "completed" @@ -619,7 +623,7 @@ def test_get_watercrawl_url_data_delegates(monkeypatch: pytest.MonkeyPatch) -> N def test_get_jinareader_url_data_without_job_id_success(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( - website_service_module.httpx, + website_service_module._jina_http_client, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"url": "u"}})), ) @@ -627,7 +631,11 @@ def test_get_jinareader_url_data_without_job_id_success(monkeypatch: pytest.Monk def test_get_jinareader_url_data_without_job_id_failure(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500}))) + monkeypatch.setattr( + website_service_module._jina_http_client, + "get", + MagicMock(return_value=_DummyHttpxResponse({"code": 500})), + ) with pytest.raises(ValueError, match="Failed to crawl$"): WebsiteService._get_jinareader_url_data("", "u", "k") @@ -637,7 +645,7 @@ def test_get_jinareader_url_data_with_job_id_completed_returns_matching_item(mon processed_payload = {"data": {"processed": {"u1": {"data": {"url": "u", "title": "t"}}}}} post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) - monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock) assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") == {"url": "u", "title": "t"} assert post_mock.call_count == 2 @@ -645,7 +653,7 @@ def test_get_jinareader_url_data_with_job_id_completed_returns_matching_item(mon def test_get_jinareader_url_data_with_job_id_not_completed_raises(monkeypatch: pytest.MonkeyPatch) -> None: post_mock = MagicMock(return_value=_DummyHttpxResponse({"data": {"status": "active"}})) - monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock) with pytest.raises(ValueError, match=r"Crawl job is no\s*t completed"): WebsiteService._get_jinareader_url_data("job-1", "u", "k") @@ -658,7 +666,7 @@ def test_get_jinareader_url_data_with_job_id_completed_but_not_found_returns_non processed_payload = {"data": {"processed": {"u1": {"data": {"url": "other"}}}}} post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) - monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + monkeypatch.setattr(website_service_module._adaptive_http_client, "post", post_mock) assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") is None From d2baacdd4b7f3a717fa7e19820300d9ae63d4d5f Mon Sep 17 00:00:00 2001 From: lif <1835304752@qq.com> Date: Wed, 1 Apr 2026 09:31:42 +0800 Subject: [PATCH 02/42] feat(docker): add healthcheck for api, worker, and worker_beat services (#34345) Signed-off-by: majiayu000 <1835304752@qq.com> --- docker/docker-compose-template.yaml | 18 ++++++++++++++++++ docker/docker-compose.yaml | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index e55cf942c32..57584cb8295 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -56,6 +56,12 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:5001/health"] + interval: 30s + timeout: 5s + retries: 3 + start_period: 30s networks: - ssrf_proxy_network - default @@ -95,6 +101,12 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage + healthcheck: + test: ["CMD-SHELL", "celery -A celery_entrypoint.celery inspect ping"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 60s networks: - ssrf_proxy_network - default @@ -126,6 +138,12 @@ services: required: false redis: condition: service_started + healthcheck: + test: ["CMD-SHELL", "celery -A app.celery inspect ping"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 60s networks: - ssrf_proxy_network - default diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index ed68107f46c..097fadc9593 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -765,6 +765,12 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:5001/health"] + interval: 30s + timeout: 5s + retries: 3 + start_period: 30s networks: - ssrf_proxy_network - default @@ -804,6 +810,12 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage + healthcheck: + test: ["CMD-SHELL", "celery -A celery_entrypoint.celery inspect ping"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 60s networks: - ssrf_proxy_network - default @@ -835,6 +847,12 @@ services: required: false redis: condition: service_started + healthcheck: + test: ["CMD-SHELL", "celery -A app.celery inspect ping"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 60s networks: - ssrf_proxy_network - default From 324b47507c0781567b8b441477aeb78216241719 Mon Sep 17 00:00:00 2001 From: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Date: Wed, 1 Apr 2026 09:50:02 +0800 Subject: [PATCH 03/42] refactor: enhance ELK layout handling (#34334) --- .../utils/__tests__/elk-layout.spec.ts | 275 ++++++++++++++++++ .../components/workflow/utils/elk-layout.ts | 135 +++++---- 2 files changed, 357 insertions(+), 53 deletions(-) diff --git a/web/app/components/workflow/utils/__tests__/elk-layout.spec.ts b/web/app/components/workflow/utils/__tests__/elk-layout.spec.ts index 1a3c52ec2d5..54eb289abe3 100644 --- a/web/app/components/workflow/utils/__tests__/elk-layout.spec.ts +++ b/web/app/components/workflow/utils/__tests__/elk-layout.spec.ts @@ -486,6 +486,242 @@ describe('getLayoutByELK', () => { expect(hiNode.ports).toHaveLength(2) }) + it('should build ports for QuestionClassifier sorted by classes order', async () => { + const nodes = [ + makeWorkflowNode({ + id: 'qc-1', + data: { + type: BlockEnum.QuestionClassifier, + title: '', + desc: '', + classes: [{ id: 'cls-a', name: 'A' }, { id: 'cls-b', name: 'B' }, { id: 'cls-c', name: 'C' }], + }, + }), + makeWorkflowNode({ id: 'x', data: { type: BlockEnum.Code, title: '', desc: '' } }), + makeWorkflowNode({ id: 'y', data: { type: BlockEnum.Code, title: '', desc: '' } }), + makeWorkflowNode({ id: 'z', data: { type: BlockEnum.Code, title: '', desc: '' } }), + ] + const edges = [ + makeWorkflowEdge({ id: 'e-c', source: 'qc-1', target: 'z', sourceHandle: 'cls-c' }), + makeWorkflowEdge({ id: 'e-a', source: 'qc-1', target: 'x', sourceHandle: 'cls-a' }), + makeWorkflowEdge({ id: 'e-b', source: 'qc-1', target: 'y', sourceHandle: 'cls-b' }), + ] + + await getLayoutByELK(nodes, edges) + const qcNode = layoutCallArgs!.children!.find((c: ElkChild) => c.id === 'qc-1')! + const portIds = qcNode.ports!.map((p: { id: string }) => p.id) + expect(portIds).toEqual([ + 'qc-1-out-cls-a', + 'qc-1-out-cls-b', + 'qc-1-out-cls-c', + ]) + }) + + it('should build ports for QuestionClassifier with single class', async () => { + const nodes = [ + makeWorkflowNode({ + id: 'qc-1', + data: { + type: BlockEnum.QuestionClassifier, + title: '', + desc: '', + classes: [{ id: 'cls-only', name: 'Only' }], + }, + }), + makeWorkflowNode({ id: 'x', data: { type: BlockEnum.Code, title: '', desc: '' } }), + ] + const edges = [ + makeWorkflowEdge({ id: 'e1', source: 'qc-1', target: 'x', sourceHandle: 'cls-only' }), + ] + + await getLayoutByELK(nodes, edges) + const qcNode = layoutCallArgs!.children!.find((c: ElkChild) => c.id === 'qc-1')! + expect(qcNode.ports).toHaveLength(1) + expect(qcNode.ports![0].layoutOptions!['elk.port.side']).toBe('EAST') + }) + + it('should only create output (EAST) ports, not input (WEST) ports', async () => { + const nodes = [ + makeWorkflowNode({ id: 'a', data: { type: BlockEnum.Start, title: '', desc: '' } }), + makeWorkflowNode({ id: 'b', data: { type: BlockEnum.Code, title: '', desc: '' } }), + makeWorkflowNode({ id: 'c', data: { type: BlockEnum.End, title: '', desc: '' } }), + ] + const edges = [ + makeWorkflowEdge({ id: 'e1', source: 'a', target: 'b' }), + makeWorkflowEdge({ id: 'e2', source: 'b', target: 'c' }), + ] + + await getLayoutByELK(nodes, edges) + layoutCallArgs!.children!.forEach((child: ElkChild) => { + if (child.ports) { + child.ports.forEach((port) => { + expect(port.layoutOptions!['elk.port.side']).toBe('EAST') + }) + } + }) + const endNode = layoutCallArgs!.children!.find((c: ElkChild) => c.id === 'c')! + expect(endNode.ports).toBeUndefined() + }) + + it('should order children array by DFS following port order', async () => { + const nodes = [ + makeWorkflowNode({ + id: 'if-1', + data: { + type: BlockEnum.IfElse, + title: '', + desc: '', + cases: [{ case_id: 'case-a', logical_operator: 'and', conditions: [] }], + }, + }), + makeWorkflowNode({ id: 'start', data: { type: BlockEnum.Start, title: '', desc: '' } }), + makeWorkflowNode({ id: 'branch-a', data: { type: BlockEnum.Code, title: '', desc: '' } }), + makeWorkflowNode({ id: 'branch-else', data: { type: BlockEnum.Code, title: '', desc: '' } }), + makeWorkflowNode({ id: 'end', data: { type: BlockEnum.End, title: '', desc: '' } }), + ] + const edges = [ + makeWorkflowEdge({ source: 'start', target: 'if-1' }), + makeWorkflowEdge({ id: 'e-else', source: 'if-1', target: 'branch-else', sourceHandle: 'false' }), + makeWorkflowEdge({ id: 'e-a', source: 'if-1', target: 'branch-a', sourceHandle: 'case-a' }), + makeWorkflowEdge({ source: 'branch-a', target: 'end' }), + makeWorkflowEdge({ source: 'branch-else', target: 'end' }), + ] + + await getLayoutByELK(nodes, edges) + const childIds = layoutCallArgs!.children!.map((c: ElkChild) => c.id) + // DFS from start: start → if-1 → branch-a (case-a first) → end → branch-else + const idxA = childIds.indexOf('branch-a') + const idxElse = childIds.indexOf('branch-else') + expect(idxA).toBeLessThan(idxElse) + }) + + it('should order children by DFS across nested branching nodes', async () => { + const nodes = [ + makeWorkflowNode({ id: 'start', data: { type: BlockEnum.Start, title: '', desc: '' } }), + makeWorkflowNode({ + id: 'qc-1', + data: { + type: BlockEnum.QuestionClassifier, + title: '', + desc: '', + classes: [{ id: 'c1', name: 'C1' }, { id: 'c2', name: 'C2' }], + }, + }), + makeWorkflowNode({ id: 'upper', data: { type: BlockEnum.Code, title: '', desc: '' } }), + makeWorkflowNode({ id: 'lower', data: { type: BlockEnum.Code, title: '', desc: '' } }), + ] + const edges = [ + makeWorkflowEdge({ source: 'start', target: 'qc-1' }), + makeWorkflowEdge({ id: 'e-c2', source: 'qc-1', target: 'lower', sourceHandle: 'c2' }), + makeWorkflowEdge({ id: 'e-c1', source: 'qc-1', target: 'upper', sourceHandle: 'c1' }), + ] + + await getLayoutByELK(nodes, edges) + const childIds = layoutCallArgs!.children!.map((c: ElkChild) => c.id) + // DFS: start → qc-1 → upper (c1 first) → lower (c2 second) + expect(childIds.indexOf('upper')).toBeLessThan(childIds.indexOf('lower')) + }) + + it('should handle QuestionClassifier with no classes property', async () => { + const nodes = [ + makeWorkflowNode({ id: 'qc-1', data: { type: BlockEnum.QuestionClassifier, title: '', desc: '' } }), + makeWorkflowNode({ id: 'b', data: { type: BlockEnum.Code, title: '', desc: '' } }), + makeWorkflowNode({ id: 'c', data: { type: BlockEnum.Code, title: '', desc: '' } }), + ] + const edges = [ + makeWorkflowEdge({ id: 'e1', source: 'qc-1', target: 'b', sourceHandle: 'cls-1' }), + makeWorkflowEdge({ id: 'e2', source: 'qc-1', target: 'c', sourceHandle: 'cls-2' }), + ] + + await getLayoutByELK(nodes, edges) + const qcNode = layoutCallArgs!.children!.find((c: ElkChild) => c.id === 'qc-1')! + expect(qcNode.ports).toHaveLength(2) + }) + + it('should handle QuestionClassifier edges where handle not found in classes', async () => { + const nodes = [ + makeWorkflowNode({ + id: 'qc-1', + data: { type: BlockEnum.QuestionClassifier, title: '', desc: '', classes: [{ id: 'known', name: 'K' }] }, + }), + makeWorkflowNode({ id: 'b', data: { type: BlockEnum.Code, title: '', desc: '' } }), + makeWorkflowNode({ id: 'c', data: { type: BlockEnum.Code, title: '', desc: '' } }), + ] + const edges = [ + makeWorkflowEdge({ id: 'e1', source: 'qc-1', target: 'b', sourceHandle: 'unknown-1' }), + makeWorkflowEdge({ id: 'e2', source: 'qc-1', target: 'c', sourceHandle: 'unknown-2' }), + ] + + await getLayoutByELK(nodes, edges) + const qcNode = layoutCallArgs!.children!.find((c: ElkChild) => c.id === 'qc-1')! + expect(qcNode.ports).toHaveLength(2) + }) + + it('should include disconnected nodes in the layout', async () => { + const nodes = [ + makeWorkflowNode({ id: 'start', data: { type: BlockEnum.Start, title: '', desc: '' } }), + makeWorkflowNode({ id: 'connected', data: { type: BlockEnum.Code, title: '', desc: '' } }), + makeWorkflowNode({ id: 'isolated', data: { type: BlockEnum.Code, title: '', desc: '' } }), + ] + const edges = [ + makeWorkflowEdge({ source: 'start', target: 'connected' }), + ] + + await getLayoutByELK(nodes, edges) + const childIds = layoutCallArgs!.children!.map((c: ElkChild) => c.id) + expect(childIds).toContain('isolated') + expect(childIds).toHaveLength(3) + }) + + it('should build edges in DFS order matching port order', async () => { + const nodes = [ + makeWorkflowNode({ id: 'start', data: { type: BlockEnum.Start, title: '', desc: '' } }), + makeWorkflowNode({ + id: 'if-1', + data: { type: BlockEnum.IfElse, title: '', desc: '', cases: [{ case_id: 'case-a' }] }, + }), + makeWorkflowNode({ id: 'a', data: { type: BlockEnum.Code, title: '', desc: '' } }), + makeWorkflowNode({ id: 'b', data: { type: BlockEnum.Code, title: '', desc: '' } }), + ] + const edges = [ + makeWorkflowEdge({ source: 'start', target: 'if-1' }), + makeWorkflowEdge({ id: 'e-else', source: 'if-1', target: 'b', sourceHandle: 'false' }), + makeWorkflowEdge({ id: 'e-a', source: 'if-1', target: 'a', sourceHandle: 'case-a' }), + ] + + await getLayoutByELK(nodes, edges) + const elkEdges = layoutCallArgs!.edges as Array<{ sources: string[], targets: string[] }> + const ifEdges = elkEdges.filter(e => e.sources[0] === 'if-1') + expect(ifEdges[0].targets[0]).toBe('a') + expect(ifEdges[1].targets[0]).toBe('b') + }) + + it('should keep edges for components where every node has an incoming edge', async () => { + const nodes = [ + makeWorkflowNode({ + id: 'if-1', + data: { type: BlockEnum.IfElse, title: '', desc: '', cases: [{ case_id: 'case-a' }] }, + }), + makeWorkflowNode({ id: 'a', data: { type: BlockEnum.Code, title: '', desc: '' } }), + makeWorkflowNode({ id: 'b', data: { type: BlockEnum.Code, title: '', desc: '' } }), + ] + const edges = [ + makeWorkflowEdge({ id: 'e-a', source: 'if-1', target: 'a', sourceHandle: 'case-a' }), + makeWorkflowEdge({ id: 'e-b', source: 'if-1', target: 'b', sourceHandle: 'false' }), + makeWorkflowEdge({ id: 'e-back', source: 'a', target: 'if-1' }), + ] + + await getLayoutByELK(nodes, edges) + + const elkEdges = layoutCallArgs!.edges as Array<{ sources: string[], targets: string[] }> + expect(elkEdges).toHaveLength(3) + expect(elkEdges).toEqual(expect.arrayContaining([ + expect.objectContaining({ sources: ['if-1'], targets: ['a'] }), + expect.objectContaining({ sources: ['if-1'], targets: ['b'] }), + expect.objectContaining({ sources: ['a'], targets: ['if-1'] }), + ])) + }) + it('should filter loop internal edges', async () => { const nodes = [ makeWorkflowNode({ id: 'a', data: { type: BlockEnum.Start, title: '', desc: '' } }), @@ -650,6 +886,45 @@ describe('getLayoutForChildNodes', () => { expect(result!.nodes.size).toBe(2) }) + it('should build ports and DFS-order for branching nodes inside iteration', async () => { + const nodes = [ + makeWorkflowNode({ id: 'parent', data: { type: BlockEnum.Iteration, title: '', desc: '' } }), + makeWorkflowNode({ + id: 'iter-start', + type: CUSTOM_ITERATION_START_NODE, + parentId: 'parent', + data: { type: BlockEnum.IterationStart, title: '', desc: '' }, + }), + makeWorkflowNode({ + id: 'qc-child', + parentId: 'parent', + data: { + type: BlockEnum.QuestionClassifier, + title: '', + desc: '', + classes: [{ id: 'cls-1', name: 'C1' }, { id: 'cls-2', name: 'C2' }], + }, + }), + makeWorkflowNode({ id: 'upper', parentId: 'parent', data: { type: BlockEnum.Code, title: '', desc: '' } }), + makeWorkflowNode({ id: 'lower', parentId: 'parent', data: { type: BlockEnum.Code, title: '', desc: '' } }), + ] + const edges = [ + makeWorkflowEdge({ source: 'iter-start', target: 'qc-child', data: { isInIteration: true, iteration_id: 'parent' } }), + makeWorkflowEdge({ id: 'e-c2', source: 'qc-child', target: 'lower', sourceHandle: 'cls-2', data: { isInIteration: true, iteration_id: 'parent' } }), + makeWorkflowEdge({ id: 'e-c1', source: 'qc-child', target: 'upper', sourceHandle: 'cls-1', data: { isInIteration: true, iteration_id: 'parent' } }), + ] + + await getLayoutForChildNodes('parent', nodes, edges) + + const qcElk = layoutCallArgs!.children!.find((c: ElkChild) => c.id === 'qc-child')! + expect(qcElk.ports).toHaveLength(2) + expect(qcElk.ports![0].id).toContain('cls-1') + expect(qcElk.ports![1].id).toContain('cls-2') + + const childIds = layoutCallArgs!.children!.map((c: ElkChild) => c.id) + expect(childIds.indexOf('upper')).toBeLessThan(childIds.indexOf('lower')) + }) + it('should return original layout when bounds are not finite', async () => { mockReturnOverride = (graph: ElkGraph) => ({ ...graph, diff --git a/web/app/components/workflow/utils/elk-layout.ts b/web/app/components/workflow/utils/elk-layout.ts index 9860bbc7708..781416f3c40 100644 --- a/web/app/components/workflow/utils/elk-layout.ts +++ b/web/app/components/workflow/utils/elk-layout.ts @@ -1,6 +1,7 @@ import type { ElkNode, LayoutOptions } from 'elkjs/lib/elk-api' import type { HumanInputNodeType } from '@/app/components/workflow/nodes/human-input/types' import type { CaseItem, IfElseNodeType } from '@/app/components/workflow/nodes/if-else/types' +import type { QuestionClassifierNodeType, Topic } from '@/app/components/workflow/nodes/question-classifier/types' import type { Edge, Node, @@ -37,13 +38,13 @@ const ROOT_LAYOUT_OPTIONS = { // === Port Configuration === 'elk.portConstraints': 'FIXED_ORDER', - 'elk.layered.considerModelOrder.strategy': 'PREFER_EDGES', + 'elk.layered.considerModelOrder.strategy': 'NODES_AND_EDGES', + 'elk.layered.crossingMinimization.forceNodeModelOrder': 'true', - // === Node Placement - Best quality === - 'elk.layered.nodePlacement.strategy': 'NETWORK_SIMPLEX', + // === Node Placement - Balanced centering === + 'elk.layered.nodePlacement.strategy': 'BRANDES_KOEPF', 'elk.layered.nodePlacement.favorStraightEdges': 'true', - 'elk.layered.nodePlacement.linearSegments.deflectionDampening': '0.5', - 'elk.layered.nodePlacement.networkSimplex.nodeFlexibility': 'NODE_SIZE', + 'elk.layered.nodePlacement.bk.fixedAlignment': 'BALANCED', // === Edge Routing - Maximum quality === 'elk.edgeRouting': 'SPLINES', @@ -56,7 +57,7 @@ const ROOT_LAYOUT_OPTIONS = { 'elk.layered.crossingMinimization.strategy': 'LAYER_SWEEP', 'elk.layered.crossingMinimization.greedySwitch.type': 'TWO_SIDED', 'elk.layered.crossingMinimization.greedySwitchHierarchical.type': 'TWO_SIDED', - 'elk.layered.crossingMinimization.semiInteractive': 'true', + 'elk.layered.crossingMinimization.semiInteractive': 'false', 'elk.layered.crossingMinimization.hierarchicalSweepiness': '0.9', // === Layering Strategy - Best quality === @@ -115,11 +116,15 @@ const CHILD_LAYOUT_OPTIONS = { 'elk.spacing.edgeLabel': '8', 'elk.spacing.portPort': '15', - // === Node Placement - Best quality === - 'elk.layered.nodePlacement.strategy': 'NETWORK_SIMPLEX', + // === Port Configuration === + 'elk.portConstraints': 'FIXED_ORDER', + 'elk.layered.considerModelOrder.strategy': 'NODES_AND_EDGES', + 'elk.layered.crossingMinimization.forceNodeModelOrder': 'true', + + // === Node Placement - Balanced centering === + 'elk.layered.nodePlacement.strategy': 'BRANDES_KOEPF', 'elk.layered.nodePlacement.favorStraightEdges': 'true', - 'elk.layered.nodePlacement.linearSegments.deflectionDampening': '0.5', - 'elk.layered.nodePlacement.networkSimplex.nodeFlexibility': 'NODE_SIZE', + 'elk.layered.nodePlacement.bk.fixedAlignment': 'BALANCED', // === Edge Routing - Maximum quality === 'elk.edgeRouting': 'SPLINES', @@ -129,7 +134,7 @@ const CHILD_LAYOUT_OPTIONS = { // === Crossing Minimization - Aggressive === 'elk.layered.crossingMinimization.strategy': 'LAYER_SWEEP', 'elk.layered.crossingMinimization.greedySwitch.type': 'TWO_SIDED', - 'elk.layered.crossingMinimization.semiInteractive': 'true', + 'elk.layered.crossingMinimization.semiInteractive': 'false', // === Layering Strategy === 'elk.layered.layering.strategy': 'NETWORK_SIMPLEX', @@ -197,12 +202,6 @@ type ElkEdgeShape = { targetPort?: string } -const toElkNode = (node: Node): ElkNodeShape => ({ - id: node.id, - width: node.width ?? DEFAULT_NODE_WIDTH, - height: node.height ?? DEFAULT_NODE_HEIGHT, -}) - let edgeCounter = 0 const nextEdgeId = () => `elk-edge-${edgeCounter++}` @@ -297,6 +296,24 @@ const sortIfElseOutEdges = (ifElseNode: Node, outEdges: Edge[]): Edge[] => { }) } +const sortQuestionClassifierOutEdges = (classifierNode: Node, outEdges: Edge[]): Edge[] => { + return [...outEdges].sort((edgeA, edgeB) => { + const handleA = edgeA.sourceHandle + const handleB = edgeB.sourceHandle + + if (handleA && handleB) { + const classes = (classifierNode.data as QuestionClassifierNodeType).classes || [] + const indexA = classes.findIndex((t: Topic) => t.id === handleA) + const indexB = classes.findIndex((t: Topic) => t.id === handleB) + + if (indexA !== -1 && indexB !== -1) + return indexA - indexB + } + + return 0 + }) +} + const sortHumanInputOutEdges = (humanInputNode: Node, outEdges: Edge[]): Edge[] => { return [...outEdges].sort((edgeA, edgeB) => { const handleA = edgeA.sourceHandle @@ -352,63 +369,45 @@ const normaliseBounds = (layout: LayoutResult): LayoutResult => { } } -export const getLayoutByELK = async (originNodes: Node[], originEdges: Edge[]): Promise => { - edgeCounter = 0 - const nodes = cloneDeep(originNodes).filter(node => !node.parentId && node.type === CUSTOM_NODE) - const edges = cloneDeep(originEdges).filter(edge => (!edge.data?.isInIteration && !edge.data?.isInLoop)) - +/** + * Build ELK nodes with output ports (sorted for branching types) + * and edges ordered by a DFS traversal that follows port order. + */ +const buildPortAwareGraph = (nodes: Node[], edges: Edge[]) => { const outEdgesByNode = new Map() - const inEdgesByNode = new Map() edges.forEach((edge) => { if (!outEdgesByNode.has(edge.source)) outEdgesByNode.set(edge.source, []) outEdgesByNode.get(edge.source)!.push(edge) - if (!inEdgesByNode.has(edge.target)) - inEdgesByNode.set(edge.target, []) - inEdgesByNode.get(edge.target)!.push(edge) }) const elkNodes: ElkNodeShape[] = [] const elkEdges: ElkEdgeShape[] = [] const sourcePortMap = new Map() - const targetPortMap = new Map() const sortedOutEdgesByNode = new Map() nodes.forEach((node) => { - const inEdges = inEdgesByNode.get(node.id) || [] let outEdges = outEdgesByNode.get(node.id) || [] if (node.data.type === BlockEnum.IfElse) outEdges = sortIfElseOutEdges(node, outEdges) + else if (node.data.type === BlockEnum.QuestionClassifier) + outEdges = sortQuestionClassifierOutEdges(node, outEdges) else if (node.data.type === BlockEnum.HumanInput) outEdges = sortHumanInputOutEdges(node, outEdges) sortedOutEdgesByNode.set(node.id, outEdges) - const ports: ElkPortShape[] = [] - - inEdges.forEach((edge, index) => { - const portId = `${node.id}-in-${index}` - ports.push({ - id: portId, - layoutOptions: { - 'elk.port.side': 'WEST', - 'elk.port.index': String(index), - }, - }) - targetPortMap.set(edge.id, portId) - }) - - outEdges.forEach((edge, index) => { + const ports: ElkPortShape[] = outEdges.map((edge, index) => { const portId = `${node.id}-out-${edge.sourceHandle || index}` - ports.push({ + sourcePortMap.set(edge.id, portId) + return { id: portId, layoutOptions: { 'elk.port.side': 'EAST', 'elk.port.index': String(index), }, - }) - sourcePortMap.set(edge.id, portId) + } }) elkNodes.push({ @@ -422,19 +421,51 @@ export const getLayoutByELK = async (originNodes: Node[], originEdges: Edge[]): }) }) - // Build edges in sorted per-node order so PREFER_EDGES aligns with port order - nodes.forEach((node) => { - const outEdges = sortedOutEdgesByNode.get(node.id) || [] + // DFS in port order to determine the definitive vertical ordering of nodes. + // forceNodeModelOrder makes ELK respect the children-array order within each layer. + const nodeIdSet = new Set(nodes.map(n => n.id)) + const visited = new Set() + const orderedIds: string[] = [] + + const dfs = (id: string) => { + if (visited.has(id) || !nodeIdSet.has(id)) + return + visited.add(id) + orderedIds.push(id) + const outEdges = sortedOutEdgesByNode.get(id) || [] + outEdges.forEach(e => dfs(e.target)) + } + + nodes.forEach((n) => { + if (!edges.some(e => e.target === n.id)) + dfs(n.id) + }) + nodes.forEach(n => dfs(n.id)) + + const nodeOrder = new Map(orderedIds.map((id, i) => [id, i])) + elkNodes.sort((a, b) => (nodeOrder.get(a.id) ?? 0) - (nodeOrder.get(b.id) ?? 0)) + + orderedIds.forEach((id) => { + const outEdges = sortedOutEdgesByNode.get(id) || [] outEdges.forEach((edge) => { elkEdges.push(createEdge( edge.source, edge.target, sourcePortMap.get(edge.id), - targetPortMap.get(edge.id), )) }) }) + return { elkNodes, elkEdges } +} + +export const getLayoutByELK = async (originNodes: Node[], originEdges: Edge[]): Promise => { + edgeCounter = 0 + const nodes = cloneDeep(originNodes).filter(node => !node.parentId && node.type === CUSTOM_NODE) + const edges = cloneDeep(originEdges).filter(edge => (!edge.data?.isInIteration && !edge.data?.isInLoop)) + + const { elkNodes, elkEdges } = buildPortAwareGraph(nodes, edges) + const graph = { id: 'workflow-root', layoutOptions: ROOT_LAYOUT_OPTIONS, @@ -443,7 +474,6 @@ export const getLayoutByELK = async (originNodes: Node[], originEdges: Edge[]): } const layoutedGraph = await elk.layout(graph) - // No need to filter dummy nodes anymore, as we're using ports const layout = collectLayout(layoutedGraph, () => true) return normaliseBounds(layout) } @@ -532,8 +562,7 @@ export const getLayoutForChildNodes = async ( || (edge.data?.isInLoop && edge.data?.loop_id === parentNodeId), ) - const elkNodes: ElkNodeShape[] = nodes.map(toElkNode) - const elkEdges: ElkEdgeShape[] = edges.map(edge => createEdge(edge.source, edge.target)) + const { elkNodes, elkEdges } = buildPortAwareGraph(nodes, edges) const graph = { id: parentNodeId, From 4bd388669aedc342ccc76ae7529785d615b10323 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Tue, 31 Mar 2026 21:20:56 -0500 Subject: [PATCH 04/42] refactor: core/app pipeline, core/datasource, and core/indexing_runner (#34359) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../app/apps/pipeline/pipeline_generator.py | 2 +- api/core/app/apps/pipeline/pipeline_runner.py | 17 ++- .../datasource/datasource_file_manager.py | 8 +- api/core/indexing_runner.py | 105 ++++++++++-------- .../apps/pipeline/test_pipeline_generator.py | 2 +- .../app/apps/pipeline/test_pipeline_runner.py | 25 +---- .../test_datasource_file_manager.py | 50 +++------ .../core/rag/indexing/test_indexing_runner.py | 73 ++++++------ 8 files changed, 131 insertions(+), 151 deletions(-) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index fa242003a25..9cc1a197d5f 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -302,7 +302,7 @@ class PipelineGenerator(BaseAppGenerator): """ with preserve_flask_contexts(flask_app, context_vars=context): # init queue manager - workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first() + workflow = db.session.get(Workflow, workflow_id) if not workflow: raise ValueError(f"Workflow not found: {workflow_id}") queue_manager = PipelineQueueManager( diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 4c188dac68d..b4d2310da85 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -9,6 +9,7 @@ from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variable_loader import VariableLoader from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput +from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig @@ -84,13 +85,13 @@ class PipelineRunner(WorkflowBasedAppRunner): user_id = None if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: - end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first() + end_user = db.session.get(EndUser, self.application_generate_entity.user_id) if end_user: user_id = end_user.session_id else: user_id = self.application_generate_entity.user_id - pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first() + pipeline = db.session.get(Pipeline, app_config.app_id) if not pipeline: raise ValueError("Pipeline not found") @@ -213,10 +214,10 @@ class PipelineRunner(WorkflowBasedAppRunner): Get workflow """ # fetch workflow by workflow_id - workflow = ( - db.session.query(Workflow) + workflow = db.session.scalar( + select(Workflow) .where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id) - .first() + .limit(1) ) # return workflow @@ -297,10 +298,8 @@ class PipelineRunner(WorkflowBasedAppRunner): """ if isinstance(event, GraphRunFailedEvent): if document_id and dataset_id: - document = ( - db.session.query(Document) - .where(Document.id == document_id, Document.dataset_id == dataset_id) - .first() + document = db.session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) ) if document: document.indexing_status = "error" diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index fe40d8f0e58..492b507aa99 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -153,7 +153,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ - upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == id).first() + upload_file: UploadFile | None = db.session.get(UploadFile, id) if not upload_file: return None @@ -171,7 +171,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ - message_file: MessageFile | None = db.session.query(MessageFile).where(MessageFile.id == id).first() + message_file: MessageFile | None = db.session.get(MessageFile, id) # Check if message_file is not None if message_file is not None: @@ -185,7 +185,7 @@ class DatasourceFileManager: else: tool_file_id = None - tool_file: ToolFile | None = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first() + tool_file: ToolFile | None = db.session.get(ToolFile, tool_file_id) if not tool_file: return None @@ -203,7 +203,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ - upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() + upload_file: UploadFile | None = db.session.get(UploadFile, upload_file_id) if not upload_file: return None, None diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 3ec17bc9864..b8d5ca2f50f 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -10,7 +10,7 @@ from typing import Any from flask import Flask, current_app from graphon.model_runtime.entities.model_entities import ModelType -from sqlalchemy import select +from sqlalchemy import delete, func, select, update from sqlalchemy.orm.exc import ObjectDeletedError from configs import dify_config @@ -78,7 +78,7 @@ class IndexingRunner: continue # get dataset - dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first() + dataset = db.session.get(Dataset, requeried_document.dataset_id) if not dataset: raise ValueError("no dataset found") @@ -95,7 +95,7 @@ class IndexingRunner: text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform - current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first() + current_user = db.session.get(Account, requeried_document.created_by) if not current_user: raise ValueError("no current user found") current_user.set_tenant_id(dataset.tenant_id) @@ -137,23 +137,24 @@ class IndexingRunner: return # get dataset - dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first() + dataset = db.session.get(Dataset, requeried_document.dataset_id) if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete - document_segments = ( - db.session.query(DocumentSegment) - .filter_by(dataset_id=dataset.id, document_id=requeried_document.id) - .all() - ) + document_segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == requeried_document.id, + ) + ).all() for document_segment in document_segments: db.session.delete(document_segment) if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: # delete child chunks - db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() + db.session.execute(delete(ChildChunk).where(ChildChunk.segment_id == document_segment.id)) db.session.commit() # get the process rule stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id) @@ -167,7 +168,7 @@ class IndexingRunner: text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform - current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first() + current_user = db.session.get(Account, requeried_document.created_by) if not current_user: raise ValueError("no current user found") current_user.set_tenant_id(dataset.tenant_id) @@ -207,17 +208,18 @@ class IndexingRunner: return # get dataset - dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first() + dataset = db.session.get(Dataset, requeried_document.dataset_id) if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete - document_segments = ( - db.session.query(DocumentSegment) - .filter_by(dataset_id=dataset.id, document_id=requeried_document.id) - .all() - ) + document_segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == requeried_document.id, + ) + ).all() documents = [] if document_segments: @@ -289,7 +291,7 @@ class IndexingRunner: embedding_model_instance = None if dataset_id: - dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() + dataset = db.session.get(Dataset, dataset_id) if not dataset: raise ValueError("Dataset not found.") if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}: @@ -652,24 +654,26 @@ class IndexingRunner: @staticmethod def _process_keyword_index(flask_app, dataset_id, document_id, documents): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() + dataset = db.session.get(Dataset, dataset_id) if not dataset: raise ValueError("no dataset found") keyword = Keyword(dataset) keyword.create(documents) if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: document_ids = [document.metadata["doc_id"] for document in documents] - db.session.query(DocumentSegment).where( - DocumentSegment.document_id == document_id, - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == SegmentStatus.INDEXING, - ).update( - { - DocumentSegment.status: SegmentStatus.COMPLETED, - DocumentSegment.enabled: True, - DocumentSegment.completed_at: naive_utc_now(), - } + db.session.execute( + update(DocumentSegment) + .where( + DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.index_node_id.in_(document_ids), + DocumentSegment.status == SegmentStatus.INDEXING, + ) + .values( + status=SegmentStatus.COMPLETED, + enabled=True, + completed_at=naive_utc_now(), + ) ) db.session.commit() @@ -703,17 +707,19 @@ class IndexingRunner: ) document_ids = [document.metadata["doc_id"] for document in chunk_documents] - db.session.query(DocumentSegment).where( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == SegmentStatus.INDEXING, - ).update( - { - DocumentSegment.status: SegmentStatus.COMPLETED, - DocumentSegment.enabled: True, - DocumentSegment.completed_at: naive_utc_now(), - } + db.session.execute( + update(DocumentSegment) + .where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(document_ids), + DocumentSegment.status == SegmentStatus.INDEXING, + ) + .values( + status=SegmentStatus.COMPLETED, + enabled=True, + completed_at=naive_utc_now(), + ) ) db.session.commit() @@ -734,10 +740,17 @@ class IndexingRunner: """ Update the document indexing status. """ - count = db.session.query(DatasetDocument).filter_by(id=document_id, is_paused=True).count() + count = ( + db.session.scalar( + select(func.count()) + .select_from(DatasetDocument) + .where(DatasetDocument.id == document_id, DatasetDocument.is_paused == True) + ) + or 0 + ) if count > 0: raise DocumentIsPausedError() - document = db.session.query(DatasetDocument).filter_by(id=document_id).first() + document = db.session.get(DatasetDocument, document_id) if not document: raise DocumentIsDeletedPausedError() @@ -745,7 +758,7 @@ class IndexingRunner: if extra_update_params: update_params.update(extra_update_params) - db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) # type: ignore + db.session.execute(update(DatasetDocument).where(DatasetDocument.id == document_id).values(update_params)) # type: ignore db.session.commit() @staticmethod @@ -753,7 +766,9 @@ class IndexingRunner: """ Update the document segment by document id. """ - db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params) + db.session.execute( + update(DocumentSegment).where(DocumentSegment.document_id == dataset_document_id).values(update_params) + ) db.session.commit() def _transform( diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py index 06face41fe7..0047f6659d5 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py @@ -345,7 +345,7 @@ def test_generate_raises_when_workflow_not_found(generator, mocker): mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve) session = MagicMock() - session.query.return_value.where.return_value.first.return_value = None + session.get.return_value = None mocker.patch.object(module.db, "session", session) with pytest.raises(ValueError): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index ab70996f0aa..c8ae288e6fe 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -80,9 +80,7 @@ def test_get_workflow_returns_workflow(mocker, runner): pipeline = MagicMock(tenant_id="tenant", id="pipe") workflow = MagicMock(id="wf") - query = MagicMock() - query.where.return_value.first.return_value = workflow - mocker.patch.object(module.db, "session", MagicMock(query=MagicMock(return_value=query))) + mocker.patch.object(module.db, "session", MagicMock(scalar=MagicMock(return_value=workflow))) result = runner.get_workflow(pipeline=pipeline, workflow_id="wf") @@ -115,11 +113,8 @@ def test_init_rag_pipeline_graph_not_found(mocker, runner): def test_update_document_status_on_failure(mocker, runner): document = MagicMock() - query = MagicMock() - query.where.return_value.first.return_value = document - session = MagicMock() - session.query.return_value = query + session.scalar.return_value = document mocker.patch.object(module.db, "session", session) event = GraphRunFailedEvent(error="boom") @@ -189,14 +184,10 @@ def test_run_single_iteration_path(mocker): app_generate_entity.single_iteration_run = MagicMock() pipeline = MagicMock(id="pipe") - query_pipeline = MagicMock() - query_pipeline.where.return_value.first.return_value = pipeline - - query_end_user = MagicMock() - query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess") + end_user = MagicMock(session_id="sess") session = MagicMock() - session.query.side_effect = [query_end_user, query_pipeline] + session.get.side_effect = [end_user, pipeline] mocker.patch.object(module.db, "session", session) runner = PipelineRunner( @@ -241,14 +232,10 @@ def test_run_normal_path_builds_graph(mocker): app_generate_entity = _build_app_generate_entity() pipeline = MagicMock(id="pipe") - query_pipeline = MagicMock() - query_pipeline.where.return_value.first.return_value = pipeline - - query_end_user = MagicMock() - query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess") + end_user = MagicMock(session_id="sess") session = MagicMock() - session.query.side_effect = [query_end_user, query_pipeline] + session.get.side_effect = [end_user, pipeline] mocker.patch.object(module.db, "session", session) workflow = MagicMock( diff --git a/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py index 7cd1fdf06b2..4f39d38831d 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py @@ -287,9 +287,7 @@ class TestDatasourceFileManager: mock_upload_file.key = "some_key" mock_upload_file.mime_type = "image/png" - mock_query = mock_db.session.query.return_value - mock_where = mock_query.where.return_value - mock_where.first.return_value = mock_upload_file + mock_db.session.get.return_value = mock_upload_file mock_storage.load_once.return_value = b"file content" @@ -300,7 +298,7 @@ class TestDatasourceFileManager: assert result == (b"file content", "image/png") # Case: Not found - mock_where.first.return_value = None + mock_db.session.get.return_value = None assert DatasourceFileManager.get_file_binary("unknown") is None @patch("core.datasource.datasource_file_manager.db") @@ -314,16 +312,14 @@ class TestDatasourceFileManager: mock_tool_file.file_key = "tool_key" mock_tool_file.mimetype = "image/png" - # Mock query sequence - def mock_query(model): - m = MagicMock() + def mock_get(model, id): if model == MessageFile: - m.where.return_value.first.return_value = mock_message_file + return mock_message_file elif model == ToolFile: - m.where.return_value.first.return_value = mock_tool_file - return m + return mock_tool_file + return None - mock_db.session.query.side_effect = mock_query + mock_db.session.get.side_effect = mock_get mock_storage.load_once.return_value = b"tool content" # Execute @@ -344,15 +340,12 @@ class TestDatasourceFileManager: mock_tool_file.file_key = "tk" mock_tool_file.mimetype = "image/png" - def mock_query(model): - m = MagicMock() + def mock_get(model, id): if model == MessageFile: - m.where.return_value.first.return_value = mock_message_file - else: - m.where.return_value.first.return_value = mock_tool_file - return m + return mock_message_file + return mock_tool_file - mock_db.session.query.side_effect = mock_query + mock_db.session.get.side_effect = mock_get mock_storage.load_once.return_value = b"bits" result = DatasourceFileManager.get_file_binary_by_message_file_id("m") @@ -361,27 +354,20 @@ class TestDatasourceFileManager: @patch("core.datasource.datasource_file_manager.db") @patch("core.datasource.datasource_file_manager.storage") def test_get_file_binary_by_message_file_id_failures(self, mock_storage, mock_db): - # Setup common mock - mock_query_obj = MagicMock() - mock_db.session.query.return_value = mock_query_obj - mock_query_obj.where.return_value.first.return_value = None - # Case 1: Message file not found + mock_db.session.get.return_value = None assert DatasourceFileManager.get_file_binary_by_message_file_id("none") is None # Case 2: Message file found but tool file not found mock_message_file = MagicMock(spec=MessageFile) mock_message_file.url = None - def mock_query_v2(model): - m = MagicMock() + def mock_get_v2(model, id): if model == MessageFile: - m.where.return_value.first.return_value = mock_message_file - else: - m.where.return_value.first.return_value = None - return m + return mock_message_file + return None - mock_db.session.query.side_effect = mock_query_v2 + mock_db.session.get.side_effect = mock_get_v2 assert DatasourceFileManager.get_file_binary_by_message_file_id("msg_id") is None @patch("core.datasource.datasource_file_manager.db") @@ -392,7 +378,7 @@ class TestDatasourceFileManager: mock_upload_file.key = "upload_key" mock_upload_file.mime_type = "text/plain" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_upload_file + mock_db.session.get.return_value = mock_upload_file mock_storage.load_stream.return_value = iter([b"chunk1", b"chunk2"]) @@ -404,7 +390,7 @@ class TestDatasourceFileManager: assert list(stream) == [b"chunk1", b"chunk2"] # Case: Not found - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None stream, mimetype = DatasourceFileManager.get_file_generator_by_upload_file_id("none") assert stream is None assert mimetype is None diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index 450e7166360..641c5d9ba0f 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -795,33 +795,21 @@ class TestIndexingRunnerRun: doc = sample_dataset_documents[0] # Mock database queries - mock_dependencies["db"].session.get.return_value = doc - mock_dataset = Mock(spec=Dataset) mock_dataset.id = doc.dataset_id mock_dataset.tenant_id = doc.tenant_id mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY - mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset + + mock_current_user = MagicMock() + mock_current_user.set_tenant_id = MagicMock() + + get_dispatch = {"Document": doc, "Dataset": mock_dataset, "Account": mock_current_user} + mock_dependencies["db"].session.get.side_effect = lambda model, id: get_dispatch.get(model.__name__) mock_process_rule = Mock(spec=DatasetProcessRule) mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}} mock_dependencies["db"].session.scalar.return_value = mock_process_rule - # Mock current_user (Account) for _transform - mock_current_user = MagicMock() - mock_current_user.set_tenant_id = MagicMock() - - # Setup db.session.query to return different results based on the model - def mock_query_side_effect(model): - mock_query_result = MagicMock() - if model.__name__ == "Dataset": - mock_query_result.filter_by.return_value.first.return_value = mock_dataset - elif model.__name__ == "Account": - mock_query_result.filter_by.return_value.first.return_value = mock_current_user - return mock_query_result - - mock_dependencies["db"].session.query.side_effect = mock_query_side_effect - # Mock processor mock_processor = MagicMock() mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor @@ -891,10 +879,11 @@ class TestIndexingRunnerRun: doc = sample_dataset_documents[0] # Mock database - mock_dependencies["db"].session.get.return_value = doc - mock_dataset = Mock(spec=Dataset) - mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset + mock_dataset.tenant_id = doc.tenant_id + + get_dispatch = {"Document": doc, "Dataset": mock_dataset} + mock_dependencies["db"].session.get.side_effect = lambda model, id: get_dispatch.get(model.__name__) mock_process_rule = Mock(spec=DatasetProcessRule) mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}} @@ -917,11 +906,12 @@ class TestIndexingRunnerRun: runner = IndexingRunner() doc = sample_dataset_documents[0] - # Mock database to raise ObjectDeletedError - mock_dependencies["db"].session.get.return_value = doc - + # Mock database mock_dataset = Mock(spec=Dataset) - mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset + mock_dataset.tenant_id = doc.tenant_id + + get_dispatch = {"Document": doc, "Dataset": mock_dataset} + mock_dependencies["db"].session.get.side_effect = lambda model, id: get_dispatch.get(model.__name__) mock_process_rule = Mock(spec=DatasetProcessRule) mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}} @@ -945,17 +935,21 @@ class TestIndexingRunnerRun: docs = sample_dataset_documents # Mock database - def get_side_effect(model_class, doc_id): - for doc in docs: - if doc.id == doc_id: - return doc - return None - - mock_dependencies["db"].session.get.side_effect = get_side_effect - mock_dataset = Mock(spec=Dataset) mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY - mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset + mock_current_user = MagicMock() + mock_current_user.set_tenant_id = MagicMock() + + doc_map = {doc.id: doc for doc in docs} + model_dispatch = {"Dataset": mock_dataset, "Account": mock_current_user} + + def get_side_effect(model_class, id): + name = model_class.__name__ + if name == "Document": + return doc_map.get(id) + return model_dispatch.get(name) + + mock_dependencies["db"].session.get.side_effect = get_side_effect mock_process_rule = Mock(spec=DatasetProcessRule) mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}} @@ -1035,9 +1029,8 @@ class TestIndexingRunnerRetryLogic: mock_document = Mock(spec=DatasetDocument) mock_document.id = document_id - mock_dependencies["db"].session.query.return_value.filter_by.return_value.count.return_value = 0 - mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_document - mock_dependencies["db"].session.query.return_value.filter_by.return_value.update.return_value = None + mock_dependencies["db"].session.scalar.return_value = 0 + mock_dependencies["db"].session.get.return_value = mock_document # Act IndexingRunner._update_document_index_status( @@ -1053,7 +1046,7 @@ class TestIndexingRunnerRetryLogic: """Test document status update when document is paused.""" # Arrange document_id = str(uuid.uuid4()) - mock_dependencies["db"].session.query.return_value.filter_by.return_value.count.return_value = 1 + mock_dependencies["db"].session.scalar.return_value = 1 # Act & Assert with pytest.raises(DocumentIsPausedError): @@ -1063,8 +1056,8 @@ class TestIndexingRunnerRetryLogic: """Test document status update when document is deleted.""" # Arrange document_id = str(uuid.uuid4()) - mock_dependencies["db"].session.query.return_value.filter_by.return_value.count.return_value = 0 - mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = None + mock_dependencies["db"].session.scalar.return_value = 0 + mock_dependencies["db"].session.get.return_value = None # Act & Assert with pytest.raises(DocumentIsDeletedPausedError): From 42d7623cc6e8b38aa0913d1e006f22205b12a962 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Wed, 1 Apr 2026 10:32:01 +0800 Subject: [PATCH 05/42] fix: Variable Aggregator cannot click group swich (#34361) --- .../__tests__/use-config.spec.tsx | 30 +++++++++++++++ .../variable-assigner/use-config.helpers.ts | 23 +++++++++++- .../nodes/variable-assigner/use-config.ts | 37 +++++++++++++------ 3 files changed, 77 insertions(+), 13 deletions(-) diff --git a/web/app/components/workflow/nodes/variable-assigner/__tests__/use-config.spec.tsx b/web/app/components/workflow/nodes/variable-assigner/__tests__/use-config.spec.tsx index 1137f20a0c5..cb8c2db52f9 100644 --- a/web/app/components/workflow/nodes/variable-assigner/__tests__/use-config.spec.tsx +++ b/web/app/components/workflow/nodes/variable-assigner/__tests__/use-config.spec.tsx @@ -91,6 +91,15 @@ const createPayload = (overrides: Partial = {}): Varia ...overrides, }) +const createPayloadWithoutAdvancedSettings = (): VariableAssignerNodeType => { + const payload = createPayload() as Omit & { + advanced_settings?: VariableAssignerNodeType['advanced_settings'] + } + delete payload.advanced_settings + + return payload as VariableAssignerNodeType +} + describe('useConfig', () => { beforeEach(() => { vi.clearAllMocks() @@ -252,4 +261,25 @@ describe('useConfig', () => { advanced_settings: expect.objectContaining({ group_enabled: false }), })) }) + + it('should not throw when enabling groups with missing advanced settings', () => { + const { result } = renderHook(() => useConfig('assigner-node', createPayloadWithoutAdvancedSettings())) + + expect(() => { + result.current.handleGroupEnabledChange(true) + }).not.toThrow() + + expect(mockHandleOutVarRenameChange).toHaveBeenCalledWith( + 'assigner-node', + ['assigner-node', 'output'], + ['assigner-node', 'Group1', 'output'], + ) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + advanced_settings: expect.objectContaining({ + group_enabled: true, + groups: [expect.objectContaining({ group_name: 'Group1' })], + }), + })) + expect(mockDeleteNodeInspectorVars).toHaveBeenCalledWith('assigner-node') + }) }) diff --git a/web/app/components/workflow/nodes/variable-assigner/use-config.helpers.ts b/web/app/components/workflow/nodes/variable-assigner/use-config.helpers.ts index 31300557b2c..2cc91c65aca 100644 --- a/web/app/components/workflow/nodes/variable-assigner/use-config.helpers.ts +++ b/web/app/components/workflow/nodes/variable-assigner/use-config.helpers.ts @@ -26,7 +26,13 @@ export const updateNestedVarGroupItem = ( groupId: string, payload: VarGroupItem, ) => produce(inputs, (draft) => { + if (!draft.advanced_settings) + return + const index = draft.advanced_settings.groups.findIndex(item => item.groupId === groupId) + if (index < 0) + return + draft.advanced_settings.groups[index] = { ...draft.advanced_settings.groups[index], ...payload, @@ -37,6 +43,11 @@ export const removeGroupByIndex = ( inputs: VariableAssignerNodeType, index: number, ) => produce(inputs, (draft) => { + if (!draft.advanced_settings) + return + if (index < 0 || index >= draft.advanced_settings.groups.length) + return + draft.advanced_settings.groups.splice(index, 1) }) @@ -70,7 +81,8 @@ export const toggleGroupEnabled = ({ export const addGroup = (inputs: VariableAssignerNodeType) => { let maxInGroupName = 1 - inputs.advanced_settings.groups.forEach((item) => { + const groups = inputs.advanced_settings?.groups ?? [] + groups.forEach((item) => { const match = /(\d+)$/.exec(item.group_name) if (match) { const num = Number.parseInt(match[1], 10) @@ -80,6 +92,9 @@ export const addGroup = (inputs: VariableAssignerNodeType) => { }) return produce(inputs, (draft) => { + if (!draft.advanced_settings) + draft.advanced_settings = { group_enabled: false, groups: [] } + draft.advanced_settings.groups.push({ output_type: VarType.any, variables: [], @@ -94,6 +109,12 @@ export const renameGroup = ( groupId: string, name: string, ) => produce(inputs, (draft) => { + if (!draft.advanced_settings) + return + const index = draft.advanced_settings.groups.findIndex(item => item.groupId === groupId) + if (index < 0) + return + draft.advanced_settings.groups[index].group_name = name }) diff --git a/web/app/components/workflow/nodes/variable-assigner/use-config.ts b/web/app/components/workflow/nodes/variable-assigner/use-config.ts index 6d4b27e50b7..cecf185d4f4 100644 --- a/web/app/components/workflow/nodes/variable-assigner/use-config.ts +++ b/web/app/components/workflow/nodes/variable-assigner/use-config.ts @@ -54,10 +54,15 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => { const [removedGroupIndex, setRemovedGroupIndex] = useState(-1) const handleGroupRemoved = useCallback((groupId: string) => { return () => { - const index = inputs.advanced_settings.groups.findIndex(item => item.groupId === groupId) - if (isVarUsedInNodes([id, inputs.advanced_settings.groups[index].group_name, 'output'])) { + const groups = inputs.advanced_settings?.groups ?? [] + const index = groups.findIndex(item => item.groupId === groupId) + if (index < 0) + return + + const groupName = groups[index].group_name + if (isVarUsedInNodes([id, groupName, 'output'])) { showRemoveVarConfirm() - setRemovedVars([[id, inputs.advanced_settings.groups[index].group_name, 'output']]) + setRemovedVars([[id, groupName, 'output']]) setRemoveType('group') setRemovedGroupIndex(index) return @@ -67,13 +72,15 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => { }, [id, inputs, isVarUsedInNodes, setInputs, showRemoveVarConfirm]) const handleGroupEnabledChange = useCallback((enabled: boolean) => { - if (enabled && inputs.advanced_settings.groups.length === 0) { + const groups = inputs.advanced_settings?.groups ?? [] + + if (enabled && groups.length === 0) { handleOutVarRenameChange(id, [id, 'output'], [id, 'Group1', 'output']) } - if (!enabled && inputs.advanced_settings.groups.length > 0) { - if (inputs.advanced_settings.groups.length > 1) { - const useVars = inputs.advanced_settings.groups.filter((item, index) => index > 0 && isVarUsedInNodes([id, item.group_name, 'output'])) + if (!enabled && groups.length > 0) { + if (groups.length > 1) { + const useVars = groups.filter((item, index) => index > 0 && isVarUsedInNodes([id, item.group_name, 'output'])) if (useVars.length > 0) { showRemoveVarConfirm() setRemovedVars(useVars.map(item => [id, item.group_name, 'output'])) @@ -82,7 +89,7 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => { } } - handleOutVarRenameChange(id, [id, inputs.advanced_settings.groups[0].group_name, 'output'], [id, 'output']) + handleOutVarRenameChange(id, [id, groups[0].group_name, 'output'], [id, 'output']) } setInputs(toggleGroupEnabled({ inputs, enabled })) @@ -110,11 +117,16 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => { const handleVarGroupNameChange = useCallback((groupId: string) => { return (name: string) => { - const index = inputs.advanced_settings.groups.findIndex(item => item.groupId === groupId) - handleOutVarRenameChange(id, [id, inputs.advanced_settings.groups[index].group_name, 'output'], [id, name, 'output']) + const groups = inputs.advanced_settings?.groups ?? [] + const index = groups.findIndex(item => item.groupId === groupId) + if (index < 0) + return + + const oldName = groups[index].group_name + handleOutVarRenameChange(id, [id, oldName, 'output'], [id, name, 'output']) setInputs(renameGroup(inputs, groupId, name)) if (!(id in oldNameRef.current)) - oldNameRef.current[id] = inputs.advanced_settings.groups[index].group_name + oldNameRef.current[id] = oldName renameInspectNameWithDebounce(id, name) } }, [handleOutVarRenameChange, id, inputs, renameInspectNameWithDebounce, setInputs]) @@ -125,7 +137,8 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => { }) hideRemoveVarConfirm() if (removeType === 'group') { - setInputs(removeGroupByIndex(inputs, removedGroupIndex)) + if (removedGroupIndex >= 0) + setInputs(removeGroupByIndex(inputs, removedGroupIndex)) } else { // removeType === 'enableChanged' to enabled From beda78e91129874d7a2b5377f71f2cc2ca6dc1bb Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Wed, 1 Apr 2026 06:00:05 +0200 Subject: [PATCH 06/42] refactor: select in 13 small service files (#34371) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/services/audio_service.py | 2 +- api/services/billing_service.py | 7 +++-- api/services/conversation_service.py | 12 ++++---- api/services/credit_pool_service.py | 14 ++++----- .../enterprise/account_deletion_sync.py | 5 +++- .../rag_pipeline/pipeline_generate_service.py | 2 +- .../customized/customized_retrieval.py | 12 ++++---- .../database/database_retrieval.py | 11 +++---- .../database/database_retrieval.py | 8 ++--- api/services/web_conversation_service.py | 12 ++++---- api/services/webapp_auth_service.py | 5 ++-- api/services/workflow/workflow_converter.py | 7 +++-- api/services/workspace_service.py | 7 +++-- .../unit_tests/services/test_audio_service.py | 21 ++++--------- .../services/test_billing_service.py | 30 ++++--------------- .../services/test_conversation_service.py | 19 ++++-------- 16 files changed, 72 insertions(+), 102 deletions(-) diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 90e72d5f34f..1c7027efb49 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -132,7 +132,7 @@ class AudioService: uuid.UUID(message_id) except ValueError: return None - message = db.session.query(Message).where(Message.id == message_id).first() + message = db.session.get(Message, message_id) if message is None: return None if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}: diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 54c595e0cbd..9970b2e6040 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -6,6 +6,7 @@ from typing import Literal import httpx from pydantic import TypeAdapter +from sqlalchemy import select from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed from typing_extensions import TypedDict from werkzeug.exceptions import InternalServerError @@ -158,10 +159,10 @@ class BillingService: def is_tenant_owner_or_admin(current_user: Account): tenant_id = current_user.current_tenant_id - join: TenantAccountJoin | None = ( - db.session.query(TenantAccountJoin) + join: TenantAccountJoin | None = db.session.scalar( + select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) - .first() + .limit(1) ) if not join: diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index ba1e7bb8266..95482a2235f 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -137,11 +137,11 @@ class ConversationService: @classmethod def auto_generate_name(cls, app_model: App, conversation: Conversation): # get conversation first message - message = ( - db.session.query(Message) + message = db.session.scalar( + select(Message) .where(Message.app_id == app_model.id, Message.conversation_id == conversation.id) .order_by(Message.created_at.asc()) - .first() + .limit(1) ) if not message: @@ -160,8 +160,8 @@ class ConversationService: @classmethod def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): - conversation = ( - db.session.query(Conversation) + conversation = db.session.scalar( + select(Conversation) .where( Conversation.id == conversation_id, Conversation.app_id == app_model.id, @@ -170,7 +170,7 @@ class ConversationService: Conversation.from_account_id == (user.id if isinstance(user, Account) else None), Conversation.is_deleted == False, ) - .first() + .limit(1) ) if not conversation: diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 2894826935d..7826695366b 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -1,6 +1,6 @@ import logging -from sqlalchemy import update +from sqlalchemy import select, update from sqlalchemy.orm import Session from configs import dify_config @@ -29,13 +29,13 @@ class CreditPoolService: @classmethod def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None: """get tenant credit pool""" - return ( - db.session.query(TenantCreditPool) - .filter_by( - tenant_id=tenant_id, - pool_type=pool_type, + return db.session.scalar( + select(TenantCreditPool) + .where( + TenantCreditPool.tenant_id == tenant_id, + TenantCreditPool.pool_type == pool_type, ) - .first() + .limit(1) ) @classmethod diff --git a/api/services/enterprise/account_deletion_sync.py b/api/services/enterprise/account_deletion_sync.py index c7ff42894da..b5107fb0f66 100644 --- a/api/services/enterprise/account_deletion_sync.py +++ b/api/services/enterprise/account_deletion_sync.py @@ -4,6 +4,7 @@ import uuid from datetime import UTC, datetime from redis import RedisError +from sqlalchemy import select from configs import dify_config from extensions.ext_database import db @@ -104,7 +105,9 @@ def sync_account_deletion(account_id: str, *, source: str) -> bool: return True # Fetch all workspaces the account belongs to - workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all() + workspace_joins = db.session.scalars( + select(TenantAccountJoin).where(TenantAccountJoin.account_id == account_id) + ).all() # Queue sync task for each workspace success = True diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index 07e1b8f20ed..10e89b1dbab 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -110,7 +110,7 @@ class PipelineGenerateService: Update document status to waiting :param document_id: document id """ - document = db.session.query(Document).where(Document.id == document_id).first() + document = db.session.get(Document, document_id) if document: document.indexing_status = IndexingStatus.WAITING db.session.add(document) diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 4ac2e0792bf..2ee871a2663 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,4 +1,5 @@ import yaml +from sqlalchemy import select from extensions.ext_database import db from libs.login import current_account_with_tenant @@ -32,12 +33,11 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param language: language :return: """ - pipeline_customized_templates = ( - db.session.query(PipelineCustomizedTemplate) + pipeline_customized_templates = db.session.scalars( + select(PipelineCustomizedTemplate) .where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) .order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc()) - .all() - ) + ).all() recommended_pipelines_results = [] for pipeline_customized_template in pipeline_customized_templates: recommended_pipeline_result = { @@ -59,9 +59,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param template_id: Template ID :return: """ - pipeline_template = ( - db.session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first() - ) + pipeline_template = db.session.get(PipelineCustomizedTemplate, template_id) if not pipeline_template: return None diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 908f9a26840..43b21a7b320 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,4 +1,5 @@ import yaml +from sqlalchemy import select from extensions.ext_database import db from models.dataset import PipelineBuiltInTemplate @@ -30,8 +31,10 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :return: """ - pipeline_built_in_templates: list[PipelineBuiltInTemplate] = ( - db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language).all() + pipeline_built_in_templates = list( + db.session.scalars( + select(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language) + ).all() ) recommended_pipelines_results = [] @@ -58,9 +61,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :return: """ # is in public recommended list - pipeline_template = ( - db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.id == template_id).first() - ) + pipeline_template = db.session.get(PipelineBuiltInTemplate, template_id) if not pipeline_template: return None diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index d0c49325dc6..6fb90d356d9 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -77,17 +77,15 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): :return: """ # is in public recommended list - recommended_app = ( - db.session.query(RecommendedApp) - .where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) - .first() + recommended_app = db.session.scalar( + select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id).limit(1) ) if not recommended_app: return None # get app detail - app_model = db.session.query(App).where(App.id == app_id).first() + app_model = db.session.get(App, app_id) if not app_model or not app_model.is_public: return None diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index e028e3e5e3b..5ef9e9be613 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -64,15 +64,15 @@ class WebConversationService: def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): if not user: return - pinned_conversation = ( - db.session.query(PinnedConversation) + pinned_conversation = db.session.scalar( + select(PinnedConversation) .where( PinnedConversation.app_id == app_model.id, PinnedConversation.conversation_id == conversation_id, PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), PinnedConversation.created_by == user.id, ) - .first() + .limit(1) ) if pinned_conversation: @@ -96,15 +96,15 @@ class WebConversationService: def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): if not user: return - pinned_conversation = ( - db.session.query(PinnedConversation) + pinned_conversation = db.session.scalar( + select(PinnedConversation) .where( PinnedConversation.app_id == app_model.id, PinnedConversation.conversation_id == conversation_id, PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), PinnedConversation.created_by == user.id, ) - .first() + .limit(1) ) if not pinned_conversation: diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 5ca0b630014..eaea79af2f1 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -3,6 +3,7 @@ import secrets from datetime import UTC, datetime, timedelta from typing import Any +from sqlalchemy import select from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config @@ -92,10 +93,10 @@ class WebAppAuthService: @classmethod def create_end_user(cls, app_code, email) -> EndUser: - site = db.session.query(Site).where(Site.code == app_code).first() + site = db.session.scalar(select(Site).where(Site.code == app_code).limit(1)) if not site: raise NotFound("Site not found.") - app_model = db.session.query(App).where(App.id == site.app_id).first() + app_model = db.session.get(App, site.app_id) if not app_model: raise NotFound("App not found.") end_user = EndUser( diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 31367f72fab..399c82849f5 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -6,6 +6,7 @@ from graphon.model_runtime.entities.llm_entities import LLMMode from graphon.model_runtime.utils.encoders import jsonable_encoder from graphon.nodes import BuiltinNodeTypes from graphon.variables.input_entities import VariableEntity +from sqlalchemy import select from typing_extensions import TypedDict from core.app.app_config.entities import ( @@ -648,10 +649,10 @@ class WorkflowConverter: :param api_based_extension_id: api based extension id :return: """ - api_based_extension = ( - db.session.query(APIBasedExtension) + api_based_extension = db.session.scalar( + select(APIBasedExtension) .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) - .first() + .limit(1) ) if not api_based_extension: diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 84a8b033296..eb4671cfaa3 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,4 +1,5 @@ from flask_login import current_user +from sqlalchemy import select from configs import dify_config from enums.cloud_plan import CloudPlan @@ -24,10 +25,10 @@ class WorkspaceService: } # Get role of user - tenant_account_join = ( - db.session.query(TenantAccountJoin) + tenant_account_join = db.session.scalar( + select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) - .first() + .limit(1) ) assert tenant_account_join is not None, "TenantAccountJoin not found" tenant_info["role"] = tenant_account_join.role diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index 175fd3ee016..cede6671cea 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -421,11 +421,8 @@ class TestAudioServiceTTS: answer="Message answer text", ) - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = message + # Mock database lookup + mock_db_session.get.return_value = message # Mock ModelManager mock_model_manager = mock_model_manager_class.return_value @@ -568,11 +565,8 @@ class TestAudioServiceTTS: # Arrange app = factory.create_app_mock() - # Mock database query returning None - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None + # Mock database lookup returning None + mock_db_session.get.return_value = None # Act result = AudioService.transcript_tts( @@ -594,11 +588,8 @@ class TestAudioServiceTTS: status=MessageStatus.NORMAL, ) - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = message + # Mock database lookup + mock_db_session.get.return_value = message # Act result = AudioService.transcript_tts( diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index b3d2e608025..168ab6cf0d9 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -865,16 +865,11 @@ class TestBillingServiceAccountManagement: mock_join = MagicMock(spec=TenantAccountJoin) mock_join.role = TenantAccountRole.OWNER - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_join - mock_db_session.query.return_value = mock_query + mock_db_session.scalar.return_value = mock_join # Act - should not raise exception BillingService.is_tenant_owner_or_admin(current_user) - # Assert - mock_db_session.query.assert_called_once() - def test_is_tenant_owner_or_admin_admin(self, mock_db_session): """Test tenant owner/admin check for admin role.""" # Arrange @@ -885,16 +880,11 @@ class TestBillingServiceAccountManagement: mock_join = MagicMock(spec=TenantAccountJoin) mock_join.role = TenantAccountRole.ADMIN - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_join - mock_db_session.query.return_value = mock_query + mock_db_session.scalar.return_value = mock_join # Act - should not raise exception BillingService.is_tenant_owner_or_admin(current_user) - # Assert - mock_db_session.query.assert_called_once() - def test_is_tenant_owner_or_admin_normal_user_raises_error(self, mock_db_session): """Test tenant owner/admin check raises error for normal user.""" # Arrange @@ -905,9 +895,7 @@ class TestBillingServiceAccountManagement: mock_join = MagicMock(spec=TenantAccountJoin) mock_join.role = TenantAccountRole.NORMAL - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_join - mock_db_session.query.return_value = mock_query + mock_db_session.scalar.return_value = mock_join # Act & Assert with pytest.raises(ValueError) as exc_info: @@ -921,9 +909,7 @@ class TestBillingServiceAccountManagement: current_user.id = "account-123" current_user.current_tenant_id = "tenant-456" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = None - mock_db_session.query.return_value = mock_query + mock_db_session.scalar.return_value = None # Act & Assert with pytest.raises(ValueError) as exc_info: @@ -1135,9 +1121,7 @@ class TestBillingServiceEdgeCases: mock_join.role = TenantAccountRole.EDITOR # Editor is not privileged with patch("services.billing_service.db.session") as mock_session: - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_join - mock_session.query.return_value = mock_query + mock_session.scalar.return_value = mock_join # Act & Assert with pytest.raises(ValueError) as exc_info: @@ -1155,9 +1139,7 @@ class TestBillingServiceEdgeCases: mock_join.role = TenantAccountRole.DATASET_OPERATOR # Dataset operator is not privileged with patch("services.billing_service.db.session") as mock_session: - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_join - mock_session.query.return_value = mock_query + mock_session.scalar.return_value = mock_join # Act & Assert with pytest.raises(ValueError) as exc_info: diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index 1bf4c0e1721..a4359f00b8a 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -355,15 +355,13 @@ class TestConversationServiceGetConversation: from_account_id=user.id, from_source=ConversationFromSource.CONSOLE ) - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.first.return_value = conversation + mock_db_session.scalar.return_value = conversation # Act result = ConversationService.get_conversation(app_model, "conv-123", user) # Assert assert result == conversation - mock_db_session.query.assert_called_once_with(Conversation) @patch("services.conversation_service.db.session") def test_get_conversation_success_with_end_user(self, mock_db_session): @@ -379,8 +377,7 @@ class TestConversationServiceGetConversation: from_end_user_id=user.id, from_source=ConversationFromSource.API ) - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.first.return_value = conversation + mock_db_session.scalar.return_value = conversation # Act result = ConversationService.get_conversation(app_model, "conv-123", user) @@ -399,8 +396,7 @@ class TestConversationServiceGetConversation: app_model = ConversationServiceTestDataFactory.create_app_mock() user = ConversationServiceTestDataFactory.create_account_mock() - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.first.return_value = None + mock_db_session.scalar.return_value = None # Act & Assert with pytest.raises(ConversationNotExistsError): @@ -489,8 +485,7 @@ class TestConversationServiceAutoGenerateName: ) # Mock database query to return message - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.order_by.return_value.first.return_value = message + mock_db_session.scalar.return_value = message # Mock LLM generator mock_llm_generator.generate_conversation_name.return_value = "Generated Name" @@ -518,8 +513,7 @@ class TestConversationServiceAutoGenerateName: conversation = ConversationServiceTestDataFactory.create_conversation_mock() # Mock database query to return None - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.order_by.return_value.first.return_value = None + mock_db_session.scalar.return_value = None # Act & Assert with pytest.raises(MessageNotExistsError): @@ -541,8 +535,7 @@ class TestConversationServiceAutoGenerateName: ) # Mock database query to return message - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.order_by.return_value.first.return_value = message + mock_db_session.scalar.return_value = message # Mock LLM generator to raise exception mock_llm_generator.generate_conversation_name.side_effect = Exception("LLM Error") From 09ee8ea1f535fc86a41e8370ef520abbe10ac54f Mon Sep 17 00:00:00 2001 From: Full Stack Engineer <66432853+EndlessLucky@users.noreply.github.com> Date: Wed, 1 Apr 2026 00:22:23 -0400 Subject: [PATCH 07/42] fix: support qa_preview shape in IndexProcessor preview formatting (#34151) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Crazywoola <100913391+crazywoola@users.noreply.github.com> --- api/core/rag/index_processor/index_processor.py | 9 ++++++++- .../core/rag/indexing/test_index_processor.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 api/tests/unit_tests/core/rag/indexing/test_index_processor.py diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index a6d1db214b0..825ae012269 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -35,7 +35,10 @@ class IndexProcessor: if "parent_mode" in preview: data.parent_mode = preview["parent_mode"] - for item in preview["preview"]: + # Different index processors return different preview shapes: + # - paragraph/parent-child processors: {"preview": [...]} + # - QA processor: {"qa_preview": [...]} (no "preview" key) + for item in preview.get("preview", []): if "content" in item and "child_chunks" in item: data.preview.append( PreviewItem(content=item["content"], child_chunks=item["child_chunks"], summary=None) @@ -44,6 +47,10 @@ class IndexProcessor: data.qa_preview.append(QaPreview(question=item["question"], answer=item["answer"])) elif "content" in item: data.preview.append(PreviewItem(content=item["content"], child_chunks=None, summary=None)) + + for item in preview.get("qa_preview", []): + if "question" in item and "answer" in item: + data.qa_preview.append(QaPreview(question=item["question"], answer=item["answer"])) return data def index_and_clean( diff --git a/api/tests/unit_tests/core/rag/indexing/test_index_processor.py b/api/tests/unit_tests/core/rag/indexing/test_index_processor.py new file mode 100644 index 00000000000..a3f284955bc --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/test_index_processor.py @@ -0,0 +1,15 @@ +from core.rag.index_processor.index_processor import IndexProcessor + + +class TestIndexProcessor: + def test_format_preview_supports_qa_preview_shape(self) -> None: + preview = IndexProcessor().format_preview( + "qa_model", + {"qa_chunks": [{"question": "Q1", "answer": "A1"}]}, + ) + + assert preview.chunk_structure == "qa_model" + assert preview.total_segments == 1 + assert len(preview.qa_preview) == 1 + assert preview.qa_preview[0].question == "Q1" + assert preview.qa_preview[0].answer == "A1" From c51cd42cb4e21320664b6d0e9efcf2ecbd1ddec5 Mon Sep 17 00:00:00 2001 From: Dream <42954461+eureka928@users.noreply.github.com> Date: Wed, 1 Apr 2026 01:41:44 -0400 Subject: [PATCH 08/42] refactor(api): replace json.loads with Pydantic validation in controllers and infra layers (#34277) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/workflow.py | 12 ++--- .../rag_pipeline/rag_pipeline_workflow.py | 23 ++-------- .../arize_phoenix_trace.py | 3 +- api/core/ops/mlflow_trace/mlflow_trace.py | 10 ++--- api/core/ops/ops_trace_manager.py | 23 +++++++--- api/core/ops/utils.py | 3 ++ .../alibabacloud_mysql_vector.py | 15 +++---- .../analyticdb/analyticdb_vector_openapi.py | 5 ++- .../rag/datasource/vdb/baidu/baidu_vector.py | 13 ++---- .../vdb/clickzetta/clickzetta_vector.py | 32 ++++++------- api/core/rag/datasource/vdb/field.py | 20 +++++++++ .../vdb/hologres/hologres_vector.py | 7 ++- .../rag/datasource/vdb/iris/iris_vector.py | 5 ++- .../vdb/matrixone/matrixone_vector.py | 7 +-- .../vdb/oceanbase/oceanbase_vector.py | 5 ++- .../vdb/tablestore/tablestore_vector.py | 9 ++-- .../datasource/vdb/tencent/tencent_vector.py | 12 +++-- .../datasource/vdb/tidb_vector/tidb_vector.py | 4 +- .../vdb/vikingdb/vikingdb_vector.py | 7 ++- ...tore_workflow_node_execution_repository.py | 9 ++-- .../clickzetta_volume/file_lifecycle.py | 8 +++- .../storage/google_cloud_storage.py | 7 ++- .../core/rag/datasource/vdb/test_field.py | 45 +++++++++++++++++++ 23 files changed, 170 insertions(+), 114 deletions(-) create mode 100644 api/tests/unit_tests/core/rag/datasource/vdb/test_field.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 6df8f7032ec..dcd24d2200f 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -9,7 +9,7 @@ from graphon.enums import NodeType from graphon.file import File from graphon.graph_engine.manager import GraphEngineManager from graphon.model_runtime.utils.encoders import jsonable_encoder -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, ValidationError, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -268,22 +268,18 @@ class DraftWorkflowApi(Resource): content_type = request.headers.get("Content-Type", "") - payload_data: dict[str, Any] | None = None if "application/json" in content_type: payload_data = request.get_json(silent=True) if not isinstance(payload_data, dict): return {"message": "Invalid JSON data"}, 400 + args_model = SyncDraftWorkflowPayload.model_validate(payload_data) elif "text/plain" in content_type: try: - payload_data = json.loads(request.data.decode("utf-8")) - except json.JSONDecodeError: - return {"message": "Invalid JSON data"}, 400 - if not isinstance(payload_data, dict): + args_model = SyncDraftWorkflowPayload.model_validate_json(request.data) + except (ValueError, ValidationError): return {"message": "Invalid JSON data"}, 400 else: abort(415) - - args_model = SyncDraftWorkflowPayload.model_validate(payload_data) args = args_model.model_dump() workflow_service = WorkflowService() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index e08cb155b6a..4251e7ebac3 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -5,7 +5,7 @@ from typing import Any, Literal, cast from flask import abort, request from flask_restx import Resource, marshal_with # type: ignore from graphon.model_runtime.utils.encoders import jsonable_encoder -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -186,29 +186,14 @@ class DraftRagPipelineApi(Resource): if "application/json" in content_type: payload_dict = console_ns.payload or {} + payload = DraftWorkflowSyncPayload.model_validate(payload_dict) elif "text/plain" in content_type: try: - data = json.loads(request.data.decode("utf-8")) - if "graph" not in data or "features" not in data: - raise ValueError("graph or features not found in data") - - if not isinstance(data.get("graph"), dict): - raise ValueError("graph is not a dict") - - payload_dict = { - "graph": data.get("graph"), - "features": data.get("features"), - "hash": data.get("hash"), - "environment_variables": data.get("environment_variables"), - "conversation_variables": data.get("conversation_variables"), - "rag_pipeline_variables": data.get("rag_pipeline_variables"), - } - except json.JSONDecodeError: + payload = DraftWorkflowSyncPayload.model_validate_json(request.data) + except (ValueError, ValidationError): return {"message": "Invalid JSON data"}, 400 else: abort(415) - - payload = DraftWorkflowSyncPayload.model_validate(payload_dict) rag_pipeline_service = RagPipelineService() try: diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 902f58e6b7b..66933cea287 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -38,6 +38,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) +from core.ops.utils import JSON_DICT_ADAPTER from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db from models.model import EndUser, MessageFile @@ -469,7 +470,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider if trace_info.message_data and trace_info.message_data.message_metadata: - metadata_dict = json.loads(trace_info.message_data.message_metadata) + metadata_dict = JSON_DICT_ADAPTER.validate_json(trace_info.message_data.message_metadata) if model_params := metadata_dict.get("model_parameters"): llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params) diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/core/ops/mlflow_trace/mlflow_trace.py index 946d3cdd479..3d8c1dd038a 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/core/ops/mlflow_trace/mlflow_trace.py @@ -1,4 +1,3 @@ -import json import logging import os from datetime import datetime, timedelta @@ -25,6 +24,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) +from core.ops.utils import JSON_DICT_ADAPTER from extensions.ext_database import db from models import EndUser from models.workflow import WorkflowNodeExecutionModel @@ -153,7 +153,7 @@ class MLflowDataTrace(BaseTraceInstance): inputs = node.process_data # contains request URL if not inputs: - inputs = json.loads(node.inputs) if node.inputs else {} + inputs = JSON_DICT_ADAPTER.validate_json(node.inputs) if node.inputs else {} node_span = start_span_no_context( name=node.title, @@ -180,7 +180,7 @@ class MLflowDataTrace(BaseTraceInstance): # End node span finished_at = node.created_at + timedelta(seconds=node.elapsed_time) - outputs = json.loads(node.outputs) if node.outputs else {} + outputs = JSON_DICT_ADAPTER.validate_json(node.outputs) if node.outputs else {} if node.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: outputs = self._parse_knowledge_retrieval_outputs(outputs) elif node.node_type == BuiltinNodeTypes.LLM: @@ -216,8 +216,8 @@ class MLflowDataTrace(BaseTraceInstance): return {}, {} try: - data = json.loads(node.process_data) - except (json.JSONDecodeError, TypeError): + data = JSON_DICT_ADAPTER.validate_json(node.process_data) + except (ValueError, TypeError): return {}, {} inputs = self._parse_prompts(data.get("prompts")) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 9c36d57c6f5..c689a86614c 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -11,8 +11,10 @@ from uuid import UUID, uuid4 from cachetools import LRUCache from flask import current_app +from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker +from typing_extensions import TypedDict from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token from core.ops.entities.config_entity import ( @@ -33,7 +35,7 @@ from core.ops.entities.trace_entity import ( WorkflowNodeTraceInfo, WorkflowTraceInfo, ) -from core.ops.utils import get_message_data +from core.ops.utils import JSON_DICT_ADAPTER, get_message_data from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Tenant @@ -50,6 +52,14 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class _AppTracingConfig(TypedDict, total=False): + enabled: bool + tracing_provider: str | None + + +_app_tracing_config_adapter: TypeAdapter[_AppTracingConfig] = TypeAdapter(_AppTracingConfig) + + def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]: """Return (app_name, workspace_name) for the given IDs. Falls back to empty strings.""" app_name = "" @@ -468,7 +478,7 @@ class OpsTraceManager: if app is None: return None - app_ops_trace_config = json.loads(app.tracing) if app.tracing else None + app_ops_trace_config = _app_tracing_config_adapter.validate_json(app.tracing) if app.tracing else None if app_ops_trace_config is None: return None if not app_ops_trace_config.get("enabled"): @@ -560,7 +570,7 @@ class OpsTraceManager: raise ValueError("App not found") if not app.tracing: return {"enabled": False, "tracing_provider": None} - app_trace_config = json.loads(app.tracing) + app_trace_config = _app_tracing_config_adapter.validate_json(app.tracing) return app_trace_config @staticmethod @@ -636,7 +646,6 @@ class TraceTask: carries ``total_tokens``. Projects only the ``outputs`` column to avoid loading large JSON blobs unnecessarily. """ - import json from models.workflow import WorkflowNodeExecutionModel @@ -658,7 +667,7 @@ class TraceTask: if not raw: continue try: - outputs = json.loads(raw) if isinstance(raw, str) else raw + outputs = JSON_DICT_ADAPTER.validate_json(raw) if isinstance(raw, str) else raw except (ValueError, TypeError): continue if not isinstance(outputs, dict): @@ -1420,7 +1429,7 @@ class TraceTask: return {} try: - metadata = json.loads(message_data.message_metadata) + metadata = JSON_DICT_ADAPTER.validate_json(message_data.message_metadata) usage = metadata.get("usage", {}) time_to_first_token = usage.get("time_to_first_token") time_to_generate = usage.get("time_to_generate") @@ -1430,7 +1439,7 @@ class TraceTask: "llm_streaming_time_to_generate": time_to_generate, "is_streaming_request": time_to_first_token is not None, } - except (json.JSONDecodeError, AttributeError): + except (ValueError, AttributeError): return {} diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 8b9a2e424a3..a6f10c09acc 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -3,11 +3,14 @@ from datetime import datetime from typing import Any, Union from urllib.parse import urlparse +from pydantic import TypeAdapter from sqlalchemy import select from models.engine import db from models.model import Message +JSON_DICT_ADAPTER: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any]) + def filter_none_values(data: dict[str, Any]) -> dict[str, Any]: new_data = {} diff --git a/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py b/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py index fdb5ffebfcb..6e76827a422 100644 --- a/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py +++ b/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py @@ -10,6 +10,7 @@ from mysql.connector import Error as MySQLError from pydantic import BaseModel, model_validator from configs import dify_config +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -178,9 +179,7 @@ class AlibabaCloudMySQLVector(BaseVector): cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids) docs = [] for record in cur: - metadata = record["meta"] - if isinstance(metadata, str): - metadata = json.loads(metadata) + metadata = parse_metadata_json(record["meta"]) docs.append(Document(page_content=record["text"], metadata=metadata)) return docs @@ -263,15 +262,13 @@ class AlibabaCloudMySQLVector(BaseVector): # similarity = 1 / (1 + distance) similarity = 1.0 / (1.0 + distance) - metadata = record["meta"] - if isinstance(metadata, str): - metadata = json.loads(metadata) + metadata = parse_metadata_json(record["meta"]) metadata["score"] = similarity metadata["distance"] = distance if similarity >= score_threshold: docs.append(Document(page_content=record["text"], metadata=metadata)) - except (ValueError, json.JSONDecodeError) as e: + except (ValueError, TypeError) as e: logger.warning("Error processing search result: %s", e) continue @@ -306,9 +303,7 @@ class AlibabaCloudMySQLVector(BaseVector): ) docs = [] for record in cur: - metadata = record["meta"] - if isinstance(metadata, str): - metadata = json.loads(metadata) + metadata = parse_metadata_json(record["meta"]) metadata["score"] = float(record["score"]) docs.append(Document(page_content=record["text"], metadata=metadata)) return docs diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index 702200e0ac0..ce626bbd7e1 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -8,6 +8,7 @@ _import_err_msg = ( "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`" ) +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.models.document import Document from extensions.ext_redis import redis_client @@ -257,7 +258,7 @@ class AnalyticdbVectorOpenAPI: documents = [] for match in response.body.matches.match: if match.score >= score_threshold: - metadata = json.loads(match.metadata.get("metadata_")) + metadata = parse_metadata_json(match.metadata.get("metadata_")) metadata["score"] = match.score doc = Document( page_content=match.metadata.get("page_content"), @@ -294,7 +295,7 @@ class AnalyticdbVectorOpenAPI: documents = [] for match in response.body.matches.match: if match.score >= score_threshold: - metadata = json.loads(match.metadata.get("metadata_")) + metadata = parse_metadata_json(match.metadata.get("metadata_")) metadata["score"] = match.score doc = Document( page_content=match.metadata.get("page_content"), diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 9f5842e4493..3173920c9c5 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -29,6 +29,7 @@ from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams, from configs import dify_config from core.rag.datasource.vdb.field import Field as VDBField +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -173,15 +174,9 @@ class BaiduVector(BaseVector): score = row.get("score", 0.0) meta = row_data.get(VDBField.METADATA_KEY, {}) - # Handle both JSON string and dict formats for backward compatibility - if isinstance(meta, str): - try: - import json - - meta = json.loads(meta) - except (json.JSONDecodeError, TypeError): - meta = {} - elif not isinstance(meta, dict): + try: + meta = parse_metadata_json(meta) + except (ValueError, TypeError): meta = {} if score >= score_threshold: diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index 8e8120fc107..a4dddc68f07 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from clickzetta.connector.v0.connection import Connection # type: ignore from configs import dify_config -from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.field import Field, parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.embedding.embedding_base import Embeddings @@ -357,18 +357,19 @@ class ClickzettaVector(BaseVector): """ try: if raw_metadata: - metadata = json.loads(raw_metadata) + # First parse may yield a string (double-encoded JSON) so use json.loads + first_pass = json.loads(raw_metadata) # Handle double-encoded JSON - if isinstance(metadata, str): - metadata = json.loads(metadata) - - # Ensure we have a dict - if not isinstance(metadata, dict): + if isinstance(first_pass, str): + metadata = parse_metadata_json(first_pass) + elif isinstance(first_pass, dict): + metadata = first_pass + else: metadata = {} else: metadata = {} - except (json.JSONDecodeError, TypeError): + except (json.JSONDecodeError, ValueError, TypeError): logger.exception("JSON parsing failed for metadata") # Fallback: extract document_id with regex doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', raw_metadata or "") @@ -930,17 +931,18 @@ class ClickzettaVector(BaseVector): # Parse metadata from JSON string (may be double-encoded) try: if row[2]: - metadata = json.loads(row[2]) + # First parse may yield a string (double-encoded JSON) + first_pass = json.loads(row[2]) - # If result is a string, it's double-encoded JSON - parse again - if isinstance(metadata, str): - metadata = json.loads(metadata) - - if not isinstance(metadata, dict): + if isinstance(first_pass, str): + metadata = parse_metadata_json(first_pass) + elif isinstance(first_pass, dict): + metadata = first_pass + else: metadata = {} else: metadata = {} - except (json.JSONDecodeError, TypeError): + except (json.JSONDecodeError, ValueError, TypeError): logger.exception("JSON parsing failed") # Fallback: extract document_id with regex diff --git a/api/core/rag/datasource/vdb/field.py b/api/core/rag/datasource/vdb/field.py index 8fc94be3603..5a0fabc5726 100644 --- a/api/core/rag/datasource/vdb/field.py +++ b/api/core/rag/datasource/vdb/field.py @@ -1,4 +1,24 @@ from enum import StrEnum, auto +from typing import Any + +from pydantic import TypeAdapter + +_metadata_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any]) + + +def parse_metadata_json(raw: Any) -> dict[str, Any]: + """Parse metadata from a JSON string or pass through an existing dict. + + Many VDB drivers return metadata as either a JSON string or an already- + decoded dict depending on the column type and driver version. + """ + if raw is None or raw in ("", b""): + return {} + if isinstance(raw, dict): + return raw + if not isinstance(raw, (str, bytes, bytearray)): + return {} + return _metadata_adapter.validate_json(raw) class Field(StrEnum): diff --git a/api/core/rag/datasource/vdb/hologres/hologres_vector.py b/api/core/rag/datasource/vdb/hologres/hologres_vector.py index 36b259e494c..13d48b5668d 100644 --- a/api/core/rag/datasource/vdb/hologres/hologres_vector.py +++ b/api/core/rag/datasource/vdb/hologres/hologres_vector.py @@ -9,6 +9,7 @@ from psycopg import sql as psql from pydantic import BaseModel, model_validator from configs import dify_config +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -217,8 +218,7 @@ class HologresVector(BaseVector): text = row[2] meta = row[3] - if isinstance(meta, str): - meta = json.loads(meta) + meta = parse_metadata_json(meta) # Convert distance to similarity score (consistent with pgvector) score = 1 - distance @@ -265,8 +265,7 @@ class HologresVector(BaseVector): meta = row[2] score = row[-1] # score is the last column from return_score - if isinstance(meta, str): - meta = json.loads(meta) + meta = parse_metadata_json(meta) meta["score"] = score docs.append(Document(page_content=text, metadata=meta)) diff --git a/api/core/rag/datasource/vdb/iris/iris_vector.py b/api/core/rag/datasource/vdb/iris/iris_vector.py index 50bb2429ec9..aae445e6ff4 100644 --- a/api/core/rag/datasource/vdb/iris/iris_vector.py +++ b/api/core/rag/datasource/vdb/iris/iris_vector.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Any from configs import dify_config from configs.middleware.vdb.iris_config import IrisVectorConfig +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -269,7 +270,7 @@ class IrisVector(BaseVector): if len(row) >= 4: text, meta_str, score = row[1], row[2], float(row[3]) if score >= score_threshold: - metadata = json.loads(meta_str) if meta_str else {} + metadata = parse_metadata_json(meta_str) metadata["score"] = score docs.append(Document(page_content=text, metadata=metadata)) return docs @@ -384,7 +385,7 @@ class IrisVector(BaseVector): meta_str = row[2] score_value = row[3] - metadata = json.loads(meta_str) if meta_str else {} + metadata = parse_metadata_json(meta_str) # Add score to metadata for hybrid search compatibility score = float(score_value) if score_value is not None else 0.0 metadata["score"] = score diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py index 14955c8d7ca..09ef4987156 100644 --- a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py +++ b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py @@ -9,6 +9,7 @@ from mo_vector.client import MoVectorClient # type: ignore from pydantic import BaseModel, model_validator from configs import dify_config +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -196,11 +197,7 @@ class MatrixoneVector(BaseVector): docs = [] for result in results: - metadata = result.metadata - if isinstance(metadata, str): - import json - - metadata = json.loads(metadata) + metadata = parse_metadata_json(result.metadata) score = 1 - result.distance if score >= score_threshold: metadata["score"] = score diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index 86c1e65f47e..82f419871c6 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -10,6 +10,7 @@ from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.exc import SQLAlchemyError from configs import dify_config +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -366,8 +367,8 @@ class OceanBaseVector(BaseVector): # Parse metadata JSON try: - metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else metadata_str - except json.JSONDecodeError: + metadata = parse_metadata_json(metadata_str) + except (ValueError, TypeError): logger.warning("Invalid JSON metadata: %s", metadata_str) metadata = {} diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index f2156afa59e..4a734232ec1 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, model_validator from tablestore import BatchGetRowRequest, TableInBatchGetRowItem from configs import dify_config -from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.field import Field, parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -73,7 +73,8 @@ class TableStoreVector(BaseVector): for item in table_result: if item.is_ok and item.row: kv = {k: v for k, v, _ in item.row.attribute_columns} - docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=json.loads(kv[Field.METADATA_KEY]))) + metadata = parse_metadata_json(kv[Field.METADATA_KEY]) + docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=metadata)) return docs def get_type(self) -> str: @@ -311,7 +312,7 @@ class TableStoreVector(BaseVector): metadata_str = ots_column_map.get(Field.METADATA_KEY) vector = json.loads(vector_str) if vector_str else None - metadata = json.loads(metadata_str) if metadata_str else {} + metadata = parse_metadata_json(metadata_str) metadata["score"] = search_hit.score @@ -371,7 +372,7 @@ class TableStoreVector(BaseVector): ots_column_map[col[0]] = col[1] metadata_str = ots_column_map.get(Field.METADATA_KEY) - metadata = json.loads(metadata_str) if metadata_str else {} + metadata = parse_metadata_json(metadata_str) vector_str = ots_column_map.get(Field.VECTOR) vector = json.loads(vector_str) if vector_str else None diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 291d047c046..829db9db20a 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -11,6 +11,7 @@ from tcvectordb.model import index as vdb_index # type: ignore from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, WeightedRerank # type: ignore from configs import dify_config +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -286,13 +287,10 @@ class TencentVector(BaseVector): return docs for result in res[0]: - meta = result.get(self.field_metadata) - if isinstance(meta, str): - # Compatible with version 1.1.3 and below. - meta = json.loads(meta) - score = 1 - result.get("score", 0.0) - else: - score = result.get("score", 0.0) + raw_meta = result.get(self.field_metadata) + # Compatible with version 1.1.3 and below: str means old driver. + score = (1 - result.get("score", 0.0)) if isinstance(raw_meta, str) else result.get("score", 0.0) + meta = parse_metadata_json(raw_meta) if score >= score_threshold: meta["score"] = score doc = Document(page_content=result.get(self.field_text), metadata=meta) diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 27ae038a064..c9489173741 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -9,7 +9,7 @@ from sqlalchemy import text as sql_text from sqlalchemy.orm import Session, declarative_base from configs import dify_config -from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.field import Field, parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -228,7 +228,7 @@ class TiDBVector(BaseVector): ) results = [(row[0], row[1], row[2]) for row in res] for meta, text, distance in results: - metadata = json.loads(meta) + metadata = parse_metadata_json(meta) metadata["score"] = 1 - distance docs.append(Document(page_content=text, metadata=metadata)) return docs diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index e5feecf2bc8..83fd3626d96 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -15,6 +15,7 @@ from volcengine.viking_db import ( # type: ignore from configs import dify_config from core.rag.datasource.vdb.field import Field as vdb_Field +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -163,7 +164,7 @@ class VikingDBVector(BaseVector): for result in results: metadata = result.fields.get(vdb_Field.METADATA_KEY) if metadata is not None: - metadata = json.loads(metadata) + metadata = parse_metadata_json(metadata) if metadata.get(key) == value: ids.append(result.id) return ids @@ -189,9 +190,7 @@ class VikingDBVector(BaseVector): docs = [] for result in results: - metadata = result.fields.get(vdb_Field.METADATA_KEY) - if metadata is not None: - metadata = json.loads(metadata) + metadata = parse_metadata_json(result.fields.get(vdb_Field.METADATA_KEY)) if result.score >= score_threshold: metadata["score"] = result.score doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY), metadata=metadata) diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index b7254366817..0e9a19b8214 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -20,6 +20,7 @@ from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker +from core.ops.utils import JSON_DICT_ADAPTER from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.logstore.aliyun_logstore import AliyunLogStore @@ -48,10 +49,10 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut """ logger.debug("_dict_to_workflow_node_execution: data keys=%s", list(data.keys())[:5]) # Parse JSON fields - inputs = json.loads(data.get("inputs", "{}")) - process_data = json.loads(data.get("process_data", "{}")) - outputs = json.loads(data.get("outputs", "{}")) - metadata = json.loads(data.get("execution_metadata", "{}")) + inputs = JSON_DICT_ADAPTER.validate_json(data.get("inputs") or "{}") + process_data = JSON_DICT_ADAPTER.validate_json(data.get("process_data") or "{}") + outputs = JSON_DICT_ADAPTER.validate_json(data.get("outputs") or "{}") + metadata = JSON_DICT_ADAPTER.validate_json(data.get("execution_metadata") or "{}") # Convert metadata to domain enum keys domain_metadata = {} diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index 1d9911465be..483bd6bbf69 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -15,8 +15,12 @@ from datetime import datetime from enum import StrEnum, auto from typing import Any +from pydantic import TypeAdapter + logger = logging.getLogger(__name__) +_metadata_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any]) + class FileStatus(StrEnum): """File status enumeration""" @@ -455,8 +459,8 @@ class FileLifecycleManager: try: if self._storage.exists(self._metadata_file): metadata_content = self._storage.load_once(self._metadata_file) - result = json.loads(metadata_content.decode("utf-8")) - return dict(result) if result else {} + result = _metadata_adapter.validate_json(metadata_content) + return result or {} else: return {} except Exception as e: diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index 4ad7e2d159d..00f7289aa4f 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -1,13 +1,16 @@ import base64 import io -import json from collections.abc import Generator +from typing import Any from google.cloud import storage as google_cloud_storage # type: ignore +from pydantic import TypeAdapter from configs import dify_config from extensions.storage.base_storage import BaseStorage +_service_account_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any]) + class GoogleCloudStorage(BaseStorage): """Implementation for Google Cloud storage.""" @@ -21,7 +24,7 @@ class GoogleCloudStorage(BaseStorage): if service_account_json_str: service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") # convert str to object - service_account_obj = json.loads(service_account_json) + service_account_obj = _service_account_adapter.validate_json(service_account_json) self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj) else: self.client = google_cloud_storage.Client() diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_field.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_field.py new file mode 100644 index 00000000000..d68c93b021b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_field.py @@ -0,0 +1,45 @@ +import pytest + +from core.rag.datasource.vdb.field import parse_metadata_json + + +class TestParseMetadataJson: + def test_none_returns_empty_dict(self): + assert parse_metadata_json(None) == {} + + def test_empty_string_returns_empty_dict(self): + assert parse_metadata_json("") == {} + + def test_valid_json_string(self): + result = parse_metadata_json('{"doc_id": "abc", "score": 0.9}') + assert result == {"doc_id": "abc", "score": 0.9} + + def test_dict_passthrough(self): + original = {"doc_id": "abc", "document_id": "123"} + result = parse_metadata_json(original) + assert result == original + + def test_empty_json_object(self): + assert parse_metadata_json("{}") == {} + + def test_invalid_json_raises_value_error(self): + with pytest.raises(ValueError): + parse_metadata_json("{invalid json") + + def test_nested_metadata(self): + result = parse_metadata_json('{"doc_id": "1", "extra": {"nested": true}}') + assert result["extra"]["nested"] is True + + def test_non_str_non_dict_returns_empty_dict(self): + assert parse_metadata_json(123) == {} + assert parse_metadata_json([1, 2]) == {} + + def test_bytes_input(self): + result = parse_metadata_json(b'{"key": "value"}') + assert result == {"key": "value"} + + def test_empty_bytes_returns_empty_dict(self): + assert parse_metadata_json(b"") == {} + + def test_empty_bytearray_returns_empty_dict(self): + assert parse_metadata_json(bytearray(b"")) == {} From b23ea0397a756d7b6f267c5789a292eabbb1c502 Mon Sep 17 00:00:00 2001 From: jimmyzhuu Date: Wed, 1 Apr 2026 14:16:09 +0800 Subject: [PATCH 09/42] fix: apply Baidu Vector DB connection timeout when initializing Mochow client (#34328) --- api/core/rag/datasource/vdb/baidu/baidu_vector.py | 6 +++++- .../rag/datasource/vdb/baidu/test_baidu_vector.py | 13 +++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 3173920c9c5..2b220fc04dc 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -195,7 +195,11 @@ class BaiduVector(BaseVector): raise def _init_client(self, config) -> MochowClient: - config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint) + config = Configuration( + credentials=BceCredentials(config.account, config.api_key), + endpoint=config.endpoint, + connection_timeout_in_mills=config.connection_timeout_in_mills, + ) client = MochowClient(config) return client diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py index c46c3d5e4bd..487d0216970 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py @@ -381,13 +381,22 @@ def test_init_client_constructs_configuration_and_client(baidu_module, monkeypat monkeypatch.setattr(baidu_module, "MochowClient", client_cls) vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) - config = SimpleNamespace(account="account", api_key="key", endpoint="https://endpoint") + config = SimpleNamespace( + account="account", + api_key="key", + endpoint="https://endpoint", + connection_timeout_in_mills=12_345, + ) client = vector._init_client(config) assert client == "client" credentials.assert_called_once_with("account", "key") - configuration.assert_called_once_with(credentials="credentials", endpoint="https://endpoint") + configuration.assert_called_once_with( + credentials="credentials", + endpoint="https://endpoint", + connection_timeout_in_mills=12_345, + ) client_cls.assert_called_once_with("configuration") From 31f7752ba9479e69753867cab8e3feafe7c101eb Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Wed, 1 Apr 2026 10:03:49 +0200 Subject: [PATCH 10/42] refactor: select in 10 service files (#34373) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato --- api/services/agent_service.py | 24 +++---- api/services/api_based_extension_service.py | 48 ++++++++------ api/services/app_service.py | 9 ++- api/services/feedback_service.py | 22 +++---- .../rag_pipeline_transform_service.py | 7 +- api/services/recommended_app_service.py | 12 ++-- api/services/saved_message_service.py | 21 +++--- .../tools/builtin_tools_manage_service.py | 9 ++- api/services/vector_service.py | 15 ++--- api/services/workflow_service.py | 24 +++---- .../services/test_feedback_service.py | 21 +++--- .../services/test_vector_service.py | 47 ++++++-------- .../services/test_workflow_service.py | 64 +++++++------------ .../test_builtin_tools_manage_service.py | 4 +- 14 files changed, 147 insertions(+), 180 deletions(-) diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 2b8a3ee5949..d8f4e11e758 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -2,6 +2,7 @@ import threading from typing import Any import pytz +from sqlalchemy import select import contexts from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager @@ -23,25 +24,25 @@ class AgentService: contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - conversation: Conversation | None = ( - db.session.query(Conversation) + conversation: Conversation | None = db.session.scalar( + select(Conversation) .where( Conversation.id == conversation_id, Conversation.app_id == app_model.id, ) - .first() + .limit(1) ) if not conversation: raise ValueError(f"Conversation not found: {conversation_id}") - message: Message | None = ( - db.session.query(Message) + message: Message | None = db.session.scalar( + select(Message) .where( Message.id == message_id, Message.conversation_id == conversation_id, ) - .first() + .limit(1) ) if not message: @@ -51,16 +52,11 @@ class AgentService: if conversation.from_end_user_id: # only select name field - executor = ( - db.session.query(EndUser, EndUser.name).where(EndUser.id == conversation.from_end_user_id).first() - ) + executor_name = db.session.scalar(select(EndUser.name).where(EndUser.id == conversation.from_end_user_id)) else: - executor = db.session.query(Account, Account.name).where(Account.id == conversation.from_account_id).first() + executor_name = db.session.scalar(select(Account.name).where(Account.id == conversation.from_account_id)) - if executor: - executor = executor.name - else: - executor = "Unknown" + executor = executor_name or "Unknown" assert isinstance(current_user, Account) assert current_user.timezone is not None timezone = pytz.timezone(current_user.timezone) diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index 3a0ed41be04..fdb377694bb 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -1,3 +1,5 @@ +from sqlalchemy import select + from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor from core.helper.encrypter import decrypt_token, encrypt_token from extensions.ext_database import db @@ -7,11 +9,12 @@ from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint class APIBasedExtensionService: @staticmethod def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: - extension_list = ( - db.session.query(APIBasedExtension) - .filter_by(tenant_id=tenant_id) - .order_by(APIBasedExtension.created_at.desc()) - .all() + extension_list = list( + db.session.scalars( + select(APIBasedExtension) + .where(APIBasedExtension.tenant_id == tenant_id) + .order_by(APIBasedExtension.created_at.desc()) + ).all() ) for extension in extension_list: @@ -36,11 +39,10 @@ class APIBasedExtensionService: @staticmethod def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: - extension = ( - db.session.query(APIBasedExtension) - .filter_by(tenant_id=tenant_id) - .filter_by(id=api_based_extension_id) - .first() + extension = db.session.scalar( + select(APIBasedExtension) + .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .limit(1) ) if not extension: @@ -58,23 +60,27 @@ class APIBasedExtensionService: if not extension_data.id: # case one: check new data, name must be unique - is_name_existed = ( - db.session.query(APIBasedExtension) - .filter_by(tenant_id=extension_data.tenant_id) - .filter_by(name=extension_data.name) - .first() + is_name_existed = db.session.scalar( + select(APIBasedExtension) + .where( + APIBasedExtension.tenant_id == extension_data.tenant_id, + APIBasedExtension.name == extension_data.name, + ) + .limit(1) ) if is_name_existed: raise ValueError("name must be unique, it is already existed") else: # case two: check existing data, name must be unique - is_name_existed = ( - db.session.query(APIBasedExtension) - .filter_by(tenant_id=extension_data.tenant_id) - .filter_by(name=extension_data.name) - .where(APIBasedExtension.id != extension_data.id) - .first() + is_name_existed = db.session.scalar( + select(APIBasedExtension) + .where( + APIBasedExtension.tenant_id == extension_data.tenant_id, + APIBasedExtension.name == extension_data.name, + APIBasedExtension.id != extension_data.id, + ) + .limit(1) ) if is_name_existed: diff --git a/api/services/app_service.py b/api/services/app_service.py index e9aeb6c43d0..87d52a3159c 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -6,6 +6,7 @@ import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from sqlalchemy import select from configs import dify_config from constants.model_template import default_app_templates @@ -433,9 +434,7 @@ class AppService: meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" elif provider_type == "api": try: - provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first() - ) + provider: ApiToolProvider | None = db.session.get(ApiToolProvider, provider_id) if provider is None: raise ValueError(f"provider not found for tool {tool_name}") meta["tool_icons"][tool_name] = json.loads(provider.icon) @@ -451,7 +450,7 @@ class AppService: :param app_id: app id :return: app code """ - site = db.session.query(Site).where(Site.app_id == app_id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_id).limit(1)) if not site: raise ValueError(f"App with id {app_id} not found") return str(site.code) @@ -463,7 +462,7 @@ class AppService: :param app_code: app code :return: app id """ - site = db.session.query(Site).where(Site.code == app_code).first() + site = db.session.scalar(select(Site).where(Site.code == app_code).limit(1)) if not site: raise ValueError(f"App with code {app_code} not found") return str(site.app_id) diff --git a/api/services/feedback_service.py b/api/services/feedback_service.py index e7473d371b9..d6c338a830d 100644 --- a/api/services/feedback_service.py +++ b/api/services/feedback_service.py @@ -4,7 +4,7 @@ import json from datetime import datetime from flask import Response -from sqlalchemy import or_ +from sqlalchemy import or_, select from extensions.ext_database import db from models.enums import FeedbackRating @@ -41,8 +41,8 @@ class FeedbackService: raise ValueError(f"Unsupported format: {format_type}") # Build base query - query = ( - db.session.query(MessageFeedback, Message, Conversation, App, Account) + stmt = ( + select(MessageFeedback, Message, Conversation, App, Account) .join(Message, MessageFeedback.message_id == Message.id) .join(Conversation, MessageFeedback.conversation_id == Conversation.id) .join(App, MessageFeedback.app_id == App.id) @@ -52,36 +52,36 @@ class FeedbackService: # Apply filters if from_source: - query = query.filter(MessageFeedback.from_source == from_source) + stmt = stmt.where(MessageFeedback.from_source == from_source) if rating: - query = query.filter(MessageFeedback.rating == rating) + stmt = stmt.where(MessageFeedback.rating == rating) if has_comment is not None: if has_comment: - query = query.filter(MessageFeedback.content.isnot(None), MessageFeedback.content != "") + stmt = stmt.where(MessageFeedback.content.isnot(None), MessageFeedback.content != "") else: - query = query.filter(or_(MessageFeedback.content.is_(None), MessageFeedback.content == "")) + stmt = stmt.where(or_(MessageFeedback.content.is_(None), MessageFeedback.content == "")) if start_date: try: start_dt = datetime.strptime(start_date, "%Y-%m-%d") - query = query.filter(MessageFeedback.created_at >= start_dt) + stmt = stmt.where(MessageFeedback.created_at >= start_dt) except ValueError: raise ValueError(f"Invalid start_date format: {start_date}. Use YYYY-MM-DD") if end_date: try: end_dt = datetime.strptime(end_date, "%Y-%m-%d") - query = query.filter(MessageFeedback.created_at <= end_dt) + stmt = stmt.where(MessageFeedback.created_at <= end_dt) except ValueError: raise ValueError(f"Invalid end_date format: {end_date}. Use YYYY-MM-DD") # Order by creation date (newest first) - query = query.order_by(MessageFeedback.created_at.desc()) + stmt = stmt.order_by(MessageFeedback.created_at.desc()) # Execute query - results = query.all() + results = db.session.execute(stmt).all() # Prepare data for export export_data = [] diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 215a8c85285..c3b00fe1094 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -6,6 +6,7 @@ from uuid import uuid4 import yaml from flask_login import current_user +from sqlalchemy import select from constants import DOCUMENT_EXTENSIONS from core.plugin.impl.plugin import PluginInstaller @@ -26,7 +27,7 @@ logger = logging.getLogger(__name__) class RagPipelineTransformService: def transform_dataset(self, dataset_id: str): - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = db.session.get(Dataset, dataset_id) if not dataset: raise ValueError("Dataset not found") if dataset.pipeline_id and dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE: @@ -306,7 +307,7 @@ class RagPipelineTransformService: jina_node_id = "1752491761974" firecrawl_node_id = "1752565402678" - documents = db.session.query(Document).where(Document.dataset_id == dataset.id).all() + documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset.id)).all() for document in documents: data_source_info_dict = document.data_source_info_dict @@ -316,7 +317,7 @@ class RagPipelineTransformService: document.data_source_type = DataSourceType.LOCAL_FILE file_id = data_source_info_dict.get("upload_file_id") if file_id: - file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + file = db.session.get(UploadFile, file_id) if file: data_source_info = json.dumps( { diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 6b211a5632b..9819822103b 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -1,3 +1,5 @@ +from sqlalchemy import select + from configs import dify_config from extensions.ext_database import db from models.model import AccountTrialAppRecord, TrialApp @@ -27,7 +29,7 @@ class RecommendedAppService: apps = result["recommended_apps"] for app in apps: app_id = app["app_id"] - trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first() + trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1)) if trial_app_model: app["can_trial"] = True else: @@ -46,7 +48,7 @@ class RecommendedAppService: result: dict = retrieval_instance.get_recommend_app_detail(app_id) if FeatureService.get_system_features().enable_trial_app: app_id = result["id"] - trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first() + trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1)) if trial_app_model: result["can_trial"] = True else: @@ -60,10 +62,10 @@ class RecommendedAppService: :param app_id: app id :return: """ - account_trial_app_record = ( - db.session.query(AccountTrialAppRecord) + account_trial_app_record = db.session.scalar( + select(AccountTrialAppRecord) .where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id) - .first() + .limit(1) ) if account_trial_app_record: account_trial_app_record.count += 1 diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index d0f4f279683..77d1767c4be 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -1,5 +1,7 @@ from typing import Union +from sqlalchemy import select + from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account @@ -16,16 +18,15 @@ class SavedMessageService: ) -> InfiniteScrollPagination: if not user: raise ValueError("User is required") - saved_messages = ( - db.session.query(SavedMessage) + saved_messages = db.session.scalars( + select(SavedMessage) .where( SavedMessage.app_id == app_model.id, SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), SavedMessage.created_by == user.id, ) .order_by(SavedMessage.created_at.desc()) - .all() - ) + ).all() message_ids = [sm.message_id for sm in saved_messages] return MessageService.pagination_by_last_id( @@ -36,15 +37,15 @@ class SavedMessageService: def save(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): if not user: return - saved_message = ( - db.session.query(SavedMessage) + saved_message = db.session.scalar( + select(SavedMessage) .where( SavedMessage.app_id == app_model.id, SavedMessage.message_id == message_id, SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), SavedMessage.created_by == user.id, ) - .first() + .limit(1) ) if saved_message: @@ -66,15 +67,15 @@ class SavedMessageService: def delete(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): if not user: return - saved_message = ( - db.session.query(SavedMessage) + saved_message = db.session.scalar( + select(SavedMessage) .where( SavedMessage.app_id == app_model.id, SavedMessage.message_id == message_id, SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), SavedMessage.created_by == user.id, ) - .first() + .limit(1) ) if not saved_message: diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 8e3c36e0998..f7447d3c104 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -332,12 +332,11 @@ class BuiltinToolManageService: get builtin tool provider credentials """ with db.session.no_autoflush: - providers = ( - db.session.query(BuiltinToolProvider) - .filter_by(tenant_id=tenant_id, provider=provider_name) + providers = db.session.scalars( + select(BuiltinToolProvider) + .where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name) .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - .all() - ) + ).all() if len(providers) == 0: return [] diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 3f78b823a63..e7266cb8e94 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,6 +1,7 @@ import logging from graphon.model_runtime.entities.model_entities import ModelType +from sqlalchemy import delete, select from core.model_manager import ModelInstance, ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword @@ -29,7 +30,7 @@ class VectorService: for segment in segments: if doc_form == IndexStructureType.PARENT_CHILD_INDEX: - dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first() + dataset_document = db.session.get(DatasetDocument, segment.document_id) if not dataset_document: logger.warning( "Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s", @@ -38,11 +39,7 @@ class VectorService: ) continue # get the process rule - processing_rule = ( - db.session.query(DatasetProcessRule) - .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) - .first() - ) + processing_rule = db.session.get(DatasetProcessRule, dataset_document.dataset_process_rule_id) if not processing_rule: raise ValueError("No processing rule found.") # get embedding model instance @@ -271,8 +268,8 @@ class VectorService: vector.delete_by_ids(old_attachment_ids) # Delete existing segment attachment bindings in one operation - db.session.query(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id).delete( - synchronize_session=False + db.session.execute( + delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id) ) if not attachment_ids: @@ -280,7 +277,7 @@ class VectorService: return # Bulk fetch upload files - only fetch needed fields - upload_file_list = db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all() + upload_file_list = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all() if not upload_file_list: db.session.commit() diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 3b3ee6dd92e..8f365c7c51f 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -138,14 +138,14 @@ class WorkflowService: if workflow_id: return self.get_published_workflow_by_id(app_model, workflow_id) # fetch draft workflow by app_model - workflow = ( - db.session.query(Workflow) + workflow = db.session.scalar( + select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == Workflow.VERSION_DRAFT, ) - .first() + .limit(1) ) # return draft workflow @@ -155,14 +155,14 @@ class WorkflowService: """ fetch published workflow by workflow_id """ - workflow = ( - db.session.query(Workflow) + workflow = db.session.scalar( + select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id, ) - .first() + .limit(1) ) if not workflow: return None @@ -182,14 +182,14 @@ class WorkflowService: return None # fetch published workflow by workflow_id - workflow = ( - db.session.query(Workflow) + workflow = db.session.scalar( + select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == app_model.workflow_id, ) - .first() + .limit(1) ) return workflow @@ -544,14 +544,14 @@ class WorkflowService: # Use the same fallback logic as runtime: get the first available credential # ordered by is_default DESC, created_at ASC (same as tool_manager.py) - default_provider = ( - db.session.query(BuiltinToolProvider) + default_provider = db.session.scalar( + select(BuiltinToolProvider) .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider, ) .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - .first() + .limit(1) ) if not default_provider: diff --git a/api/tests/test_containers_integration_tests/services/test_feedback_service.py b/api/tests/test_containers_integration_tests/services/test_feedback_service.py index 771f4067753..d82933ccb90 100644 --- a/api/tests/test_containers_integration_tests/services/test_feedback_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feedback_service.py @@ -99,7 +99,7 @@ class TestFeedbackService: ) ] - mock_db_session.query.return_value = mock_query + mock_db_session.execute.return_value = mock_query # Test CSV export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -138,7 +138,7 @@ class TestFeedbackService: ) ] - mock_db_session.query.return_value = mock_query + mock_db_session.execute.return_value = mock_query # Test JSON export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") @@ -175,7 +175,7 @@ class TestFeedbackService: ) ] - mock_db_session.query.return_value = mock_query + mock_db_session.execute.return_value = mock_query # Test with filters result = FeedbackService.export_feedbacks( @@ -188,11 +188,8 @@ class TestFeedbackService: format_type="csv", ) - # Verify filters were applied - assert mock_query.filter.called - filter_calls = mock_query.filter.call_args_list - # At least three filter invocations are expected (source, rating, comment) - assert len(filter_calls) >= 3 + # Verify query was executed (filters are baked into the select statement) + assert mock_db_session.execute.called def test_export_feedbacks_no_data(self, mock_db_session, sample_data): """Test exporting feedback when no data exists.""" @@ -206,7 +203,7 @@ class TestFeedbackService: mock_query.order_by.return_value = mock_query mock_query.all.return_value = [] - mock_db_session.query.return_value = mock_query + mock_db_session.execute.return_value = mock_query result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -271,7 +268,7 @@ class TestFeedbackService: ) ] - mock_db_session.query.return_value = mock_query + mock_db_session.execute.return_value = mock_query # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") @@ -329,7 +326,7 @@ class TestFeedbackService: ) ] - mock_db_session.query.return_value = mock_query + mock_db_session.execute.return_value = mock_query # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -367,7 +364,7 @@ class TestFeedbackService: ), ] - mock_db_session.query.return_value = mock_query + mock_db_session.execute.return_value = mock_query # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") diff --git a/api/tests/unit_tests/services/test_vector_service.py b/api/tests/unit_tests/services/test_vector_service.py index 598ff3fc3a4..a78a033f4d3 100644 --- a/api/tests/unit_tests/services/test_vector_service.py +++ b/api/tests/unit_tests/services/test_vector_service.py @@ -77,22 +77,12 @@ def _make_segment( def _mock_db_session_for_update_multimodel(*, upload_files: list[_UploadFileStub] | None) -> MagicMock: session = MagicMock(name="session") - binding_query = MagicMock(name="binding_query") - binding_query.where.return_value = binding_query - binding_query.delete.return_value = 1 + # db.session.execute() is used for delete(SegmentAttachmentBinding).where(...) + session.execute = MagicMock(name="execute") - upload_query = MagicMock(name="upload_query") - upload_query.where.return_value = upload_query - upload_query.all.return_value = upload_files or [] + # db.session.scalars(select(UploadFile).where(...)).all() returns upload files + session.scalars.return_value.all.return_value = upload_files or [] - def query_side_effect(model: object) -> MagicMock: - if model is vector_service_module.SegmentAttachmentBinding: - return binding_query - if model is vector_service_module.UploadFile: - return upload_query - return MagicMock(name=f"query({model})") - - session.query.side_effect = query_side_effect db_mock = MagicMock(name="db") db_mock.session = session return db_mock @@ -165,22 +155,15 @@ def _mock_parent_child_queries( ) -> MagicMock: session = MagicMock(name="session") - doc_query = MagicMock(name="doc_query") - doc_query.filter_by.return_value = doc_query - doc_query.first.return_value = dataset_document + get_dispatch: dict[object, object | None] = { + vector_service_module.DatasetDocument: dataset_document, + vector_service_module.DatasetProcessRule: processing_rule, + } - rule_query = MagicMock(name="rule_query") - rule_query.where.return_value = rule_query - rule_query.first.return_value = processing_rule + def get_side_effect(model: object, pk: object) -> object | None: + return get_dispatch.get(model) - def query_side_effect(model: object) -> MagicMock: - if model is vector_service_module.DatasetDocument: - return doc_query - if model is vector_service_module.DatasetProcessRule: - return rule_query - return MagicMock(name=f"query({model})") - - session.query.side_effect = query_side_effect + session.get.side_effect = get_side_effect db_mock = MagicMock(name="db") db_mock.session = session return db_mock @@ -609,7 +592,7 @@ def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids( vector_cls.assert_called_once_with(dataset=dataset) vector_instance.delete_by_ids.assert_called_once_with(["old-1", "old-2"]) - db_mock.session.query.assert_called_once_with(vector_service_module.SegmentAttachmentBinding) + db_mock.session.execute.assert_called_once() db_mock.session.commit.assert_called_once() db_mock.session.add_all.assert_not_called() vector_instance.add_texts.assert_not_called() @@ -644,6 +627,8 @@ def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_up binding_ctor = MagicMock(side_effect=lambda **kwargs: kwargs) monkeypatch.setattr(vector_service_module, "SegmentAttachmentBinding", binding_ctor) + monkeypatch.setattr(vector_service_module, "delete", MagicMock()) + monkeypatch.setattr(vector_service_module, "select", MagicMock()) logger_mock = MagicMock() monkeypatch.setattr(vector_service_module, "logger", logger_mock) @@ -677,6 +662,8 @@ def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops monkeypatch.setattr( vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs) ) + monkeypatch.setattr(vector_service_module, "delete", MagicMock()) + monkeypatch.setattr(vector_service_module, "select", MagicMock()) VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset) @@ -698,6 +685,8 @@ def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: monkeypatch.setattr( vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs) ) + monkeypatch.setattr(vector_service_module, "delete", MagicMock()) + monkeypatch.setattr(vector_service_module, "select", MagicMock()) logger_mock = MagicMock() monkeypatch.setattr(vector_service_module, "logger", logger_mock) diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index cd71981bcf1..1b253eb2f1f 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -268,7 +268,7 @@ class TestWorkflowService: Provides mock implementations of: - session.add(): Adding new records - session.commit(): Committing transactions - - session.query(): Querying database + - session.scalar(): Scalar queries - session.execute(): Executing SQL statements """ with patch("services.workflow_service.db") as mock_db: @@ -276,7 +276,7 @@ class TestWorkflowService: mock_db.session = mock_session mock_session.add = MagicMock() mock_session.commit = MagicMock() - mock_session.query = MagicMock() + mock_session.scalar = MagicMock() mock_session.execute = MagicMock() yield mock_db @@ -338,10 +338,8 @@ class TestWorkflowService: app = TestWorkflowAssociatedDataFactory.create_app_mock() mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock() - # Mock database query - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + # Mock db.session.scalar() used by get_draft_workflow + mock_db_session.session.scalar.return_value = mock_workflow result = workflow_service.get_draft_workflow(app) @@ -351,10 +349,8 @@ class TestWorkflowService: """Test get_draft_workflow returns None when no draft exists.""" app = TestWorkflowAssociatedDataFactory.create_app_mock() - # Mock database query to return None - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = None + # Mock db.session.scalar() to return None + mock_db_session.session.scalar.return_value = None result = workflow_service.get_draft_workflow(app) @@ -366,10 +362,8 @@ class TestWorkflowService: workflow_id = "workflow-123" mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(version="v1") - # Mock database query - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + # Mock db.session.scalar() used by get_published_workflow_by_id + mock_db_session.session.scalar.return_value = mock_workflow result = workflow_service.get_draft_workflow(app, workflow_id=workflow_id) @@ -384,10 +378,8 @@ class TestWorkflowService: workflow_id = "workflow-123" mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1") - # Mock database query - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + # Mock db.session.scalar() used by get_published_workflow_by_id + mock_db_session.session.scalar.return_value = mock_workflow result = workflow_service.get_published_workflow_by_id(app, workflow_id) @@ -406,10 +398,8 @@ class TestWorkflowService: workflow_id=workflow_id, version=Workflow.VERSION_DRAFT ) - # Mock database query - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + # Mock db.session.scalar() used by get_published_workflow_by_id + mock_db_session.session.scalar.return_value = mock_workflow with pytest.raises(IsDraftWorkflowError): workflow_service.get_published_workflow_by_id(app, workflow_id) @@ -419,10 +409,8 @@ class TestWorkflowService: app = TestWorkflowAssociatedDataFactory.create_app_mock() workflow_id = "nonexistent-workflow" - # Mock database query to return None - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = None + # Mock db.session.scalar() to return None + mock_db_session.session.scalar.return_value = None result = workflow_service.get_published_workflow_by_id(app, workflow_id) @@ -434,10 +422,8 @@ class TestWorkflowService: app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=workflow_id) mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1") - # Mock database query - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + # Mock db.session.scalar() used by get_published_workflow + mock_db_session.session.scalar.return_value = mock_workflow result = workflow_service.get_published_workflow(app) @@ -466,11 +452,9 @@ class TestWorkflowService: graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph() features = {"file_upload": {"enabled": False}} - # Mock get_draft_workflow to return None (no existing draft) + # Mock db.session.scalar() to return None (no existing draft) # This simulates the first time a workflow is created for an app - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = None + mock_db_session.session.scalar.return_value = None with ( patch.object(workflow_service, "validate_features_structure"), @@ -504,12 +488,10 @@ class TestWorkflowService: features = {"file_upload": {"enabled": False}} unique_hash = "test-hash-123" - # Mock existing draft workflow + # Mock existing draft workflow via db.session.scalar() mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash=unique_hash) - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + mock_db_session.session.scalar.return_value = mock_workflow with ( patch.object(workflow_service, "validate_features_structure"), @@ -545,12 +527,10 @@ class TestWorkflowService: graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph() features = {} - # Mock existing draft workflow with different hash + # Mock existing draft workflow with different hash via db.session.scalar() mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash="old-hash") - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + mock_db_session.session.scalar.return_value = mock_workflow with pytest.raises(WorkflowHashNotEqualError): workflow_service.sync_draft_workflow( diff --git a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py index 439d203c58d..175900071b1 100644 --- a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py +++ b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py @@ -347,7 +347,7 @@ class TestGetBuiltinToolProviderCredentials: def test_returns_empty_when_no_providers(self, mock_db): mock_db.session.no_autoflush.__enter__ = MagicMock(return_value=None) mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False) - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] result = BuiltinToolManageService.get_builtin_tool_provider_credentials("t", "google") @@ -362,7 +362,7 @@ class TestGetBuiltinToolProviderCredentials: mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False) provider = MagicMock(provider="google", is_default=False) - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [provider] + mock_db.session.scalars.return_value.all.return_value = [provider] mock_encrypter = MagicMock() mock_encrypter.decrypt.return_value = {"key": "decrypted"} From 2b9eb065552e4bae8ff14b00640d30e2d2257d78 Mon Sep 17 00:00:00 2001 From: Stephen Zhou Date: Wed, 1 Apr 2026 19:02:53 +0800 Subject: [PATCH 11/42] chore: move commit hook to root (#34404) --- .gitignore | 1 + {web/.husky => .vite-hooks}/pre-commit | 2 +- package.json | 12 +- pnpm-lock.yaml | 155 +------------------------ pnpm-workspace.yaml | 2 - vite.config.ts | 5 + web/Dockerfile | 2 +- web/Dockerfile.dockerignore | 1 - web/package.json | 6 - web/vite.config.ts | 3 + 10 files changed, 22 insertions(+), 167 deletions(-) rename {web/.husky => .vite-hooks}/pre-commit (99%) mode change 100644 => 100755 create mode 100644 vite.config.ts diff --git a/.gitignore b/.gitignore index d7698fe3fd9..f703fc02e9b 100644 --- a/.gitignore +++ b/.gitignore @@ -213,6 +213,7 @@ api/.vscode # pnpm /.pnpm-store /node_modules +.vite-hooks/_ # plugin migrate plugins.jsonl diff --git a/web/.husky/pre-commit b/.vite-hooks/pre-commit old mode 100644 new mode 100755 similarity index 99% rename from web/.husky/pre-commit rename to .vite-hooks/pre-commit index 3f25de256fb..54e09f80d6d --- a/web/.husky/pre-commit +++ b/.vite-hooks/pre-commit @@ -77,7 +77,7 @@ if $web_modified; then fi cd ./web || exit 1 - lint-staged + vp staged if $web_ts_modified; then echo "Running TypeScript type-check:tsgo" diff --git a/package.json b/package.json index 07f1e16153f..48c3acef021 100644 --- a/package.json +++ b/package.json @@ -1,11 +1,15 @@ { "name": "dify", "private": true, + "scripts": { + "prepare": "vp config" + }, + "devDependencies": { + "taze": "catalog:", + "vite-plus": "catalog:" + }, "engines": { "node": "^22.22.1" }, - "packageManager": "pnpm@10.33.0", - "devDependencies": { - "taze": "catalog:" - } + "packageManager": "pnpm@10.33.0" } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index eb45ea0ef85..baa4ed6c34c 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -345,9 +345,6 @@ catalogs: html-to-image: specifier: 1.11.13 version: 1.11.13 - husky: - specifier: 9.1.7 - version: 9.1.7 i18next: specifier: 25.10.10 version: 25.10.10 @@ -390,9 +387,6 @@ catalogs: lexical: specifier: 0.42.0 version: 0.42.0 - lint-staged: - specifier: 16.4.0 - version: 16.4.0 mermaid: specifier: 11.13.0 version: 11.13.0 @@ -624,6 +618,9 @@ importers: taze: specifier: 'catalog:' version: 19.10.0 + vite-plus: + specifier: 'catalog:' + version: 0.1.14(@types/node@25.5.0)(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.0)(esbuild@0.27.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3) e2e: devDependencies: @@ -1165,18 +1162,12 @@ importers: hono: specifier: 'catalog:' version: 4.12.9 - husky: - specifier: 'catalog:' - version: 9.1.7 iconify-import-svg: specifier: 'catalog:' version: 0.1.2 knip: specifier: 'catalog:' version: 6.1.0(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1) - lint-staged: - specifier: 'catalog:' - version: 16.4.0 postcss: specifier: 'catalog:' version: 8.5.8 @@ -4751,10 +4742,6 @@ packages: ajv@8.18.0: resolution: {integrity: sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==} - ansi-escapes@7.3.0: - resolution: {integrity: sha512-BvU8nYgGQBxcmMuEeUEmNTvrMVjJNSH7RgW24vXexN4Ven6qCvy4TntnvlnwnMLTVlcRQQdbRY8NKnaIoeWDNg==} - engines: {node: '>=18'} - ansi-regex@4.1.1: resolution: {integrity: sha512-ILlv4k/3f6vfQ4OoP2AGvirOktlQ98ZEL1k9FaQjxa3L1abBgbuTDAdPOpvbGncC0BTVQrl+OM8xZGK6tWXt7g==} engines: {node: '>=6'} @@ -4775,10 +4762,6 @@ packages: resolution: {integrity: sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==} engines: {node: '>=10'} - ansi-styles@6.2.3: - resolution: {integrity: sha512-4Dj6M28JB+oAH8kFkTLUo+a2jwOFkuqb3yucU0CANcRRUbxS0cP0nZYCGjcc3BNXwRIsUVmDGgzawme7zvJHvg==} - engines: {node: '>=12'} - ansis@4.2.0: resolution: {integrity: sha512-HqZ5rWlFjGiV0tDm3UxxgNRqsOTniqoKZu0pIAfh7TZQMGuZK+hH0drySty0si0QXj1ieop4+SkSfPZBPPkHig==} engines: {node: '>=14'} @@ -5066,18 +5049,10 @@ packages: resolution: {integrity: sha512-GfisEZEJvzKrmGWkvfhgzcz/BllN1USeqD2V6tg14OAOgaCD2Z/PUEuxnAZ/nPvmaHRG7a8y77p1T/IRQ4D1Hw==} engines: {node: '>=4'} - cli-cursor@5.0.0: - resolution: {integrity: sha512-aCj4O5wKyszjMmDT4tZj93kxyydN/K5zPWSCe6/0AV/AA1pqe5ZBIw0a2ZfPQV7lL5/yb5HsUreJ6UFAF1tEQw==} - engines: {node: '>=18'} - cli-table3@0.6.5: resolution: {integrity: sha512-+W/5efTR7y5HRD7gACw9yQjqMVvEMLBHmboM/kPWam+H+Hmyrgjh6YncVKK122YZkXrLudzTuAukUw9FnMf7IQ==} engines: {node: 10.* || >= 12.*} - cli-truncate@5.2.0: - resolution: {integrity: sha512-xRwvIOMGrfOAnM1JYtqQImuaNtDEv9v6oIYAs4LIHwTiKee8uwvIi363igssOC0O5U04i4AlENs79LQLu9tEMw==} - engines: {node: '>=20'} - client-only@0.0.1: resolution: {integrity: sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA==} @@ -5104,9 +5079,6 @@ packages: color-name@1.1.4: resolution: {integrity: sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==} - colorette@2.0.20: - resolution: {integrity: sha512-IfEDxwoWIjkeXL1eXcDiow4UbKjhLdq6/EuSVR9GMN7KVH3r9gQ83e73hsz1Nd1T3ijd5xv1wcWRYO+D6kCI2w==} - comma-separated-tokens@1.0.8: resolution: {integrity: sha512-GHuDRO12Sypu2cV70d1dkA2EUmXHgntrzbpvOB+Qy+49ypNfGgFQIC2fhhXbnyrJRynDCAARsT7Ou0M6hirpfw==} @@ -5572,10 +5544,6 @@ packages: resolution: {integrity: sha512-TWrgLOFUQTH994YUyl1yT4uyavY5nNB5muff+RtWaqNVCAK408b5ZnnbNAUEWLTCpum9w6arT70i1XdQ4UeOPA==} engines: {node: '>=0.12'} - environment@1.1.0: - resolution: {integrity: sha512-xUtoPkMggbz0MPyPiIWr1Kp4aeWJjDZ6SMvURhimjdZgsRuDplF5/s9hcgGhyXMhs+6vpnuoiZ2kFiu3FMnS8Q==} - engines: {node: '>=18'} - error-stack-parser-es@1.0.5: resolution: {integrity: sha512-5qucVt2XcuGMcEGgWI7i+yZpmpByQ8J1lHhcL7PwqCwu9FPP3VUXzT4ltHe5i2z9dePwEHcDVOAfSnHsOlCXRA==} @@ -5965,9 +5933,6 @@ packages: event-target-bus@1.0.0: resolution: {integrity: sha512-uPcWKbj/BJU3Tbw9XqhHqET4/LBOhvv3/SJWr7NksxA6TC5YqBpaZgawE9R+WpYFCBFSAE4Vun+xQS6w4ABdlA==} - eventemitter3@5.0.4: - resolution: {integrity: sha512-mlsTRyGaPBjPedk6Bvw+aqbsXDtoAyAzm5MO7JgU+yVRyMQ5O8bD4Kcci7BS85f93veegeCPkL8R4GLClnjLFw==} - events@3.3.0: resolution: {integrity: sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==} engines: {node: '>=0.8.x'} @@ -6289,11 +6254,6 @@ packages: htmlparser2@10.1.0: resolution: {integrity: sha512-VTZkM9GWRAtEpveh7MSF6SjjrpNVNNVJfFup7xTY3UpFtm67foy9HDVXneLtFVt4pMz5kZtgNcvCniNFb1hlEQ==} - husky@9.1.7: - resolution: {integrity: sha512-5gs5ytaNjBrh5Ow3zrvdUUY+0VxIuWVL4i9irt6friV+BqdCfmV11CQTWMiBYWHbXhco+J1kHfTOUkePhCDvMA==} - engines: {node: '>=18'} - hasBin: true - i18next-resources-to-backend@1.2.1: resolution: {integrity: sha512-okHbVA+HZ7n1/76MsfhPqDou0fptl2dAlhRDu2ideXloRRduzHsqDOznJBef+R3DFZnbvWoBW+KxJ7fnFjd6Yw==} @@ -6419,10 +6379,6 @@ packages: resolution: {integrity: sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==} engines: {node: '>=0.10.0'} - is-fullwidth-code-point@5.1.0: - resolution: {integrity: sha512-5XHYaSyiqADb4RnZ1Bdad6cPp8Toise4TzEjcOYDHZkTCbKgiUl7WTUCpNWHuxmDt91wnsZBc9xinNzopv3JMQ==} - engines: {node: '>=18'} - is-glob@4.0.3: resolution: {integrity: sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==} engines: {node: '>=0.10.0'} @@ -6730,15 +6686,6 @@ packages: lines-and-columns@1.2.4: resolution: {integrity: sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==} - lint-staged@16.4.0: - resolution: {integrity: sha512-lBWt8hujh/Cjysw5GYVmZpFHXDCgZzhrOm8vbcUdobADZNOK/bRshr2kM3DfgrrtR1DQhfupW9gnIXOfiFi+bw==} - engines: {node: '>=20.17'} - hasBin: true - - listr2@9.0.5: - resolution: {integrity: sha512-ME4Fb83LgEgwNw96RKNvKV4VTLuXfoKudAmm2lP8Kk87KaMK0/Xrx/aAkMWmT8mDb+3MlFDspfbCs7adjRxA2g==} - engines: {node: '>=20.0.0'} - load-tsconfig@0.2.5: resolution: {integrity: sha512-IXO6OCs9yg8tMKzfPZ1YmheJbZCiEsnBdcB03l0OcfK9prKnJb96siuHCr5Fl37/yo9DnKU+TLpxzTUspw9shg==} engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} @@ -6770,10 +6717,6 @@ packages: lodash@4.17.23: resolution: {integrity: sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==} - log-update@6.1.0: - resolution: {integrity: sha512-9ie8ItPR6tjY5uYJh8K/Zrv/RMZ5VOlOWvtZdEHYSTFKZfIBPQa9tOAEeAWhd+AnIneLJ22w5fjOYtoutpWq5w==} - engines: {node: '>=18'} - longest-streak@3.1.0: resolution: {integrity: sha512-9Ri+o0JYgehTaVBBDoMqIl8GXtbWg711O3srftcHhZ0dqnETqLaoIK0x17fUw9rFSlK/0NlsKe0Ahhyl5pXE2g==} @@ -7920,9 +7863,6 @@ packages: resolution: {integrity: sha512-g6QUff04oZpHs0eG5p83rFLhHeV00ug/Yf9nZM6fLeUrPguBTkTQOdpAWWspMh55TZfVQDPaN3NQJfbVRAxdIw==} engines: {iojs: '>=1.0.0', node: '>=0.10.0'} - rfdc@1.4.1: - resolution: {integrity: sha512-q1b3N5QkRUWUl7iyylaaj3kOpIT0N2i9MqIEQXP73GVsN9cw3fdx8X63cEmWhJGi2PPCF23Ijp7ktmd39rawIA==} - robust-predicates@3.0.3: resolution: {integrity: sha512-NS3levdsRIUOmiJ8FZWCP7LG3QpJyrs/TE0Zpf1yvZu8cAJJ6QMW92H1c7kWpdIHo8RvmLxN/o2JXTKHp74lUA==} @@ -8043,14 +7983,6 @@ packages: size-sensor@1.0.3: resolution: {integrity: sha512-+k9mJ2/rQMiRmQUcjn+qznch260leIXY8r4FyYKKyRBO/s5UoeMAHGkCJyE1R/4wrIhTJONfyloY55SkE7ve3A==} - slice-ansi@7.1.2: - resolution: {integrity: sha512-iOBWFgUX7caIZiuutICxVgX1SdxwAVFFKwt1EvMYYec/NWO5meOJ6K5uQxhrYBdQJne4KxiqZc+KptFOWFSI9w==} - engines: {node: '>=18'} - - slice-ansi@8.0.0: - resolution: {integrity: sha512-stxByr12oeeOyY2BlviTNQlYV5xOj47GirPr4yA1hE9JCtxfQN0+tVbkxwCtYDQWhEKWFHsEK48ORg5jrouCAg==} - engines: {node: '>=20'} - smol-toml@1.6.1: resolution: {integrity: sha512-dWUG8F5sIIARXih1DTaQAX4SsiTXhInKf1buxdY9DIg4ZYPZK5nGM1VRIYmEbDbsHt7USo99xSLFu5Q1IqTmsg==} engines: {node: '>= 18'} @@ -8134,10 +8066,6 @@ packages: resolution: {integrity: sha512-a1uQGz7IyVy9YwhqjZIZu1c8JO8dNIe20xBmSS6qu9kv++k3JGzCVmprbNN5Kn+BgzD5E7YYwg1CcjuJMRNsvg==} engines: {node: '>=0.6.19'} - string-argv@0.3.2: - resolution: {integrity: sha512-aqD2Q0144Z+/RqG52NeHEkZauTAUWJO8c6yTftGJKO3Tja5tUgIfmIl6kExvhtxSDP7fXB6DvzkfMpCd/F3G+Q==} - engines: {node: '>=0.6.19'} - string-ts@2.3.1: resolution: {integrity: sha512-xSJq+BS52SaFFAVxuStmx6n5aYZU571uYUnUrPXkPFCfdHyZMMlbP2v2Wx5sNBnAVzq/2+0+mcBLBa3Xa5ubYw==} @@ -8874,10 +8802,6 @@ packages: resolution: {integrity: sha512-BN22B5eaMMI9UMtjrGd5g5eCYPpCPDUy0FJXbYsaT5zYxjFOckS53SQDE3pWkVoWpHXVb3BrYcEN4Twa55B5cA==} engines: {node: '>=0.10.0'} - wrap-ansi@9.0.2: - resolution: {integrity: sha512-42AtmgqjV+X1VpdOfyTGOYRi0/zsoLqtXQckTmqTeybT+BDIbM/Guxo7x3pE2vtpr1ok6xRqM9OpBe+Jyoqyww==} - engines: {node: '>=18'} - wrappy@1.0.2: resolution: {integrity: sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==} @@ -12658,10 +12582,6 @@ snapshots: json-schema-traverse: 1.0.0 require-from-string: 2.0.2 - ansi-escapes@7.3.0: - dependencies: - environment: 1.1.0 - ansi-regex@4.1.1: {} ansi-regex@5.0.1: {} @@ -12674,8 +12594,6 @@ snapshots: ansi-styles@5.2.0: {} - ansi-styles@6.2.3: {} - ansis@4.2.0: {} any-promise@1.3.0: {} @@ -12947,21 +12865,12 @@ snapshots: dependencies: escape-string-regexp: 1.0.5 - cli-cursor@5.0.0: - dependencies: - restore-cursor: 5.1.0 - cli-table3@0.6.5: dependencies: string-width: 8.2.0 optionalDependencies: '@colors/colors': 1.5.0 - cli-truncate@5.2.0: - dependencies: - slice-ansi: 8.0.0 - string-width: 8.2.0 - client-only@0.0.1: {} clsx@2.1.1: {} @@ -12998,8 +12907,6 @@ snapshots: color-name@1.1.4: {} - colorette@2.0.20: {} - comma-separated-tokens@1.0.8: {} comma-separated-tokens@2.0.3: {} @@ -13443,8 +13350,6 @@ snapshots: entities@7.0.1: {} - environment@1.1.0: {} - error-stack-parser-es@1.0.5: {} error-stack-parser@2.1.4: @@ -14106,8 +14011,6 @@ snapshots: event-target-bus@1.0.0: {} - eventemitter3@5.0.4: {} - events@3.3.0: {} expand-template@2.0.3: @@ -14496,8 +14399,6 @@ snapshots: domutils: 3.2.2 entities: 7.0.1 - husky@9.1.7: {} - i18next-resources-to-backend@1.2.1: dependencies: '@babel/runtime': 7.29.2 @@ -14596,10 +14497,6 @@ snapshots: is-extglob@2.1.1: {} - is-fullwidth-code-point@5.1.0: - dependencies: - get-east-asian-width: 1.5.0 - is-glob@4.0.3: dependencies: is-extglob: 2.1.1 @@ -14853,24 +14750,6 @@ snapshots: lines-and-columns@1.2.4: {} - lint-staged@16.4.0: - dependencies: - commander: 14.0.3 - listr2: 9.0.5 - picomatch: 4.0.4 - string-argv: 0.3.2 - tinyexec: 1.0.4 - yaml: 2.8.3 - - listr2@9.0.5: - dependencies: - cli-truncate: 5.2.0 - colorette: 2.0.20 - eventemitter3: 5.0.4 - log-update: 6.1.0 - rfdc: 1.4.1 - wrap-ansi: 9.0.2 - load-tsconfig@0.2.5: {} loader-runner@4.3.1: {} @@ -14895,14 +14774,6 @@ snapshots: lodash@4.17.23: {} - log-update@6.1.0: - dependencies: - ansi-escapes: 7.3.0 - cli-cursor: 5.0.0 - slice-ansi: 7.1.2 - strip-ansi: 7.2.0 - wrap-ansi: 9.0.2 - longest-streak@3.1.0: {} loose-envify@1.4.0: @@ -16539,8 +16410,6 @@ snapshots: reusify@1.1.0: {} - rfdc@1.4.1: {} - robust-predicates@3.0.3: {} rolldown@1.0.0-rc.12(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1): @@ -16734,16 +16603,6 @@ snapshots: size-sensor@1.0.3: {} - slice-ansi@7.1.2: - dependencies: - ansi-styles: 6.2.3 - is-fullwidth-code-point: 5.1.0 - - slice-ansi@8.0.0: - dependencies: - ansi-styles: 6.2.3 - is-fullwidth-code-point: 5.1.0 - smol-toml@1.6.1: {} solid-js@1.9.11: @@ -16844,8 +16703,6 @@ snapshots: string-argv@0.3.1: {} - string-argv@0.3.2: {} - string-ts@2.3.1: {} string-width@8.2.0: @@ -17690,12 +17547,6 @@ snapshots: word-wrap@1.2.5: {} - wrap-ansi@9.0.2: - dependencies: - ansi-styles: 6.2.3 - string-width: 8.2.0 - strip-ansi: 7.2.0 - wrappy@1.0.2: {} ws@8.20.0: {} diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index b11cca66421..77451f6dfc4 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -183,7 +183,6 @@ catalog: hono: 4.12.9 html-entities: 2.6.0 html-to-image: 1.11.13 - husky: 9.1.7 i18next: 25.10.10 i18next-resources-to-backend: 1.2.1 iconify-import-svg: 0.1.2 @@ -198,7 +197,6 @@ catalog: ky: 1.14.3 lamejs: 1.2.1 lexical: 0.42.0 - lint-staged: 16.4.0 mermaid: 11.13.0 mime: 4.1.0 mitt: 3.0.1 diff --git a/vite.config.ts b/vite.config.ts new file mode 100644 index 00000000000..a34932a4ef1 --- /dev/null +++ b/vite.config.ts @@ -0,0 +1,5 @@ +import { defineConfig } from 'vite-plus' + +export default defineConfig({ + staged: {}, +}) diff --git a/web/Dockerfile b/web/Dockerfile index 75024db4f3c..dc234168423 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -31,7 +31,7 @@ RUN corepack install # Install only the web workspace to keep image builds from pulling in # unrelated workspace dependencies such as e2e tooling. -RUN pnpm install --filter ./web... --frozen-lockfile +RUN VITE_GIT_HOOKS=0 pnpm install --filter ./web... --frozen-lockfile # build resources FROM base AS builder diff --git a/web/Dockerfile.dockerignore b/web/Dockerfile.dockerignore index 9801003d892..b572bd863e6 100644 --- a/web/Dockerfile.dockerignore +++ b/web/Dockerfile.dockerignore @@ -22,7 +22,6 @@ web/node_modules web/dist web/build web/coverage -web/.husky web/.next web/.pnpm-store web/.vscode diff --git a/web/package.json b/web/package.json index 9ed21fdb22b..08c10b12adc 100644 --- a/web/package.json +++ b/web/package.json @@ -40,7 +40,6 @@ "lint:quiet": "vp run lint --quiet", "lint:tss": "tsslint --project tsconfig.json", "preinstall": "npx only-allow pnpm", - "prepare": "cd ../ && node -e \"if (process.env.NODE_ENV !== 'production'){process.exit(1)} \" || husky ./web/.husky", "refactor-component": "node ./scripts/refactor-component.js", "start": "node ./scripts/copy-and-start.mjs", "start:vinext": "vinext start", @@ -218,10 +217,8 @@ "eslint-plugin-storybook": "catalog:", "happy-dom": "catalog:", "hono": "catalog:", - "husky": "catalog:", "iconify-import-svg": "catalog:", "knip": "catalog:", - "lint-staged": "catalog:", "postcss": "catalog:", "postcss-js": "catalog:", "react-server-dom-webpack": "catalog:", @@ -237,8 +234,5 @@ "vite-plus": "catalog:", "vitest": "catalog:", "vitest-canvas-mock": "catalog:" - }, - "lint-staged": { - "*": "eslint --fix --pass-on-unpruned-suppressions" } } diff --git a/web/vite.config.ts b/web/vite.config.ts index 28746f81ca3..92762676d18 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -18,6 +18,9 @@ export default defineConfig(({ mode }) => { || process.argv.some(arg => arg.toLowerCase().includes('storybook')) return { + staged: { + '*': 'eslint --fix --pass-on-unpruned-suppressions', + }, plugins: isTest ? [ nextStaticImageTestPlugin({ projectRoot }), From e41965061ceb78bb574d1eac60fa18fb7b3e0e98 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Wed, 1 Apr 2026 21:15:36 +0800 Subject: [PATCH 12/42] =?UTF-8?q?fix:=20sqlalchemy.exc.InvalidRequestError?= =?UTF-8?q?:=20Can't=20operate=20on=20closed=20tran=E2=80=A6=20(#34407)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../rag_pipeline/rag_pipeline_workflow.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 4251e7ebac3..70dfe47d7f8 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -593,17 +593,15 @@ class PublishedRagPipelineApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() rag_pipeline_service = RagPipelineService() - with sessionmaker(db.engine).begin() as session: - pipeline = session.merge(pipeline) - workflow = rag_pipeline_service.publish_workflow( - session=session, - pipeline=pipeline, - account=current_user, - ) - pipeline.is_published = True - pipeline.workflow_id = workflow.id - session.add(pipeline) - workflow_created_at = TimestampField().format(workflow.created_at) + workflow = rag_pipeline_service.publish_workflow( + session=db.session, # type: ignore[reportArgumentType,arg-type] + pipeline=pipeline, + account=current_user, + ) + pipeline.is_published = True + pipeline.workflow_id = workflow.id + db.session.commit() + workflow_created_at = TimestampField().format(workflow.created_at) return { "result": "success", From 391007d02e60e69962df2cc0ebcbac740cd1a1de Mon Sep 17 00:00:00 2001 From: Tim Ren <137012659+xr843@users.noreply.github.com> Date: Wed, 1 Apr 2026 22:53:41 +0800 Subject: [PATCH 13/42] refactor: migrate service_api and inner_api to sessionmaker pattern (#34379) Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/inner_api/plugin/wraps.py | 6 ++-- .../service_api/app/conversation.py | 4 +-- api/controllers/service_api/app/workflow.py | 4 +-- .../inner_api/plugin/test_plugin_wraps.py | 33 +++++++++---------- .../service_api/app/test_conversation.py | 11 +++++-- .../service_api/app/test_workflow.py | 22 +++++++++---- 6 files changed, 47 insertions(+), 33 deletions(-) diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index d6e3ebfbcd1..ed0d490aad0 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -6,7 +6,7 @@ from flask import current_app, request from flask_login import user_logged_in from pydantic import BaseModel from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from extensions.ext_database import db from libs.login import current_user @@ -33,7 +33,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID try: - with Session(db.engine) as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: user_model = None if is_anonymous: @@ -56,7 +56,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: session_id=user_id, ) session.add(user_model) - session.commit() + session.flush() session.refresh(user_model) except Exception: diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index edbf0116566..8c9a3eb5e97 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -3,7 +3,7 @@ from typing import Any, Literal from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, NotFound import services @@ -116,7 +116,7 @@ class ConversationApi(Resource): last_id = str(query_args.last_id) if query_args.last_id else None try: - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: pagination = ConversationService.pagination_by_last_id( session=session, app_model=app_model, diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 17590751395..d7992a2a3a1 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -8,7 +8,7 @@ from graphon.enums import WorkflowExecutionStatus from graphon.graph_engine.manager import GraphEngineManager from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from controllers.common.schema import register_schema_models @@ -314,7 +314,7 @@ class WorkflowAppLogApi(Resource): # get paginate workflow app logs workflow_app_service = WorkflowAppService() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( session=session, app_model=app_model, diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py index eac57fe4b76..957d7fbd9be 100644 --- a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py +++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py @@ -41,15 +41,15 @@ class TestGetUser: """Test get_user function""" @patch("controllers.inner_api.plugin.wraps.EndUser") - @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") - def test_should_return_existing_user_by_id(self, mock_db, mock_session_class, mock_enduser_class, app: Flask): + def test_should_return_existing_user_by_id(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask): """Test returning existing user when found by ID""" # Arrange mock_user = MagicMock() mock_user.id = "user123" mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_session.get.return_value = mock_user # Act @@ -61,17 +61,17 @@ class TestGetUser: mock_session.get.assert_called_once() @patch("controllers.inner_api.plugin.wraps.EndUser") - @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") def test_should_return_existing_anonymous_user_by_session_id( - self, mock_db, mock_session_class, mock_enduser_class, app: Flask + self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask ): """Test returning existing anonymous user by session_id""" # Arrange mock_user = MagicMock() mock_user.session_id = "anonymous_session" mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session # non-anonymous path uses session.get(); anonymous uses session.scalar() mock_session.get.return_value = mock_user @@ -83,13 +83,13 @@ class TestGetUser: assert result == mock_user @patch("controllers.inner_api.plugin.wraps.EndUser") - @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") - def test_should_create_new_user_when_not_found(self, mock_db, mock_session_class, mock_enduser_class, app: Flask): + def test_should_create_new_user_when_not_found(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask): """Test creating new user when not found in database""" # Arrange mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_session.get.return_value = None mock_new_user = MagicMock() mock_enduser_class.return_value = mock_new_user @@ -101,21 +101,20 @@ class TestGetUser: # Assert assert result == mock_new_user mock_session.add.assert_called_once() - mock_session.commit.assert_called_once() mock_session.refresh.assert_called_once() @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") - @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") def test_should_use_default_session_id_when_user_id_none( - self, mock_db, mock_session_class, mock_enduser_class, mock_select, app: Flask + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask ): """Test using default session ID when user_id is None""" # Arrange mock_user = MagicMock() mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session # When user_id is None, is_anonymous=True, so session.scalar() is used mock_session.scalar.return_value = mock_user @@ -127,15 +126,13 @@ class TestGetUser: assert result == mock_user @patch("controllers.inner_api.plugin.wraps.EndUser") - @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") - def test_should_raise_error_on_database_exception( - self, mock_db, mock_session_class, mock_enduser_class, app: Flask - ): + def test_should_raise_error_on_database_exception(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask): """Test raising ValueError when database operation fails""" # Arrange mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_session.get.side_effect = Exception("Database error") # Act & Assert diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py index 81c45dcdb70..dbd06677d8a 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py @@ -433,13 +433,20 @@ class TestConversationApiController: handler(api, app_model=app_model, end_user=end_user) def test_list_last_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: - class _SessionStub: + class _BeginStub: def __enter__(self): return SimpleNamespace() def __exit__(self, exc_type, exc, tb): return False + class _SessionMakerStub: + def __init__(self, *args, **kwargs): + pass + + def begin(self): + return _BeginStub() + monkeypatch.setattr( ConversationService, "pagination_by_last_id", @@ -447,7 +454,7 @@ class TestConversationApiController: ) conversation_module = sys.modules["controllers.service_api.app.conversation"] monkeypatch.setattr(conversation_module, "db", SimpleNamespace(engine=object())) - monkeypatch.setattr(conversation_module, "Session", lambda *_args, **_kwargs: _SessionStub()) + monkeypatch.setattr(conversation_module, "sessionmaker", _SessionMakerStub) api = ConversationApi() handler = _unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index b1f036c6f36..cfa21bf2dd6 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -470,16 +470,23 @@ class TestWorkflowTaskStopApi: class TestWorkflowAppLogApi: def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: - class _SessionStub: + class _BeginStub: def __enter__(self): return SimpleNamespace() def __exit__(self, exc_type, exc, tb): return False + class _SessionMakerStub: + def __init__(self, *args, **kwargs): + pass + + def begin(self): + return _BeginStub() + workflow_module = sys.modules["controllers.service_api.app.workflow"] monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) - monkeypatch.setattr(workflow_module, "Session", lambda *_args, **_kwargs: _SessionStub()) + monkeypatch.setattr(workflow_module, "sessionmaker", _SessionMakerStub) monkeypatch.setattr( WorkflowAppService, "get_paginate_workflow_app_logs", @@ -635,11 +642,14 @@ class TestWorkflowAppLogApiGet: mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination mock_wf_svc_cls.return_value = mock_svc_instance - # Mock Session context manager + # Mock sessionmaker(...).begin() context manager mock_session = Mock() mock_db.engine = Mock() - mock_session.__enter__ = Mock(return_value=mock_session) - mock_session.__exit__ = Mock(return_value=False) + mock_begin = Mock() + mock_begin.__enter__ = Mock(return_value=mock_session) + mock_begin.__exit__ = Mock(return_value=False) + mock_session_factory = Mock() + mock_session_factory.begin.return_value = mock_begin from controllers.service_api.app.workflow import WorkflowAppLogApi @@ -647,7 +657,7 @@ class TestWorkflowAppLogApiGet: "/workflows/logs?page=1&limit=20", method="GET", ): - with patch("controllers.service_api.app.workflow.Session", return_value=mock_session): + with patch("controllers.service_api.app.workflow.sessionmaker", return_value=mock_session_factory): api = WorkflowAppLogApi() result = _unwrap(api.get)(api, app_model=mock_workflow_app) From 4e1d0604391e2df11c6df7b3864b9121e9304fe8 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Wed, 1 Apr 2026 18:37:27 +0200 Subject: [PATCH 14/42] refactor: select in message_service and ops_service (#34414) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/services/message_service.py | 57 ++++--- api/services/ops_service.py | 32 ++-- .../services/test_message_service.py | 147 +++--------------- .../unit_tests/services/test_ops_service.py | 53 ++++--- 4 files changed, 94 insertions(+), 195 deletions(-) diff --git a/api/services/message_service.py b/api/services/message_service.py index a04f9cbe012..5c2978db21b 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -3,6 +3,7 @@ from typing import Union from graphon.model_runtime.entities.model_entities import ModelType from pydantic import TypeAdapter +from sqlalchemy import select from sqlalchemy.orm import sessionmaker from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager @@ -75,17 +76,15 @@ class MessageService: fetch_limit = limit + 1 if first_id: - first_message = ( - db.session.query(Message) - .where(Message.conversation_id == conversation.id, Message.id == first_id) - .first() + first_message = db.session.scalar( + select(Message).where(Message.conversation_id == conversation.id, Message.id == first_id).limit(1) ) if not first_message: raise FirstMessageNotExistsError() - history_messages = ( - db.session.query(Message) + history_messages = db.session.scalars( + select(Message) .where( Message.conversation_id == conversation.id, Message.created_at < first_message.created_at, @@ -93,16 +92,14 @@ class MessageService: ) .order_by(Message.created_at.desc()) .limit(fetch_limit) - .all() - ) + ).all() else: - history_messages = ( - db.session.query(Message) + history_messages = db.session.scalars( + select(Message) .where(Message.conversation_id == conversation.id) .order_by(Message.created_at.desc()) .limit(fetch_limit) - .all() - ) + ).all() has_more = False if len(history_messages) > limit: @@ -129,7 +126,7 @@ class MessageService: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) - base_query = db.session.query(Message) + stmt = select(Message) fetch_limit = limit + 1 @@ -138,28 +135,27 @@ class MessageService: app_model=app_model, user=user, conversation_id=conversation_id ) - base_query = base_query.where(Message.conversation_id == conversation.id) + stmt = stmt.where(Message.conversation_id == conversation.id) # Check if include_ids is not None and not empty to avoid WHERE false condition if include_ids is not None: if len(include_ids) == 0: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) - base_query = base_query.where(Message.id.in_(include_ids)) + stmt = stmt.where(Message.id.in_(include_ids)) if last_id: - last_message = base_query.where(Message.id == last_id).first() + last_message = db.session.scalar(stmt.where(Message.id == last_id).limit(1)) if not last_message: raise LastMessageNotExistsError() - history_messages = ( - base_query.where(Message.created_at < last_message.created_at, Message.id != last_message.id) + history_messages = db.session.scalars( + stmt.where(Message.created_at < last_message.created_at, Message.id != last_message.id) .order_by(Message.created_at.desc()) .limit(fetch_limit) - .all() - ) + ).all() else: - history_messages = base_query.order_by(Message.created_at.desc()).limit(fetch_limit).all() + history_messages = db.session.scalars(stmt.order_by(Message.created_at.desc()).limit(fetch_limit)).all() has_more = False if len(history_messages) > limit: @@ -214,21 +210,20 @@ class MessageService: def get_all_messages_feedbacks(cls, app_model: App, page: int, limit: int): """Get all feedbacks of an app""" offset = (page - 1) * limit - feedbacks = ( - db.session.query(MessageFeedback) + feedbacks = db.session.scalars( + select(MessageFeedback) .where(MessageFeedback.app_id == app_model.id) .order_by(MessageFeedback.created_at.desc(), MessageFeedback.id.desc()) .limit(limit) .offset(offset) - .all() - ) + ).all() return [record.to_dict() for record in feedbacks] @classmethod def get_message(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): - message = ( - db.session.query(Message) + message = db.session.scalar( + select(Message) .where( Message.id == message_id, Message.app_id == app_model.id, @@ -236,7 +231,7 @@ class MessageService: Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), Message.from_account_id == (user.id if isinstance(user, Account) else None), ) - .first() + .limit(1) ) if not message: @@ -282,10 +277,10 @@ class MessageService: ) else: if not conversation.override_model_configs: - app_model_config = ( - db.session.query(AppModelConfig) + app_model_config = db.session.scalar( + select(AppModelConfig) .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) - .first() + .limit(1) ) else: conversation_override_model_configs = _app_model_config_adapter.validate_json( diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 50ea832085a..2a64088dd66 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,5 +1,7 @@ from typing import Any +from sqlalchemy import select + from core.ops.entities.config_entity import BaseTracingConfig from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map from extensions.ext_database import db @@ -15,17 +17,17 @@ class OpsService: :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig | None = ( - db.session.query(TraceAppConfig) + trace_config_data: TraceAppConfig | None = db.session.scalar( + select(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) - .first() + .limit(1) ) if not trace_config_data: return None # decrypt_token and obfuscated_token - app = db.session.query(App).where(App.id == app_id).first() + app = db.session.get(App, app_id) if not app: return None tenant_id = app.tenant_id @@ -182,17 +184,17 @@ class OpsService: project_url = None # check if trace config already exists - trace_config_data: TraceAppConfig | None = ( - db.session.query(TraceAppConfig) + trace_config_data: TraceAppConfig | None = db.session.scalar( + select(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) - .first() + .limit(1) ) if trace_config_data: return None # get tenant id - app = db.session.query(App).where(App.id == app_id).first() + app = db.session.get(App, app_id) if not app: return None tenant_id = app.tenant_id @@ -224,17 +226,17 @@ class OpsService: raise ValueError(f"Invalid tracing provider: {tracing_provider}") # check if trace config already exists - current_trace_config = ( - db.session.query(TraceAppConfig) + current_trace_config = db.session.scalar( + select(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) - .first() + .limit(1) ) if not current_trace_config: return None # get tenant id - app = db.session.query(App).where(App.id == app_id).first() + app = db.session.get(App, app_id) if not app: return None tenant_id = app.tenant_id @@ -261,10 +263,10 @@ class OpsService: :param tracing_provider: tracing provider :return: """ - trace_config = ( - db.session.query(TraceAppConfig) + trace_config = db.session.scalar( + select(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) - .first() + .limit(1) ) if not trace_config: diff --git a/api/tests/unit_tests/services/test_message_service.py b/api/tests/unit_tests/services/test_message_service.py index 101b9bff24d..b6e990ebe0f 100644 --- a/api/tests/unit_tests/services/test_message_service.py +++ b/api/tests/unit_tests/services/test_message_service.py @@ -151,12 +151,7 @@ class TestMessageServicePaginationByFirstId: for i in range(5) ] - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.all.return_value = messages + mock_db.session.scalars.return_value.all.return_value = messages # Act result = MessageService.pagination_by_first_id( @@ -196,12 +191,7 @@ class TestMessageServicePaginationByFirstId: for i in range(5) ] - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.all.return_value = messages + mock_db.session.scalars.return_value.all.return_value = messages # Act result = MessageService.pagination_by_first_id( @@ -246,31 +236,8 @@ class TestMessageServicePaginationByFirstId: for i in range(5) ] - # Setup query mocks - mock_query_first = MagicMock() - mock_query_history = MagicMock() - - query_calls = [] - - def query_side_effect(*args): - if args[0] == Message: - query_calls.append(args) - if len(query_calls) == 1: - return mock_query_first - else: - return mock_query_history - - mock_db.session.query.side_effect = [mock_query_first, mock_query_history] - - # Setup first message query - mock_query_first.where.return_value = mock_query_first - mock_query_first.first.return_value = first_message - - # Setup history messages query - mock_query_history.where.return_value = mock_query_history - mock_query_history.order_by.return_value = mock_query_history - mock_query_history.limit.return_value = mock_query_history - mock_query_history.all.return_value = history_messages + mock_db.session.scalar.return_value = first_message + mock_db.session.scalars.return_value.all.return_value = history_messages # Act result = MessageService.pagination_by_first_id( @@ -285,8 +252,6 @@ class TestMessageServicePaginationByFirstId: # Assert assert len(result.data) == 5 assert result.has_more is False - mock_query_first.where.assert_called_once() - mock_query_history.where.assert_called_once() # Test 06: First message not found @patch("services.message_service.db") @@ -300,10 +265,7 @@ class TestMessageServicePaginationByFirstId: mock_conversation_service.get_conversation.return_value = conversation - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # Message not found + mock_db.session.scalar.return_value = None # Message not found # Act & Assert with pytest.raises(FirstMessageNotExistsError): @@ -336,12 +298,7 @@ class TestMessageServicePaginationByFirstId: for i in range(11) ] - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.all.return_value = messages + mock_db.session.scalars.return_value.all.return_value = messages # Act result = MessageService.pagination_by_first_id( @@ -369,12 +326,7 @@ class TestMessageServicePaginationByFirstId: mock_conversation_service.get_conversation.return_value = conversation - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] # Act result = MessageService.pagination_by_first_id( @@ -443,12 +395,7 @@ class TestMessageServicePaginationByLastId: for i in range(5) ] - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.all.return_value = messages + mock_db.session.scalars.return_value.all.return_value = messages # Act result = MessageService.pagination_by_last_id( @@ -485,22 +432,8 @@ class TestMessageServicePaginationByLastId: for i in range(6, 10) ] - # Setup base query mock that returns itself for chaining - mock_base_query = MagicMock() - mock_db.session.query.return_value = mock_base_query - - # First where() call for last_id lookup - mock_query_last = MagicMock() - mock_query_last.first.return_value = last_message - - # Second where() call for history messages - mock_query_history = MagicMock() - mock_query_history.order_by.return_value = mock_query_history - mock_query_history.limit.return_value = mock_query_history - mock_query_history.all.return_value = new_messages - - # Setup where() to return different mocks on consecutive calls - mock_base_query.where.side_effect = [mock_query_last, mock_query_history] + mock_db.session.scalar.return_value = last_message + mock_db.session.scalars.return_value.all.return_value = new_messages # Act result = MessageService.pagination_by_last_id( @@ -522,10 +455,7 @@ class TestMessageServicePaginationByLastId: app = factory.create_app_mock() user = factory.create_end_user_mock() - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # Message not found + mock_db.session.scalar.return_value = None # Message not found # Act & Assert with pytest.raises(LastMessageNotExistsError): @@ -557,12 +487,7 @@ class TestMessageServicePaginationByLastId: for i in range(5) ] - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.all.return_value = messages + mock_db.session.scalars.return_value.all.return_value = messages # Act result = MessageService.pagination_by_last_id( @@ -576,8 +501,6 @@ class TestMessageServicePaginationByLastId: # Assert assert len(result.data) == 5 assert result.has_more is False - # Verify conversation_id was used in query - mock_query.where.assert_called() mock_conversation_service.get_conversation.assert_called_once() # Test 14: Pagination with include_ids filter @@ -594,12 +517,7 @@ class TestMessageServicePaginationByLastId: factory.create_message_mock(message_id="msg-003"), ] - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.all.return_value = messages + mock_db.session.scalars.return_value.all.return_value = messages # Act result = MessageService.pagination_by_last_id( @@ -632,12 +550,7 @@ class TestMessageServicePaginationByLastId: for i in range(11) ] - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.all.return_value = messages + mock_db.session.scalars.return_value.all.return_value = messages # Act result = MessageService.pagination_by_last_id( @@ -743,17 +656,13 @@ class TestMessageServiceGetMessage: user = factory.create_end_user_mock(user_id="end-user-123") message = factory.create_message_mock() - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = message + mock_db.session.scalar.return_value = message # Act result = MessageService.get_message(app_model=app, user=user, message_id="msg-123") # Assert assert result == message - mock_query.where.assert_called_once() # Test 21: get_message success for Account (Admin) @patch("services.message_service.db") @@ -767,10 +676,7 @@ class TestMessageServiceGetMessage: user.id = "account-123" message = factory.create_message_mock() - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = message + mock_db.session.scalar.return_value = message # Act result = MessageService.get_message(app_model=app, user=user, message_id="msg-123") @@ -786,10 +692,7 @@ class TestMessageServiceGetMessage: app = factory.create_app_mock() user = factory.create_end_user_mock() - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with pytest.raises(MessageNotExistsError): @@ -899,21 +802,13 @@ class TestMessageServiceFeedback: feedback = MagicMock() feedback.to_dict.return_value = {"id": "fb-1"} - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.offset.return_value = mock_query - mock_query.all.return_value = [feedback] + mock_db.session.scalars.return_value.all.return_value = [feedback] # Act result = MessageService.get_all_messages_feedbacks(app_model=app, page=1, limit=10) # Assert assert result == [{"id": "fb-1"}] - mock_query.limit.assert_called_with(10) - mock_query.offset.assert_called_with(0) class TestMessageServiceSuggestedQuestions: @@ -1015,10 +910,7 @@ class TestMessageServiceSuggestedQuestions: app_model_config.suggested_questions_after_answer_dict = {"enabled": True} app_model_config.model_dict = {"provider": "openai", "name": "gpt-4"} - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = app_model_config + mock_db.session.scalar.return_value = app_model_config mock_llm_gen.generate_suggested_questions_after_answer.return_value = ["Q1?"] @@ -1029,7 +921,6 @@ class TestMessageServiceSuggestedQuestions: # Assert assert result == ["Q1?"] - mock_query.first.assert_called_once() mock_llm_gen.generate_suggested_questions_after_answer.assert_called_once() # Test 30: get_suggested_questions_after_answer - Disabled Error diff --git a/api/tests/unit_tests/services/test_ops_service.py b/api/tests/unit_tests/services/test_ops_service.py index ab7b473790a..7067e3b3dd4 100644 --- a/api/tests/unit_tests/services/test_ops_service.py +++ b/api/tests/unit_tests/services/test_ops_service.py @@ -12,28 +12,27 @@ class TestOpsService: @patch("services.ops_service.OpsTraceManager") def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act result = OpsService.get_tracing_app_config("app_id", "arize") # Assert assert result is None - mock_db.session.query.assert_called_with(TraceAppConfig) @patch("services.ops_service.db") @patch("services.ops_service.OpsTraceManager") def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): # Arrange trace_config = MagicMock(spec=TraceAppConfig) - mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, None] + mock_db.session.scalar.return_value = trace_config + mock_db.session.get.return_value = None # Act result = OpsService.get_tracing_app_config("app_id", "arize") # Assert assert result is None - assert mock_db.session.query.call_count == 2 @patch("services.ops_service.db") @patch("services.ops_service.OpsTraceManager") @@ -43,7 +42,8 @@ class TestOpsService: trace_config.tracing_config = None app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + mock_db.session.scalar.return_value = trace_config + mock_db.session.get.return_value = app # Act & Assert with pytest.raises(ValueError, match="Tracing config cannot be None."): @@ -72,7 +72,8 @@ class TestOpsService: trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}} app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + mock_db.session.scalar.return_value = trace_config + mock_db.session.get.return_value = app mock_ops_trace_manager.decrypt_tracing_config.return_value = {} mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} @@ -97,7 +98,8 @@ class TestOpsService: trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}} app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + mock_db.session.scalar.return_value = trace_config + mock_db.session.get.return_value = app mock_ops_trace_manager.decrypt_tracing_config.return_value = {} mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} @@ -118,7 +120,8 @@ class TestOpsService: trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}} app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + mock_db.session.scalar.return_value = trace_config + mock_db.session.get.return_value = app mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} @@ -139,7 +142,8 @@ class TestOpsService: trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}} app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + mock_db.session.scalar.return_value = trace_config + mock_db.session.get.return_value = app mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} @@ -189,7 +193,7 @@ class TestOpsService: mock_ops_trace_manager.check_trace_config_is_effective.return_value = True mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error") mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") - mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig) + mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig) # Act result = OpsService.create_tracing_app_config("app_id", provider, config) @@ -206,7 +210,8 @@ class TestOpsService: mock_ops_trace_manager.get_trace_config_project_key.return_value = "key" app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_db.session.scalar.return_value = None + mock_db.session.get.return_value = app mock_ops_trace_manager.encrypt_tracing_config.return_value = {} # Act @@ -223,7 +228,7 @@ class TestOpsService: # Arrange provider = TracingProviderEnum.ARIZE mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig) + mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig) # Act result = OpsService.create_tracing_app_config("app_id", provider, {}) @@ -237,7 +242,8 @@ class TestOpsService: # Arrange provider = TracingProviderEnum.ARIZE mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_db.session.query.return_value.where.return_value.first.side_effect = [None, None] + mock_db.session.scalar.return_value = None + mock_db.session.get.return_value = None # Act result = OpsService.create_tracing_app_config("app_id", provider, {}) @@ -253,7 +259,8 @@ class TestOpsService: mock_ops_trace_manager.check_trace_config_is_effective.return_value = True app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_db.session.scalar.return_value = None + mock_db.session.get.return_value = app mock_ops_trace_manager.encrypt_tracing_config.return_value = {} # Act @@ -274,7 +281,8 @@ class TestOpsService: mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url" app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_db.session.scalar.return_value = None + mock_db.session.get.return_value = app mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"} # Act @@ -297,7 +305,7 @@ class TestOpsService: def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): # Arrange provider = TracingProviderEnum.ARIZE - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act result = OpsService.update_tracing_app_config("app_id", provider, {}) @@ -311,7 +319,8 @@ class TestOpsService: # Arrange provider = TracingProviderEnum.ARIZE current_config = MagicMock(spec=TraceAppConfig) - mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, None] + mock_db.session.scalar.return_value = current_config + mock_db.session.get.return_value = None # Act result = OpsService.update_tracing_app_config("app_id", provider, {}) @@ -327,7 +336,8 @@ class TestOpsService: current_config = MagicMock(spec=TraceAppConfig) app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app] + mock_db.session.scalar.return_value = current_config + mock_db.session.get.return_value = app mock_ops_trace_manager.decrypt_tracing_config.return_value = {} mock_ops_trace_manager.check_trace_config_is_effective.return_value = False @@ -344,7 +354,8 @@ class TestOpsService: current_config.to_dict.return_value = {"some": "data"} app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app] + mock_db.session.scalar.return_value = current_config + mock_db.session.get.return_value = app mock_ops_trace_manager.decrypt_tracing_config.return_value = {} mock_ops_trace_manager.check_trace_config_is_effective.return_value = True @@ -358,7 +369,7 @@ class TestOpsService: @patch("services.ops_service.db") def test_delete_tracing_app_config_no_config(self, mock_db): # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act result = OpsService.delete_tracing_app_config("app_id", "arize") @@ -370,7 +381,7 @@ class TestOpsService: def test_delete_tracing_app_config_success(self, mock_db): # Arrange trace_config = MagicMock(spec=TraceAppConfig) - mock_db.session.query.return_value.where.return_value.first.return_value = trace_config + mock_db.session.scalar.return_value = trace_config # Act result = OpsService.delete_tracing_app_config("app_id", "arize") From 725f9e3dc4d38b04cfa6493108fc074fb84b8ab0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 2 Apr 2026 09:33:09 +0900 Subject: [PATCH 15/42] chore(deps): bump aiohttp from 3.13.3 to 3.13.4 in /api (#34425) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- api/uv.lock | 72 ++++++++++++++++++++++++++--------------------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/api/uv.lock b/api/uv.lock index 39c362eda03..9ec408d3809 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -60,7 +60,7 @@ wheels = [ [[package]] name = "aiohttp" -version = "3.13.3" +version = "3.13.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohappyeyeballs" }, @@ -71,42 +71,42 @@ dependencies = [ { name = "propcache" }, { name = "yarl" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/50/42/32cf8e7704ceb4481406eb87161349abb46a57fee3f008ba9cb610968646/aiohttp-3.13.3.tar.gz", hash = "sha256:a949eee43d3782f2daae4f4a2819b2cb9b0c5d3b7f7a927067cc84dafdbb9f88", size = 7844556, upload-time = "2026-01-03T17:33:05.204Z" } +sdist = { url = "https://files.pythonhosted.org/packages/45/4a/064321452809dae953c1ed6e017504e72551a26b6f5708a5a80e4bf556ff/aiohttp-3.13.4.tar.gz", hash = "sha256:d97a6d09c66087890c2ab5d49069e1e570583f7ac0314ecf98294c1b6aaebd38", size = 7859748, upload-time = "2026-03-28T17:19:40.6Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f1/4c/a164164834f03924d9a29dc3acd9e7ee58f95857e0b467f6d04298594ebb/aiohttp-3.13.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5b6073099fb654e0a068ae678b10feff95c5cae95bbfcbfa7af669d361a8aa6b", size = 746051, upload-time = "2026-01-03T17:29:43.287Z" }, - { url = "https://files.pythonhosted.org/packages/82/71/d5c31390d18d4f58115037c432b7e0348c60f6f53b727cad33172144a112/aiohttp-3.13.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1cb93e166e6c28716c8c6aeb5f99dfb6d5ccf482d29fe9bf9a794110e6d0ab64", size = 499234, upload-time = "2026-01-03T17:29:44.822Z" }, - { url = "https://files.pythonhosted.org/packages/0e/c9/741f8ac91e14b1d2e7100690425a5b2b919a87a5075406582991fb7de920/aiohttp-3.13.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:28e027cf2f6b641693a09f631759b4d9ce9165099d2b5d92af9bd4e197690eea", size = 494979, upload-time = "2026-01-03T17:29:46.405Z" }, - { url = "https://files.pythonhosted.org/packages/75/b5/31d4d2e802dfd59f74ed47eba48869c1c21552c586d5e81a9d0d5c2ad640/aiohttp-3.13.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3b61b7169ababd7802f9568ed96142616a9118dd2be0d1866e920e77ec8fa92a", size = 1748297, upload-time = "2026-01-03T17:29:48.083Z" }, - { url = "https://files.pythonhosted.org/packages/1a/3e/eefad0ad42959f226bb79664826883f2687d602a9ae2941a18e0484a74d3/aiohttp-3.13.3-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:80dd4c21b0f6237676449c6baaa1039abae86b91636b6c91a7f8e61c87f89540", size = 1707172, upload-time = "2026-01-03T17:29:49.648Z" }, - { url = "https://files.pythonhosted.org/packages/c5/3a/54a64299fac2891c346cdcf2aa6803f994a2e4beeaf2e5a09dcc54acc842/aiohttp-3.13.3-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:65d2ccb7eabee90ce0503c17716fc77226be026dcc3e65cce859a30db715025b", size = 1805405, upload-time = "2026-01-03T17:29:51.244Z" }, - { url = "https://files.pythonhosted.org/packages/6c/70/ddc1b7169cf64075e864f64595a14b147a895a868394a48f6a8031979038/aiohttp-3.13.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5b179331a481cb5529fca8b432d8d3c7001cb217513c94cd72d668d1248688a3", size = 1899449, upload-time = "2026-01-03T17:29:53.938Z" }, - { url = "https://files.pythonhosted.org/packages/a1/7e/6815aab7d3a56610891c76ef79095677b8b5be6646aaf00f69b221765021/aiohttp-3.13.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9d4c940f02f49483b18b079d1c27ab948721852b281f8b015c058100e9421dd1", size = 1748444, upload-time = "2026-01-03T17:29:55.484Z" }, - { url = "https://files.pythonhosted.org/packages/6b/f2/073b145c4100da5511f457dc0f7558e99b2987cf72600d42b559db856fbc/aiohttp-3.13.3-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f9444f105664c4ce47a2a7171a2418bce5b7bae45fb610f4e2c36045d85911d3", size = 1606038, upload-time = "2026-01-03T17:29:57.179Z" }, - { url = "https://files.pythonhosted.org/packages/0a/c1/778d011920cae03ae01424ec202c513dc69243cf2db303965615b81deeea/aiohttp-3.13.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:694976222c711d1d00ba131904beb60534f93966562f64440d0c9d41b8cdb440", size = 1724156, upload-time = "2026-01-03T17:29:58.914Z" }, - { url = "https://files.pythonhosted.org/packages/0e/cb/3419eabf4ec1e9ec6f242c32b689248365a1cf621891f6f0386632525494/aiohttp-3.13.3-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:f33ed1a2bf1997a36661874b017f5c4b760f41266341af36febaf271d179f6d7", size = 1722340, upload-time = "2026-01-03T17:30:01.962Z" }, - { url = "https://files.pythonhosted.org/packages/7a/e5/76cf77bdbc435bf233c1f114edad39ed4177ccbfab7c329482b179cff4f4/aiohttp-3.13.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e636b3c5f61da31a92bf0d91da83e58fdfa96f178ba682f11d24f31944cdd28c", size = 1783041, upload-time = "2026-01-03T17:30:03.609Z" }, - { url = "https://files.pythonhosted.org/packages/9d/d4/dd1ca234c794fd29c057ce8c0566b8ef7fd6a51069de5f06fa84b9a1971c/aiohttp-3.13.3-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:5d2d94f1f5fcbe40838ac51a6ab5704a6f9ea42e72ceda48de5e6b898521da51", size = 1596024, upload-time = "2026-01-03T17:30:05.132Z" }, - { url = "https://files.pythonhosted.org/packages/55/58/4345b5f26661a6180afa686c473620c30a66afdf120ed3dd545bbc809e85/aiohttp-3.13.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:2be0e9ccf23e8a94f6f0650ce06042cefc6ac703d0d7ab6c7a917289f2539ad4", size = 1804590, upload-time = "2026-01-03T17:30:07.135Z" }, - { url = "https://files.pythonhosted.org/packages/7b/06/05950619af6c2df7e0a431d889ba2813c9f0129cec76f663e547a5ad56f2/aiohttp-3.13.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9af5e68ee47d6534d36791bbe9b646d2a7c7deb6fc24d7943628edfbb3581f29", size = 1740355, upload-time = "2026-01-03T17:30:09.083Z" }, - { url = "https://files.pythonhosted.org/packages/3e/80/958f16de79ba0422d7c1e284b2abd0c84bc03394fbe631d0a39ffa10e1eb/aiohttp-3.13.3-cp311-cp311-win32.whl", hash = "sha256:a2212ad43c0833a873d0fb3c63fa1bacedd4cf6af2fee62bf4b739ceec3ab239", size = 433701, upload-time = "2026-01-03T17:30:10.869Z" }, - { url = "https://files.pythonhosted.org/packages/dc/f2/27cdf04c9851712d6c1b99df6821a6623c3c9e55956d4b1e318c337b5a48/aiohttp-3.13.3-cp311-cp311-win_amd64.whl", hash = "sha256:642f752c3eb117b105acbd87e2c143de710987e09860d674e068c4c2c441034f", size = 457678, upload-time = "2026-01-03T17:30:12.719Z" }, - { url = "https://files.pythonhosted.org/packages/a0/be/4fc11f202955a69e0db803a12a062b8379c970c7c84f4882b6da17337cc1/aiohttp-3.13.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:b903a4dfee7d347e2d87697d0713be59e0b87925be030c9178c5faa58ea58d5c", size = 739732, upload-time = "2026-01-03T17:30:14.23Z" }, - { url = "https://files.pythonhosted.org/packages/97/2c/621d5b851f94fa0bb7430d6089b3aa970a9d9b75196bc93bb624b0db237a/aiohttp-3.13.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a45530014d7a1e09f4a55f4f43097ba0fd155089372e105e4bff4ca76cb1b168", size = 494293, upload-time = "2026-01-03T17:30:15.96Z" }, - { url = "https://files.pythonhosted.org/packages/5d/43/4be01406b78e1be8320bb8316dc9c42dbab553d281c40364e0f862d5661c/aiohttp-3.13.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:27234ef6d85c914f9efeb77ff616dbf4ad2380be0cda40b4db086ffc7ddd1b7d", size = 493533, upload-time = "2026-01-03T17:30:17.431Z" }, - { url = "https://files.pythonhosted.org/packages/8d/a8/5a35dc56a06a2c90d4742cbf35294396907027f80eea696637945a106f25/aiohttp-3.13.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d32764c6c9aafb7fb55366a224756387cd50bfa720f32b88e0e6fa45b27dcf29", size = 1737839, upload-time = "2026-01-03T17:30:19.422Z" }, - { url = "https://files.pythonhosted.org/packages/bf/62/4b9eeb331da56530bf2e198a297e5303e1c1ebdceeb00fe9b568a65c5a0c/aiohttp-3.13.3-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:b1a6102b4d3ebc07dad44fbf07b45bb600300f15b552ddf1851b5390202ea2e3", size = 1703932, upload-time = "2026-01-03T17:30:21.756Z" }, - { url = "https://files.pythonhosted.org/packages/7c/f6/af16887b5d419e6a367095994c0b1332d154f647e7dc2bd50e61876e8e3d/aiohttp-3.13.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c014c7ea7fb775dd015b2d3137378b7be0249a448a1612268b5a90c2d81de04d", size = 1771906, upload-time = "2026-01-03T17:30:23.932Z" }, - { url = "https://files.pythonhosted.org/packages/ce/83/397c634b1bcc24292fa1e0c7822800f9f6569e32934bdeef09dae7992dfb/aiohttp-3.13.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2b8d8ddba8f95ba17582226f80e2de99c7a7948e66490ef8d947e272a93e9463", size = 1871020, upload-time = "2026-01-03T17:30:26Z" }, - { url = "https://files.pythonhosted.org/packages/86/f6/a62cbbf13f0ac80a70f71b1672feba90fdb21fd7abd8dbf25c0105fb6fa3/aiohttp-3.13.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ae8dd55c8e6c4257eae3a20fd2c8f41edaea5992ed67156642493b8daf3cecc", size = 1755181, upload-time = "2026-01-03T17:30:27.554Z" }, - { url = "https://files.pythonhosted.org/packages/0a/87/20a35ad487efdd3fba93d5843efdfaa62d2f1479eaafa7453398a44faf13/aiohttp-3.13.3-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:01ad2529d4b5035578f5081606a465f3b814c542882804e2e8cda61adf5c71bf", size = 1561794, upload-time = "2026-01-03T17:30:29.254Z" }, - { url = "https://files.pythonhosted.org/packages/de/95/8fd69a66682012f6716e1bc09ef8a1a2a91922c5725cb904689f112309c4/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bb4f7475e359992b580559e008c598091c45b5088f28614e855e42d39c2f1033", size = 1697900, upload-time = "2026-01-03T17:30:31.033Z" }, - { url = "https://files.pythonhosted.org/packages/e5/66/7b94b3b5ba70e955ff597672dad1691333080e37f50280178967aff68657/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:c19b90316ad3b24c69cd78d5c9b4f3aa4497643685901185b65166293d36a00f", size = 1728239, upload-time = "2026-01-03T17:30:32.703Z" }, - { url = "https://files.pythonhosted.org/packages/47/71/6f72f77f9f7d74719692ab65a2a0252584bf8d5f301e2ecb4c0da734530a/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:96d604498a7c782cb15a51c406acaea70d8c027ee6b90c569baa6e7b93073679", size = 1740527, upload-time = "2026-01-03T17:30:34.695Z" }, - { url = "https://files.pythonhosted.org/packages/fa/b4/75ec16cbbd5c01bdaf4a05b19e103e78d7ce1ef7c80867eb0ace42ff4488/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:084911a532763e9d3dd95adf78a78f4096cd5f58cdc18e6fdbc1b58417a45423", size = 1554489, upload-time = "2026-01-03T17:30:36.864Z" }, - { url = "https://files.pythonhosted.org/packages/52/8f/bc518c0eea29f8406dcf7ed1f96c9b48e3bc3995a96159b3fc11f9e08321/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:7a4a94eb787e606d0a09404b9c38c113d3b099d508021faa615d70a0131907ce", size = 1767852, upload-time = "2026-01-03T17:30:39.433Z" }, - { url = "https://files.pythonhosted.org/packages/9d/f2/a07a75173124f31f11ea6f863dc44e6f09afe2bca45dd4e64979490deab1/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:87797e645d9d8e222e04160ee32aa06bc5c163e8499f24db719e7852ec23093a", size = 1722379, upload-time = "2026-01-03T17:30:41.081Z" }, - { url = "https://files.pythonhosted.org/packages/3c/4a/1a3fee7c21350cac78e5c5cef711bac1b94feca07399f3d406972e2d8fcd/aiohttp-3.13.3-cp312-cp312-win32.whl", hash = "sha256:b04be762396457bef43f3597c991e192ee7da460a4953d7e647ee4b1c28e7046", size = 428253, upload-time = "2026-01-03T17:30:42.644Z" }, - { url = "https://files.pythonhosted.org/packages/d9/b7/76175c7cb4eb73d91ad63c34e29fc4f77c9386bba4a65b53ba8e05ee3c39/aiohttp-3.13.3-cp312-cp312-win_amd64.whl", hash = "sha256:e3531d63d3bdfa7e3ac5e9b27b2dd7ec9df3206a98e0b3445fa906f233264c57", size = 455407, upload-time = "2026-01-03T17:30:44.195Z" }, + { url = "https://files.pythonhosted.org/packages/d4/7e/cb94129302d78c46662b47f9897d642fd0b33bdfef4b73b20c6ced35aa4c/aiohttp-3.13.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8ea0c64d1bcbf201b285c2246c51a0c035ba3bbd306640007bc5844a3b4658c1", size = 760027, upload-time = "2026-03-28T17:15:33.022Z" }, + { url = "https://files.pythonhosted.org/packages/5e/cd/2db3c9397c3bd24216b203dd739945b04f8b87bb036c640da7ddb63c75ef/aiohttp-3.13.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6f742e1fa45c0ed522b00ede565e18f97e4cf8d1883a712ac42d0339dfb0cce7", size = 508325, upload-time = "2026-03-28T17:15:34.714Z" }, + { url = "https://files.pythonhosted.org/packages/36/a3/d28b2722ec13107f2e37a86b8a169897308bab6a3b9e071ecead9d67bd9b/aiohttp-3.13.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6dcfb50ee25b3b7a1222a9123be1f9f89e56e67636b561441f0b304e25aaef8f", size = 502402, upload-time = "2026-03-28T17:15:36.409Z" }, + { url = "https://files.pythonhosted.org/packages/fa/d6/acd47b5f17c4430e555590990a4746efbcb2079909bb865516892bf85f37/aiohttp-3.13.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3262386c4ff370849863ea93b9ea60fd59c6cf56bf8f93beac625cf4d677c04d", size = 1771224, upload-time = "2026-03-28T17:15:38.223Z" }, + { url = "https://files.pythonhosted.org/packages/98/af/af6e20113ba6a48fd1cd9e5832c4851e7613ef50c7619acdaee6ec5f1aff/aiohttp-3.13.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:473bb5aa4218dd254e9ae4834f20e31f5a0083064ac0136a01a62ddbae2eaa42", size = 1731530, upload-time = "2026-03-28T17:15:39.988Z" }, + { url = "https://files.pythonhosted.org/packages/81/16/78a2f5d9c124ad05d5ce59a9af94214b6466c3491a25fb70760e98e9f762/aiohttp-3.13.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e56423766399b4c77b965f6aaab6c9546617b8994a956821cc507d00b91d978c", size = 1827925, upload-time = "2026-03-28T17:15:41.944Z" }, + { url = "https://files.pythonhosted.org/packages/2a/1f/79acf0974ced805e0e70027389fccbb7d728e6f30fcac725fb1071e63075/aiohttp-3.13.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8af249343fafd5ad90366a16d230fc265cf1149f26075dc9fe93cfd7c7173942", size = 1923579, upload-time = "2026-03-28T17:15:44.071Z" }, + { url = "https://files.pythonhosted.org/packages/af/53/29f9e2054ea6900413f3b4c3eb9d8331f60678ec855f13ba8714c47fd48d/aiohttp-3.13.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0bc0a5cf4f10ef5a2c94fdde488734b582a3a7a000b131263e27c9295bd682d9", size = 1767655, upload-time = "2026-03-28T17:15:45.911Z" }, + { url = "https://files.pythonhosted.org/packages/f3/57/462fe1d3da08109ba4aa8590e7aed57c059af2a7e80ec21f4bac5cfe1094/aiohttp-3.13.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:5c7ff1028e3c9fc5123a865ce17df1cb6424d180c503b8517afbe89aa566e6be", size = 1630439, upload-time = "2026-03-28T17:15:48.11Z" }, + { url = "https://files.pythonhosted.org/packages/d7/4b/4813344aacdb8127263e3eec343d24e973421143826364fa9fc847f6283f/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ba5cf98b5dcb9bddd857da6713a503fa6d341043258ca823f0f5ab7ab4a94ee8", size = 1745557, upload-time = "2026-03-28T17:15:50.13Z" }, + { url = "https://files.pythonhosted.org/packages/d4/01/1ef1adae1454341ec50a789f03cfafe4c4ac9c003f6a64515ecd32fe4210/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:d85965d3ba21ee4999e83e992fecb86c4614d6920e40705501c0a1f80a583c12", size = 1741796, upload-time = "2026-03-28T17:15:52.351Z" }, + { url = "https://files.pythonhosted.org/packages/22/04/8cdd99af988d2aa6922714d957d21383c559835cbd43fbf5a47ddf2e0f05/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:49f0b18a9b05d79f6f37ddd567695943fcefb834ef480f17a4211987302b2dc7", size = 1805312, upload-time = "2026-03-28T17:15:54.407Z" }, + { url = "https://files.pythonhosted.org/packages/fb/7f/b48d5577338d4b25bbdbae35c75dbfd0493cb8886dc586fbfb2e90862239/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:7f78cb080c86fbf765920e5f1ef35af3f24ec4314d6675d0a21eaf41f6f2679c", size = 1621751, upload-time = "2026-03-28T17:15:56.564Z" }, + { url = "https://files.pythonhosted.org/packages/bc/89/4eecad8c1858e6d0893c05929e22343e0ebe3aec29a8a399c65c3cc38311/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:67a3ec705534a614b68bbf1c70efa777a21c3da3895d1c44510a41f5a7ae0453", size = 1826073, upload-time = "2026-03-28T17:15:58.489Z" }, + { url = "https://files.pythonhosted.org/packages/f5/5c/9dc8293ed31b46c39c9c513ac7ca152b3c3d38e0ea111a530ad12001b827/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d6630ec917e85c5356b2295744c8a97d40f007f96a1c76bf1928dc2e27465393", size = 1760083, upload-time = "2026-03-28T17:16:00.677Z" }, + { url = "https://files.pythonhosted.org/packages/1e/19/8bbf6a4994205d96831f97b7d21a0feed120136e6267b5b22d229c6dc4dc/aiohttp-3.13.4-cp311-cp311-win32.whl", hash = "sha256:54049021bc626f53a5394c29e8c444f726ee5a14b6e89e0ad118315b1f90f5e3", size = 439690, upload-time = "2026-03-28T17:16:02.902Z" }, + { url = "https://files.pythonhosted.org/packages/0c/f5/ac409ecd1007528d15c3e8c3a57d34f334c70d76cfb7128a28cffdebd4c1/aiohttp-3.13.4-cp311-cp311-win_amd64.whl", hash = "sha256:c033f2bc964156030772d31cbf7e5defea181238ce1f87b9455b786de7d30145", size = 463824, upload-time = "2026-03-28T17:16:05.058Z" }, + { url = "https://files.pythonhosted.org/packages/1e/bd/ede278648914cabbabfdf95e436679b5d4156e417896a9b9f4587169e376/aiohttp-3.13.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ee62d4471ce86b108b19c3364db4b91180d13fe3510144872d6bad5401957360", size = 752158, upload-time = "2026-03-28T17:16:06.901Z" }, + { url = "https://files.pythonhosted.org/packages/90/de/581c053253c07b480b03785196ca5335e3c606a37dc73e95f6527f1591fe/aiohttp-3.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c0fd8f41b54b58636402eb493afd512c23580456f022c1ba2db0f810c959ed0d", size = 501037, upload-time = "2026-03-28T17:16:08.82Z" }, + { url = "https://files.pythonhosted.org/packages/fa/f9/a5ede193c08f13cc42c0a5b50d1e246ecee9115e4cf6e900d8dbd8fd6acb/aiohttp-3.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4baa48ce49efd82d6b1a0be12d6a36b35e5594d1dd42f8bfba96ea9f8678b88c", size = 501556, upload-time = "2026-03-28T17:16:10.63Z" }, + { url = "https://files.pythonhosted.org/packages/d6/10/88ff67cd48a6ec36335b63a640abe86135791544863e0cfe1f065d6cef7a/aiohttp-3.13.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d738ebab9f71ee652d9dbd0211057690022201b11197f9a7324fd4dba128aa97", size = 1757314, upload-time = "2026-03-28T17:16:12.498Z" }, + { url = "https://files.pythonhosted.org/packages/8b/15/fdb90a5cf5a1f52845c276e76298c75fbbcc0ac2b4a86551906d54529965/aiohttp-3.13.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:0ce692c3468fa831af7dceed52edf51ac348cebfc8d3feb935927b63bd3e8576", size = 1731819, upload-time = "2026-03-28T17:16:14.558Z" }, + { url = "https://files.pythonhosted.org/packages/ec/df/28146785a007f7820416be05d4f28cc207493efd1e8c6c1068e9bdc29198/aiohttp-3.13.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8e08abcfe752a454d2cb89ff0c08f2d1ecd057ae3e8cc6d84638de853530ebab", size = 1793279, upload-time = "2026-03-28T17:16:16.594Z" }, + { url = "https://files.pythonhosted.org/packages/10/47/689c743abf62ea7a77774d5722f220e2c912a77d65d368b884d9779ef41b/aiohttp-3.13.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5977f701b3fff36367a11087f30ea73c212e686d41cd363c50c022d48b011d8d", size = 1891082, upload-time = "2026-03-28T17:16:18.71Z" }, + { url = "https://files.pythonhosted.org/packages/b0/b6/f7f4f318c7e58c23b761c9b13b9a3c9b394e0f9d5d76fbc6622fa98509f6/aiohttp-3.13.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:54203e10405c06f8b6020bd1e076ae0fe6c194adcee12a5a78af3ffa3c57025e", size = 1773938, upload-time = "2026-03-28T17:16:21.125Z" }, + { url = "https://files.pythonhosted.org/packages/aa/06/f207cb3121852c989586a6fc16ff854c4fcc8651b86c5d3bd1fc83057650/aiohttp-3.13.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:358a6af0145bc4dda037f13167bef3cce54b132087acc4c295c739d05d16b1c3", size = 1579548, upload-time = "2026-03-28T17:16:23.588Z" }, + { url = "https://files.pythonhosted.org/packages/6c/58/e1289661a32161e24c1fe479711d783067210d266842523752869cc1d9c2/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:898ea1850656d7d61832ef06aa9846ab3ddb1621b74f46de78fbc5e1a586ba83", size = 1714669, upload-time = "2026-03-28T17:16:25.713Z" }, + { url = "https://files.pythonhosted.org/packages/96/0a/3e86d039438a74a86e6a948a9119b22540bae037d6ba317a042ae3c22711/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7bc30cceb710cf6a44e9617e43eebb6e3e43ad855a34da7b4b6a73537d8a6763", size = 1754175, upload-time = "2026-03-28T17:16:28.18Z" }, + { url = "https://files.pythonhosted.org/packages/f4/30/e717fc5df83133ba467a560b6d8ef20197037b4bb5d7075b90037de1018e/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4a31c0c587a8a038f19a4c7e60654a6c899c9de9174593a13e7cc6e15ff271f9", size = 1762049, upload-time = "2026-03-28T17:16:30.941Z" }, + { url = "https://files.pythonhosted.org/packages/e4/28/8f7a2d4492e336e40005151bdd94baf344880a4707573378579f833a64c1/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:2062f675f3fe6e06d6113eb74a157fb9df58953ffed0cdb4182554b116545758", size = 1570861, upload-time = "2026-03-28T17:16:32.953Z" }, + { url = "https://files.pythonhosted.org/packages/78/45/12e1a3d0645968b1c38de4b23fdf270b8637735ea057d4f84482ff918ad9/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3d1ba8afb847ff80626d5e408c1fdc99f942acc877d0702fe137015903a220a9", size = 1790003, upload-time = "2026-03-28T17:16:35.468Z" }, + { url = "https://files.pythonhosted.org/packages/eb/0f/60374e18d590de16dcb39d6ff62f39c096c1b958e6f37727b5870026ea30/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b08149419994cdd4d5eecf7fd4bc5986b5a9380285bcd01ab4c0d6bfca47b79d", size = 1737289, upload-time = "2026-03-28T17:16:38.187Z" }, + { url = "https://files.pythonhosted.org/packages/02/bf/535e58d886cfbc40a8b0013c974afad24ef7632d645bca0b678b70033a60/aiohttp-3.13.4-cp312-cp312-win32.whl", hash = "sha256:fc432f6a2c4f720180959bc19aa37259651c1a4ed8af8afc84dd41c60f15f791", size = 434185, upload-time = "2026-03-28T17:16:40.735Z" }, + { url = "https://files.pythonhosted.org/packages/1e/1a/d92e3325134ebfff6f4069f270d3aac770d63320bd1fcd0eca023e74d9a8/aiohttp-3.13.4-cp312-cp312-win_amd64.whl", hash = "sha256:6148c9ae97a3e8bff9a1fc9c757fa164116f86c100468339730e717590a3fb77", size = 461285, upload-time = "2026-03-28T17:16:42.713Z" }, ] [[package]] From 2d29345f2631c33b6ca55d00af4ac668378ff691 Mon Sep 17 00:00:00 2001 From: YBoy Date: Thu, 2 Apr 2026 03:47:08 +0200 Subject: [PATCH 16/42] =?UTF-8?q?refactor(api):=20type=20OpsTraceProviderC?= =?UTF-8?q?onfigMap=20with=20TracingProviderCon=E2=80=A6=20(#34424)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/core/ops/ops_trace_manager.py | 24 ++++++++++++++++-------- api/services/ops_service.py | 6 ++---- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index c689a86614c..aa39e6b6819 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -19,6 +19,7 @@ from typing_extensions import TypedDict from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token from core.ops.entities.config_entity import ( OPS_FILE_PATH, + BaseTracingConfig, TracingProviderEnum, ) from core.ops.entities.trace_entity import ( @@ -195,8 +196,15 @@ def _lookup_llm_credential_info( return None, "" -class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): - def __getitem__(self, provider: str) -> dict[str, Any]: +class TracingProviderConfigEntry(TypedDict): + config_class: type[BaseTracingConfig] + secret_keys: list[str] + other_keys: list[str] + trace_instance: type[Any] + + +class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]): + def __getitem__(self, provider: str) -> TracingProviderConfigEntry: match provider: case TracingProviderEnum.LANGFUSE: from core.ops.entities.config_entity import LangfuseConfig @@ -585,8 +593,8 @@ class OpsTraceManager: provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["trace_instance"], ) - tracing_config = config_type(**tracing_config) - return trace_instance(tracing_config).api_check() + config = config_type(**tracing_config) + return trace_instance(config).api_check() @staticmethod def get_trace_config_project_key(tracing_config: dict, tracing_provider: str): @@ -600,8 +608,8 @@ class OpsTraceManager: provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["trace_instance"], ) - tracing_config = config_type(**tracing_config) - return trace_instance(tracing_config).get_project_key() + config = config_type(**tracing_config) + return trace_instance(config).get_project_key() @staticmethod def get_trace_config_project_url(tracing_config: dict, tracing_provider: str): @@ -615,8 +623,8 @@ class OpsTraceManager: provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["trace_instance"], ) - tracing_config = config_type(**tracing_config) - return trace_instance(tracing_config).get_project_url() + config = config_type(**tracing_config) + return trace_instance(config).get_project_url() class TraceTask: diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 2a64088dd66..0db3d3efec4 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,9 +1,7 @@ -from typing import Any - from sqlalchemy import select from core.ops.entities.config_entity import BaseTracingConfig -from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map +from core.ops.ops_trace_manager import OpsTraceManager, TracingProviderConfigEntry, provider_config_map from extensions.ext_database import db from models.model import App, TraceAppConfig @@ -150,7 +148,7 @@ class OpsService: except KeyError: return {"error": f"Invalid tracing provider: {tracing_provider}"} - provider_config: dict[str, Any] = provider_config_map[tracing_provider] + provider_config: TracingProviderConfigEntry = provider_config_map[tracing_provider] config_class: type[BaseTracingConfig] = provider_config["config_class"] other_keys: list[str] = provider_config["other_keys"] From f9d9ad7a3817653f5eeb84880a07175242ebdfbe Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Thu, 2 Apr 2026 12:16:50 +0800 Subject: [PATCH 17/42] refactor(web): migrate remaining toast usage (#34433) --- .../references/runtime-rules.md | 7 +- web/.storybook/preview.tsx | 7 +- .../apps/app-card-operations-flow.test.tsx | 40 +- .../datasets/create-dataset-flow.test.tsx | 8 +- .../dsl-export-import-flow.test.ts | 17 +- .../tools/tool-provider-detail-flow.test.tsx | 8 +- .../[appId]/overview/card-view.tsx | 9 +- .../[appId]/overview/tracing/panel.tsx | 7 +- .../tracing/provider-config-modal.tsx | 17 +- .../account-page/AvatarWithEdit.tsx | 80 ++- .../account-page/email-change-modal.tsx | 382 +++++++-------- .../(commonLayout)/account-page/index.tsx | 216 ++++----- .../delete-account/components/feed-back.tsx | 4 +- .../__tests__/use-app-info-actions.spec.ts | 43 +- .../app-info/use-app-info-actions.ts | 33 +- .../app-sidebar/dataset-info/dropdown.tsx | 8 +- .../add-annotation-modal/index.spec.tsx | 8 +- .../annotation/add-annotation-modal/index.tsx | 7 +- .../csv-uploader.spec.tsx | 28 +- .../csv-uploader.tsx | 6 +- .../batch-add-annotation-modal/index.spec.tsx | 20 +- .../batch-add-annotation-modal/index.tsx | 11 +- .../edit-annotation-modal/index.spec.tsx | 37 +- .../edit-annotation-modal/index.tsx | 12 +- .../components/app/annotation/index.spec.tsx | 24 +- web/app/components/app/annotation/index.tsx | 22 +- .../access-control.spec.tsx | 6 +- .../app/app-access-control/index.tsx | 4 +- .../components/app/app-publisher/index.tsx | 4 +- .../app/app-publisher/version-info-modal.tsx | 12 +- .../config-prompt/advanced-prompt-input.tsx | 24 +- .../config-prompt/simple-prompt-input.tsx | 24 +- .../config-var/config-modal/index.tsx | 26 +- .../configuration/config-var/index.spec.tsx | 18 +- .../app/configuration/config-var/index.tsx | 9 +- .../config/agent/prompt-editor.tsx | 8 +- .../config/automatic/get-automatic-res.tsx | 25 +- .../configuration/config/automatic/result.tsx | 4 +- .../code-generator/get-code-generator-res.tsx | 18 +- .../params-config/config-content.spec.tsx | 18 +- .../params-config/config-content.tsx | 6 +- .../settings-modal/index.spec.tsx | 41 +- .../dataset-config/settings-modal/index.tsx | 11 +- .../debug-with-single-model/index.spec.tsx | 7 - .../app/configuration/debug/index.spec.tsx | 54 ++- .../app/configuration/debug/index.tsx | 50 +- .../components/app/configuration/index.tsx | 19 +- .../tools/external-data-tool-modal.tsx | 248 +++++----- .../app/configuration/tools/index.tsx | 52 +- .../app/create-app-modal/index.spec.tsx | 23 +- .../components/app/create-app-modal/index.tsx | 17 +- .../app/create-from-dsl-modal/index.tsx | 24 +- .../app/create-from-dsl-modal/uploader.tsx | 6 +- .../app/duplicate-modal/index.spec.tsx | 6 +- .../components/app/duplicate-modal/index.tsx | 4 +- web/app/components/app/log/list.tsx | 27 +- .../app/overview/settings/index.spec.tsx | 34 +- .../app/overview/settings/index.tsx | 9 +- .../app/switch-app-modal/index.spec.tsx | 40 +- .../components/app/switch-app-modal/index.tsx | 8 +- .../app/text-generate/item/index.tsx | 6 +- .../text-generate/saved-items/index.spec.tsx | 8 +- .../app/text-generate/saved-items/index.tsx | 4 +- .../apps/__tests__/app-card.spec.tsx | 38 +- web/app/components/apps/app-card.tsx | 41 +- .../agent-log-modal/__tests__/detail.spec.tsx | 35 +- .../agent-log-modal/__tests__/index.spec.tsx | 42 +- .../base/agent-log-modal/detail.tsx | 66 +-- .../base/agent-log-modal/index.stories.tsx | 7 +- .../base/audio-btn/__tests__/audio.spec.ts | 6 +- web/app/components/base/audio-btn/audio.ts | 20 +- .../base/audio-gallery/AudioPlayer.tsx | 88 +--- .../__tests__/AudioPlayer.spec.tsx | 20 +- .../base/block-input/__tests__/index.spec.tsx | 6 +- web/app/components/base/block-input/index.tsx | 47 +- .../__tests__/hooks.spec.tsx | 5 +- .../base/chat/chat-with-history/hooks.tsx | 166 ++----- .../check-input-forms-hooks.spec.tsx | 12 +- .../base/chat/chat/__tests__/hooks.spec.tsx | 10 +- .../chat/chat/__tests__/question.spec.tsx | 4 +- .../chat/answer/__tests__/operation.spec.tsx | 2 +- .../base/chat/chat/answer/operation.tsx | 4 +- .../chat-input-area/__tests__/index.spec.tsx | 12 +- .../base/chat/chat/chat-input-area/index.tsx | 164 +------ .../base/chat/chat/check-input-forms-hooks.ts | 18 +- web/app/components/base/chat/chat/hooks.ts | 6 +- .../components/base/chat/chat/question.tsx | 119 +---- .../embedded-chatbot/__tests__/hooks.spec.tsx | 5 +- .../base/chat/embedded-chatbot/hooks.tsx | 115 +---- .../inputs-form/__tests__/content.spec.tsx | 4 +- .../__tests__/annotation-ctrl-button.spec.tsx | 8 +- .../__tests__/config-param-modal.spec.tsx | 13 +- .../annotation-ctrl-button.tsx | 32 +- .../annotation-reply/config-param-modal.tsx | 47 +- .../moderation-setting-modal.spec.tsx | 6 +- .../moderation/moderation-setting-modal.tsx | 18 +- .../file-uploader/__tests__/hooks.spec.ts | 9 +- .../index.stories.tsx | 7 +- .../index.stories.tsx | 7 +- .../components/base/file-uploader/hooks.ts | 77 ++- .../__tests__/use-check-validated.spec.ts | 14 +- .../base/form/hooks/use-check-validated.ts | 17 +- .../image-uploader/__tests__/hooks.spec.ts | 6 +- .../__tests__/image-preview.spec.tsx | 8 +- .../components/base/image-uploader/hooks.ts | 112 ++--- .../base/image-uploader/image-preview.tsx | 24 +- .../base/tag-input/__tests__/index.spec.tsx | 14 +- .../base/tag-input/__tests__/interop.spec.tsx | 32 +- web/app/components/base/tag-input/index.tsx | 112 ++--- .../tag-management/__tests__/index.spec.tsx | 28 +- .../tag-management/__tests__/panel.spec.tsx | 42 +- .../__tests__/selector.spec.tsx | 31 +- .../__tests__/tag-item-editor.spec.tsx | 71 +-- .../base/tag-management/index.stories.tsx | 7 +- .../components/base/tag-management/index.tsx | 43 +- .../components/base/tag-management/panel.tsx | 85 +--- .../base/tag-management/tag-item-editor.tsx | 45 +- .../text-generation/__tests__/hooks.spec.ts | 14 +- .../components/base/text-generation/hooks.ts | 57 +-- .../base/toast/__tests__/index.spec.tsx | 349 -------------- web/app/components/base/toast/context.ts | 33 -- .../components/base/toast/index.stories.tsx | 105 ---- web/app/components/base/toast/index.tsx | 173 ------- .../components/base/toast/style.module.css | 44 -- .../custom-page/__tests__/index.spec.tsx | 22 +- .../__tests__/use-web-app-brand.spec.tsx | 22 +- .../hooks/use-web-app-brand.ts | 32 +- .../common/retrieval-param-config/index.tsx | 2 +- .../__tests__/index.spec.tsx | 43 +- .../__tests__/uploader.spec.tsx | 24 +- .../hooks/__tests__/use-dsl-import.spec.tsx | 39 +- .../hooks/use-dsl-import.ts | 67 +-- .../create-from-dsl-modal/uploader.tsx | 44 +- .../__tests__/index.spec.tsx | 28 +- .../empty-dataset-creation-modal/index.tsx | 26 +- .../hooks/__tests__/use-file-upload.spec.tsx | 58 ++- .../file-uploader/hooks/use-file-upload.ts | 18 +- .../__tests__/use-document-creation.spec.ts | 8 +- .../step-two/hooks/use-document-creation.ts | 107 +--- .../datasets/create/step-two/index.tsx | 4 +- .../create/website/firecrawl/index.tsx | 67 +-- .../create/website/jina-reader/index.tsx | 64 +-- .../create/website/watercrawl/index.tsx | 78 +-- .../components/__tests__/operations.spec.tsx | 39 +- .../documents/components/operations.tsx | 121 +---- .../__tests__/use-local-file-upload.spec.tsx | 57 +-- .../documents/detail/__tests__/index.spec.tsx | 2 +- .../__tests__/csv-uploader.spec.tsx | 42 +- .../detail/batch-modal/csv-uploader.tsx | 55 +-- .../detail/completed/__tests__/index.spec.tsx | 17 +- .../__tests__/use-child-segment-data.spec.ts | 7 +- .../__tests__/use-segment-list-data.spec.ts | 7 +- .../completed/hooks/use-child-segment-data.ts | 160 ++---- .../completed/hooks/use-segment-list-data.ts | 120 +---- .../detail/embedding/__tests__/index.spec.tsx | 33 +- .../documents/detail/embedding/index.tsx | 74 +-- .../datasets/documents/detail/index.tsx | 8 +- .../detail/metadata/__tests__/index.spec.tsx | 27 +- .../__tests__/use-metadata-state.spec.ts | 23 +- .../metadata/hooks/use-metadata-state.ts | 29 +- .../status-item/__tests__/index.spec.tsx | 27 +- .../datasets/documents/status-item/index.tsx | 70 +-- .../__tests__/index.spec.tsx | 14 +- .../external-api/external-api-modal/index.tsx | 85 +--- .../__tests__/modify-retrieval-modal.spec.tsx | 21 +- .../hit-testing/modify-retrieval-modal.tsx | 58 +-- .../__tests__/modal.spec.tsx | 8 +- .../metadata/edit-metadata-batch/modal.tsx | 81 +--- .../use-batch-edit-document-metadata.spec.ts | 7 +- .../use-edit-dataset-metadata.spec.ts | 8 +- .../__tests__/use-metadata-document.spec.ts | 8 +- .../hooks/use-batch-edit-document-metadata.ts | 38 +- .../hooks/use-edit-dataset-metadata.ts | 38 +- .../metadata/hooks/use-metadata-document.ts | 41 +- .../dataset-metadata-drawer.spec.tsx | 8 +- .../dataset-metadata-drawer.tsx | 17 +- .../rename-modal/__tests__/index.spec.tsx | 68 ++- .../datasets/rename-modal/index.tsx | 75 +-- .../create-app-modal/__tests__/index.spec.tsx | 2 - .../explore/create-app-modal/index.tsx | 4 +- .../__tests__/compliance.spec.tsx | 18 +- .../header/account-dropdown/compliance.tsx | 97 +--- .../__tests__/index.spec.tsx | 22 +- .../workplace-selector/index.tsx | 99 ++-- .../__tests__/modal.spec.tsx | 35 +- .../api-based-extension-page/modal.tsx | 79 +-- .../language-page/__tests__/index.spec.tsx | 7 +- .../account-setting/language-page/index.tsx | 37 +- .../__tests__/dialog.spec.tsx | 5 +- .../__tests__/index.spec.tsx | 22 +- .../edit-workspace-modal/index.tsx | 55 +-- .../operation/__tests__/index.spec.tsx | 5 +- .../members-page/operation/index.tsx | 58 +-- .../__tests__/index.spec.tsx | 22 +- .../transfer-ownership-modal/index.tsx | 146 +----- .../hooks/__tests__/use-auth.spec.tsx | 19 +- .../model-auth/hooks/use-auth.ts | 98 ++-- .../__tests__/credential-panel.spec.tsx | 8 +- .../model-load-balancing-modal.spec.tsx | 21 +- .../use-change-provider-priority.spec.ts | 8 +- .../use-activate-credential.spec.tsx | 18 +- .../use-activate-credential.ts | 35 +- .../model-load-balancing-modal.tsx | 269 +++-------- .../use-change-provider-priority.ts | 44 +- .../plugin-page/SerpapiPlugin.tsx | 16 +- .../__tests__/SerpapiPlugin.spec.tsx | 27 +- .../plugin-page/__tests__/index.spec.tsx | 8 +- .../__tests__/authorized-in-node.spec.tsx | 4 - .../__tests__/plugin-auth-in-agent.spec.tsx | 4 - .../__tests__/api-key-modal.spec.tsx | 19 +- .../__tests__/authorize-components.spec.tsx | 16 +- .../__tests__/oauth-client-settings.spec.tsx | 19 +- .../plugin-auth/authorize/api-key-modal.tsx | 10 +- .../authorize/oauth-client-settings.tsx | 17 +- .../authorized/__tests__/index.spec.tsx | 40 +- .../plugins/plugin-auth/authorized/index.tsx | 24 +- .../__tests__/use-plugin-auth-action.spec.ts | 28 +- .../hooks/use-plugin-auth-action.ts | 28 +- .../datasource-action-list.tsx | 5 - .../edit/__tests__/apikey-edit-modal.spec.tsx | 29 +- .../edit/__tests__/manual-edit-modal.spec.tsx | 29 +- .../edit/__tests__/oauth-edit-modal.spec.tsx | 29 +- .../__tests__/use-reference-setting.spec.ts | 14 +- .../plugin-page/use-reference-setting.ts | 7 +- .../update-plugin/__tests__/index.spec.tsx | 25 +- .../update-plugin/from-market-place.tsx | 9 +- .../components/__tests__/conversion.spec.tsx | 34 +- .../__tests__/update-dsl-modal.spec.tsx | 45 +- .../rag-pipeline/components/conversion.tsx | 40 +- .../editor/form/__tests__/index.spec.tsx | 13 +- .../panel/input-field/editor/form/index.tsx | 35 +- .../field-list/__tests__/hooks.spec.ts | 8 +- .../field-list/__tests__/index.spec.tsx | 10 +- .../panel/input-field/field-list/hooks.ts | 41 +- .../__tests__/index.spec.tsx | 6 +- .../document-processing/options.tsx | 22 +- .../__tests__/index.spec.tsx | 21 +- .../publisher/__tests__/index.spec.tsx | 38 +- .../publisher/__tests__/popup.spec.tsx | 23 +- .../rag-pipeline-header/publisher/popup.tsx | 215 ++------- .../hooks/__tests__/index.spec.ts | 12 +- .../hooks/__tests__/use-DSL.spec.ts | 23 +- .../__tests__/use-update-dsl-modal.spec.ts | 49 +- .../components/rag-pipeline/hooks/use-DSL.ts | 25 +- .../hooks/use-update-dsl-modal.ts | 79 +-- .../use-text-generation-app-state.spec.ts | 8 +- .../hooks/use-text-generation-app-state.ts | 45 +- .../share/text-generation/index.tsx | 123 +---- .../result/__tests__/index.spec.tsx | 2 +- .../share/text-generation/result/index.tsx | 106 ++-- .../__tests__/index.spec.tsx | 14 +- .../edit-custom-collection-modal/index.tsx | 7 +- .../tools/mcp/hooks/use-mcp-modal-form.ts | 22 +- .../__tests__/config-credentials.spec.tsx | 2 +- .../setting/build-in/config-credentials.tsx | 4 +- .../__tests__/features-trigger.spec.tsx | 33 +- .../workflow-header/features-trigger.tsx | 9 +- .../hooks/__tests__/use-DSL.spec.ts | 27 +- .../components/workflow-app/hooks/use-DSL.ts | 11 +- .../nodes/_base/hooks/use-one-step-run.ts | 2 +- .../workflow/nodes/trigger-webhook/panel.tsx | 2 +- .../plugins/link-editor-plugin/hooks.ts | 121 ++--- .../__tests__/value-content-sections.spec.tsx | 36 +- .../education-apply/education-apply-page.tsx | 8 +- web/app/init/InitPasswordPopup.tsx | 8 +- web/app/layout.tsx | 13 +- web/app/signin/invite-settings/page.tsx | 4 +- web/app/signin/one-more-step.tsx | 4 +- web/docs/overlay-migration.md | 17 +- web/eslint-suppressions.json | 456 ++---------------- web/eslint.constants.mjs | 9 - web/hooks/use-import-dsl.ts | 34 +- web/service/base.ts | 18 +- 273 files changed, 3491 insertions(+), 6996 deletions(-) delete mode 100644 web/app/components/base/toast/__tests__/index.spec.tsx delete mode 100644 web/app/components/base/toast/context.ts delete mode 100644 web/app/components/base/toast/index.stories.tsx delete mode 100644 web/app/components/base/toast/index.tsx delete mode 100644 web/app/components/base/toast/style.module.css diff --git a/.agents/skills/frontend-query-mutation/references/runtime-rules.md b/.agents/skills/frontend-query-mutation/references/runtime-rules.md index 02e8b9c2b62..73d6fbddedb 100644 --- a/.agents/skills/frontend-query-mutation/references/runtime-rules.md +++ b/.agents/skills/frontend-query-mutation/references/runtime-rules.md @@ -64,7 +64,7 @@ export const useUpdateAccessMode = () => { // Component only adds UI behavior. updateAccessMode({ appId, mode }, { - onSuccess: () => Toast.notify({ type: 'success', message: '...' }), + onSuccess: () => toast.success('...'), }) // Avoid putting invalidation knowledge in the component. @@ -114,10 +114,7 @@ try { router.push(`/orders/${order.id}`) } catch (error) { - Toast.notify({ - type: 'error', - message: error instanceof Error ? error.message : 'Unknown error', - }) + toast.error(error instanceof Error ? error.message : 'Unknown error') } ``` diff --git a/web/.storybook/preview.tsx b/web/.storybook/preview.tsx index 072244c33f9..a9144e71280 100644 --- a/web/.storybook/preview.tsx +++ b/web/.storybook/preview.tsx @@ -2,7 +2,7 @@ import type { Preview } from '@storybook/react' import type { Resource } from 'i18next' import { withThemeByDataAttribute } from '@storybook/addon-themes' import { QueryClient, QueryClientProvider } from '@tanstack/react-query' -import { ToastProvider } from '../app/components/base/toast' +import { ToastHost } from '../app/components/base/ui/toast' import { I18nClientProvider as I18N } from '../app/components/provider/i18n' import commonEnUS from '../i18n/en-US/common.json' @@ -39,9 +39,10 @@ export const decorators = [ return ( - + <> + - + ) diff --git a/web/__tests__/apps/app-card-operations-flow.test.tsx b/web/__tests__/apps/app-card-operations-flow.test.tsx index c5766878a15..765c7045e56 100644 --- a/web/__tests__/apps/app-card-operations-flow.test.tsx +++ b/web/__tests__/apps/app-card-operations-flow.test.tsx @@ -23,8 +23,25 @@ let mockSystemFeatures = { webapp_auth: { enabled: false }, } +const toastMocks = vi.hoisted(() => ({ + mockNotify: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), +})) const mockRouterPush = vi.fn() -const mockNotify = vi.fn() + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: (message: string, options?: Record) => toastMocks.mockNotify({ type: 'success', message, ...options }), + error: (message: string, options?: Record) => toastMocks.mockNotify({ type: 'error', message, ...options }), + warning: (message: string, options?: Record) => toastMocks.mockNotify({ type: 'warning', message, ...options }), + info: (message: string, options?: Record) => toastMocks.mockNotify({ type: 'info', message, ...options }), + dismiss: toastMocks.dismiss, + update: toastMocks.update, + promise: toastMocks.promise, + }, +})) const mockOnPlanInfoChanged = vi.fn() const mockDeleteAppMutation = vi.fn().mockResolvedValue(undefined) let mockDeleteMutationPending = false @@ -94,27 +111,6 @@ vi.mock('@/context/provider-context', () => ({ }), })) -// Mock the ToastContext used via useContext from use-context-selector -vi.mock('use-context-selector', async () => { - const actual = await vi.importActual('use-context-selector') - return { - ...actual, - useContext: () => ({ notify: mockNotify }), - } -}) - -vi.mock('@/app/components/base/tag-management/store', () => ({ - useStore: (selector: (state: Record) => unknown) => { - const state = { - tagList: [], - showTagManagementModal: false, - setTagList: vi.fn(), - setShowTagManagementModal: vi.fn(), - } - return selector(state) - }, -})) - vi.mock('@/service/tag', () => ({ fetchTagList: vi.fn().mockResolvedValue([]), })) diff --git a/web/__tests__/datasets/create-dataset-flow.test.tsx b/web/__tests__/datasets/create-dataset-flow.test.tsx index e3a59edde64..34d64d8c439 100644 --- a/web/__tests__/datasets/create-dataset-flow.test.tsx +++ b/web/__tests__/datasets/create-dataset-flow.test.tsx @@ -33,8 +33,14 @@ vi.mock('@/service/knowledge/use-dataset', () => ({ useInvalidDatasetList: () => vi.fn(), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/ui/toast', () => ({ default: { notify: vi.fn() }, + toast: { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + }, })) vi.mock('@/app/components/base/amplitude', () => ({ diff --git a/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts b/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts index dc5ab3fc86b..cdf7aba4f6a 100644 --- a/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts +++ b/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts @@ -10,6 +10,19 @@ import { describe, expect, it, vi } from 'vitest' const mockDoSyncWorkflowDraft = vi.fn().mockResolvedValue(undefined) const mockExportPipelineConfig = vi.fn().mockResolvedValue({ data: 'yaml-content' }) const mockNotify = vi.fn() +const mockToast = { + success: (message: string, options?: Record) => mockNotify({ type: 'success', message, ...options }), + error: (message: string, options?: Record) => mockNotify({ type: 'error', message, ...options }), + warning: (message: string, options?: Record) => mockNotify({ type: 'warning', message, ...options }), + info: (message: string, options?: Record) => mockNotify({ type: 'info', message, ...options }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), +} + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: mockToast, +})) const mockEventEmitter = { emit: vi.fn() } const mockDownloadBlob = vi.fn() @@ -19,10 +32,6 @@ vi.mock('react-i18next', () => ({ }), })) -vi.mock('@/app/components/base/toast/context', () => ({ - useToastContext: () => ({ notify: mockNotify }), -})) - vi.mock('@/app/components/workflow/constants', () => ({ DSL_EXPORT_CHECK: 'DSL_EXPORT_CHECK', })) diff --git a/web/__tests__/tools/tool-provider-detail-flow.test.tsx b/web/__tests__/tools/tool-provider-detail-flow.test.tsx index 0101f83f22e..3d66467695a 100644 --- a/web/__tests__/tools/tool-provider-detail-flow.test.tsx +++ b/web/__tests__/tools/tool-provider-detail-flow.test.tsx @@ -153,8 +153,14 @@ vi.mock('@/app/components/base/confirm', () => ({ ), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/ui/toast', () => ({ default: { notify: vi.fn() }, + toast: { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + }, })) vi.mock('@/app/components/base/icons/src/vender/line/general', () => ({ diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx index 8c1df8d63df..26373bd42ae 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx @@ -7,12 +7,11 @@ import type { App } from '@/types/app' import type { I18nKeysByPrefix } from '@/types/i18n' import { useCallback, useMemo } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import AppCard from '@/app/components/app/overview/app-card' import TriggerCard from '@/app/components/app/overview/trigger-card' import { useStore as useAppStore } from '@/app/components/app/store' import Loading from '@/app/components/base/loading' -import { ToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import MCPServiceCard from '@/app/components/tools/mcp/mcp-service-card' import { isTriggerNode } from '@/app/components/workflow/types' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' @@ -34,7 +33,6 @@ export type ICardViewProps = { const CardView: FC = ({ appId, isInPanel, className }) => { const { t } = useTranslation() - const { notify } = useContext(ToastContext) const appDetail = useAppStore(state => state.appDetail) const setAppDetail = useAppStore(state => state.setAppDetail) @@ -90,10 +88,7 @@ const CardView: FC = ({ appId, isInPanel, className }) => { if (type === 'success') updateAppDetail() - notify({ - type, - message: t(`actionMsg.${message}`, { ns: 'common' }) as string, - }) + toast(t(`actionMsg.${message}`, { ns: 'common' }) as string, { type }) } const onChangeSiteStatus = async (value: boolean) => { diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx index 4201d114909..239427159c1 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx @@ -13,7 +13,7 @@ import { useTranslation } from 'react-i18next' import Divider from '@/app/components/base/divider' import { AliyunIcon, ArizeIcon, DatabricksIcon, LangfuseIcon, LangsmithIcon, MlflowIcon, OpikIcon, PhoenixIcon, TencentIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing' import Loading from '@/app/components/base/loading' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import Indicator from '@/app/components/header/indicator' import { useAppContext } from '@/context/app-context' import { usePathname } from '@/next/navigation' @@ -43,10 +43,7 @@ const Panel: FC = () => { await updateTracingStatus({ appId, body: tracingStatus }) setTracingStatus(tracingStatus) if (!noToast) { - Toast.notify({ - type: 'success', - message: t('api.success', { ns: 'common' }), - }) + toast(t('api.success', { ns: 'common' }), { type: 'success' }) } } diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx index ff78712c3c2..cc2143faac5 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx @@ -14,7 +14,7 @@ import { PortalToFollowElem, PortalToFollowElemContent, } from '@/app/components/base/portal-to-follow-elem' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { addTracingConfig, removeTracingConfig, updateTracingConfig } from '@/service/apps' import { docURL } from './config' import Field from './field' @@ -155,10 +155,7 @@ const ProviderConfigModal: FC = ({ appId, provider: type, }) - Toast.notify({ - type: 'success', - message: t('api.remove', { ns: 'common' }), - }) + toast(t('api.remove', { ns: 'common' }), { type: 'success' }) onRemoved() hideRemoveConfirm() }, [hideRemoveConfirm, appId, type, t, onRemoved]) @@ -264,10 +261,7 @@ const ProviderConfigModal: FC = ({ return const errorMessage = checkValid() if (errorMessage) { - Toast.notify({ - type: 'error', - message: errorMessage, - }) + toast(errorMessage, { type: 'error' }) return } const action = isEdit ? updateTracingConfig : addTracingConfig @@ -279,10 +273,7 @@ const ProviderConfigModal: FC = ({ tracing_config: config, }, }) - Toast.notify({ - type: 'success', - message: t('api.success', { ns: 'common' }), - }) + toast(t('api.success', { ns: 'common' }), { type: 'success' }) onSaved(config) if (isAdd) onChosen(type) diff --git a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx index 3fc677d8d80..25e529a2210 100644 --- a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx +++ b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx @@ -8,15 +8,14 @@ import { RiDeleteBin5Line, RiPencilLine } from '@remixicon/react' import * as React from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import ImageInput from '@/app/components/base/app-icon-picker/ImageInput' import getCroppedImg from '@/app/components/base/app-icon-picker/utils' import { Avatar } from '@/app/components/base/avatar' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks' -import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast/context' +import { Dialog, DialogContent } from '@/app/components/base/ui/dialog' +import { toast } from '@/app/components/base/ui/toast' import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config' import { updateUserProfile } from '@/service/common' @@ -25,7 +24,6 @@ type AvatarWithEditProps = AvatarProps & { onSave?: () => void } const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { const { t } = useTranslation() - const { notify } = useContext(ToastContext) const [inputImageInfo, setInputImageInfo] = useState() const [isShowAvatarPicker, setIsShowAvatarPicker] = useState(false) @@ -48,24 +46,24 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { await updateUserProfile({ url: 'account/avatar', body: { avatar: uploadedFileId } }) setIsShowAvatarPicker(false) onSave?.() - notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) } catch (e) { - notify({ type: 'error', message: (e as Error).message }) + toast.error((e as Error).message) } - }, [notify, onSave, t]) + }, [onSave, t]) const handleDeleteAvatar = useCallback(async () => { try { await updateUserProfile({ url: 'account/avatar', body: { avatar: '' } }) - notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) setIsShowDeleteConfirm(false) onSave?.() } catch (e) { - notify({ type: 'error', message: (e as Error).message }) + toast.error((e as Error).message) } - }, [notify, onSave, t]) + }, [onSave, t]) const { handleLocalFileUpload } = useLocalFileUploader({ limit: 3, @@ -134,45 +132,39 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { - setIsShowAvatarPicker(false)} - > - - + !open && setIsShowAvatarPicker(false)}> + + + -
- +
+ - -
- + +
+
+
- setIsShowDeleteConfirm(false)} - > -
{t('avatar.deleteTitle', { ns: 'common' })}
-

{t('avatar.deleteDescription', { ns: 'common' })}

+ !open && setIsShowDeleteConfirm(false)}> + +
{t('avatar.deleteTitle', { ns: 'common' })}
+

{t('avatar.deleteDescription', { ns: 'common' })}

-
- +
+ - -
- + +
+
+
) } diff --git a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx index f0dfd4f12fd..2e2d61f2f93 100644 --- a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx +++ b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx @@ -1,14 +1,12 @@ import type { ResponseError } from '@/service/fetch' import { RiCloseLine } from '@remixicon/react' -import { noop } from 'es-toolkit/function' import * as React from 'react' import { useState } from 'react' import { Trans, useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast/context' +import { Dialog, DialogContent } from '@/app/components/base/ui/dialog' +import { toast } from '@/app/components/base/ui/toast' import { useRouter } from '@/next/navigation' import { checkEmailExisted, @@ -34,7 +32,6 @@ enum STEP { const EmailChangeModal = ({ onClose, email, show }: Props) => { const { t } = useTranslation() - const { notify } = useContext(ToastContext) const router = useRouter() const [step, setStep] = useState(STEP.start) const [code, setCode] = useState('') @@ -70,10 +67,7 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { setStepToken(res.data) } catch (error) { - notify({ - type: 'error', - message: `Error sending verification code: ${error ? (error as any).message : ''}`, - }) + toast.error(`Error sending verification code: ${error ? (error as any).message : ''}`) } } @@ -89,17 +83,11 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { callback?.(res.token) } else { - notify({ - type: 'error', - message: 'Verifying email failed', - }) + toast.error('Verifying email failed') } } catch (error) { - notify({ - type: 'error', - message: `Error verifying email: ${error ? (error as any).message : ''}`, - }) + toast.error(`Error verifying email: ${error ? (error as any).message : ''}`) } } @@ -154,10 +142,7 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { const sendCodeToNewEmail = async () => { if (!isValidEmail(mail)) { - notify({ - type: 'error', - message: 'Invalid email format', - }) + toast.error('Invalid email format') return } await sendEmail( @@ -187,10 +172,7 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { handleLogout() } catch (error) { - notify({ - type: 'error', - message: `Error changing email: ${error ? (error as any).message : ''}`, - }) + toast.error(`Error changing email: ${error ? (error as any).message : ''}`) } } @@ -199,187 +181,185 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { } return ( - -
- -
- {step === STEP.start && ( - <> -
{t('account.changeEmail.title', { ns: 'common' })}
-
-
{t('account.changeEmail.authTip', { ns: 'common' })}
-
- }} - values={{ email }} + !open && onClose()}> + +
+ +
+ {step === STEP.start && ( + <> +
{t('account.changeEmail.title', { ns: 'common' })}
+
+
{t('account.changeEmail.authTip', { ns: 'common' })}
+
+ }} + values={{ email }} + /> +
+
+
+
+ + +
+ + )} + {step === STEP.verifyOrigin && ( + <> +
{t('account.changeEmail.verifyEmail', { ns: 'common' })}
+
+
+ }} + values={{ email }} + /> +
+
+
+
{t('account.changeEmail.codeLabel', { ns: 'common' })}
+ setCode(e.target.value)} + maxLength={6} />
-
-
-
- - -
- - )} - {step === STEP.verifyOrigin && ( - <> -
{t('account.changeEmail.verifyEmail', { ns: 'common' })}
-
-
- }} - values={{ email }} +
+ + +
+
+ {t('account.changeEmail.resendTip', { ns: 'common' })} + {time > 0 && ( + {t('account.changeEmail.resendCount', { ns: 'common', count: time })} + )} + {!time && ( + {t('account.changeEmail.resend', { ns: 'common' })} + )} +
+ + )} + {step === STEP.newEmail && ( + <> +
{t('account.changeEmail.newEmail', { ns: 'common' })}
+
+
{t('account.changeEmail.content3', { ns: 'common' })}
+
+
+
{t('account.changeEmail.emailLabel', { ns: 'common' })}
+ handleNewEmailValueChange(e.target.value)} + destructive={newEmailExited || unAvailableEmail} + /> + {newEmailExited && ( +
{t('account.changeEmail.existingEmail', { ns: 'common' })}
+ )} + {unAvailableEmail && ( +
{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}
+ )} +
+
+ + +
+ + )} + {step === STEP.verifyNew && ( + <> +
{t('account.changeEmail.verifyNew', { ns: 'common' })}
+
+
+ }} + values={{ email: mail }} + /> +
+
+
+
{t('account.changeEmail.codeLabel', { ns: 'common' })}
+ setCode(e.target.value)} + maxLength={6} />
-
-
-
{t('account.changeEmail.codeLabel', { ns: 'common' })}
- setCode(e.target.value)} - maxLength={6} - /> -
-
- - -
-
- {t('account.changeEmail.resendTip', { ns: 'common' })} - {time > 0 && ( - {t('account.changeEmail.resendCount', { ns: 'common', count: time })} - )} - {!time && ( - {t('account.changeEmail.resend', { ns: 'common' })} - )} -
- - )} - {step === STEP.newEmail && ( - <> -
{t('account.changeEmail.newEmail', { ns: 'common' })}
-
-
{t('account.changeEmail.content3', { ns: 'common' })}
-
-
-
{t('account.changeEmail.emailLabel', { ns: 'common' })}
- handleNewEmailValueChange(e.target.value)} - destructive={newEmailExited || unAvailableEmail} - /> - {newEmailExited && ( -
{t('account.changeEmail.existingEmail', { ns: 'common' })}
- )} - {unAvailableEmail && ( -
{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}
- )} -
-
- - -
- - )} - {step === STEP.verifyNew && ( - <> -
{t('account.changeEmail.verifyNew', { ns: 'common' })}
-
-
- }} - values={{ email: mail }} - /> +
+ +
-
-
-
{t('account.changeEmail.codeLabel', { ns: 'common' })}
- setCode(e.target.value)} - maxLength={6} - /> -
-
- - -
-
- {t('account.changeEmail.resendTip', { ns: 'common' })} - {time > 0 && ( - {t('account.changeEmail.resendCount', { ns: 'common', count: time })} - )} - {!time && ( - {t('account.changeEmail.resend', { ns: 'common' })} - )} -
- - )} - +
+ {t('account.changeEmail.resendTip', { ns: 'common' })} + {time > 0 && ( + {t('account.changeEmail.resendCount', { ns: 'common', count: time })} + )} + {!time && ( + {t('account.changeEmail.resend', { ns: 'common' })} + )} +
+ + )} + + ) } diff --git a/web/app/account/(commonLayout)/account-page/index.tsx b/web/app/account/(commonLayout)/account-page/index.tsx index 9a104619da7..7b4a1485300 100644 --- a/web/app/account/(commonLayout)/account-page/index.tsx +++ b/web/app/account/(commonLayout)/account-page/index.tsx @@ -7,13 +7,12 @@ import { import { useQueryClient } from '@tanstack/react-query' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import AppIcon from '@/app/components/base/app-icon' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import Modal from '@/app/components/base/modal' import PremiumBadge from '@/app/components/base/premium-badge' -import { ToastContext } from '@/app/components/base/toast/context' +import { Dialog, DialogContent } from '@/app/components/base/ui/dialog' +import { toast } from '@/app/components/base/ui/toast' import Collapse from '@/app/components/header/account-setting/collapse' import { IS_CE_EDITION, validPassword } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' @@ -43,7 +42,6 @@ export default function AccountPage() { const userProfile = userProfileResp?.profile const mutateUserProfile = () => queryClient.invalidateQueries({ queryKey: commonQueryKeys.userProfile }) const { isEducationAccount } = useProviderContext() - const { notify } = useContext(ToastContext) const [editNameModalVisible, setEditNameModalVisible] = useState(false) const [editName, setEditName] = useState('') const [editing, setEditing] = useState(false) @@ -68,22 +66,19 @@ export default function AccountPage() { try { setEditing(true) await updateUserProfile({ url: 'account/name', body: { name: editName } }) - notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) mutateUserProfile() setEditNameModalVisible(false) setEditing(false) } catch (e) { - notify({ type: 'error', message: (e as Error).message }) + toast.error((e as Error).message) setEditing(false) } } const showErrorMessage = (message: string) => { - notify({ - type: 'error', - message, - }) + toast.error(message) } const valid = () => { if (!password.trim()) { @@ -119,14 +114,14 @@ export default function AccountPage() { repeat_new_password: confirmPassword, }, }) - notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) mutateUserProfile() setEditPasswordModalVisible(false) resetPasswordForm() setEditing(false) } catch (e) { - notify({ type: 'error', message: (e as Error).message }) + toast.error((e as Error).message) setEditPasswordModalVisible(false) setEditing(false) } @@ -221,119 +216,112 @@ export default function AccountPage() {
{ editNameModalVisible && ( - setEditNameModalVisible(false)} - className="!w-[420px] !p-6" - > -
{t('account.editName', { ns: 'common' })}
-
{t('account.name', { ns: 'common' })}
- setEditName(e.target.value)} - /> -
- - -
-
+ !open && setEditNameModalVisible(false)}> + +
{t('account.editName', { ns: 'common' })}
+
{t('account.name', { ns: 'common' })}
+ setEditName(e.target.value)} + /> +
+ + +
+
+
) } { editPasswordModalVisible && ( - { - setEditPasswordModalVisible(false) - resetPasswordForm() - }} - className="!w-[420px] !p-6" - > -
{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}
- {userProfile.is_password_set && ( - <> -
{t('account.currentPassword', { ns: 'common' })}
-
- setCurrentPassword(e.target.value)} - /> + !open && (setEditPasswordModalVisible(false), resetPasswordForm())}> + +
{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}
+ {userProfile.is_password_set && ( + <> +
{t('account.currentPassword', { ns: 'common' })}
+
+ setCurrentPassword(e.target.value)} + /> -
- +
+ +
+ + )} +
+ {userProfile.is_password_set ? t('account.newPassword', { ns: 'common' }) : t('account.password', { ns: 'common' })} +
+
+ setPassword(e.target.value)} + /> +
+
- - )} -
- {userProfile.is_password_set ? t('account.newPassword', { ns: 'common' }) : t('account.password', { ns: 'common' })} -
-
- setPassword(e.target.value)} - /> -
+
+
{t('account.confirmPassword', { ns: 'common' })}
+
+ setConfirmPassword(e.target.value)} + /> +
+ +
+
+
+
-
-
{t('account.confirmPassword', { ns: 'common' })}
-
- setConfirmPassword(e.target.value)} - /> -
- -
-
-
- - -
- + +
) } { diff --git a/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx b/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx index ae73d778f8e..60bd7e5c0da 100644 --- a/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx +++ b/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx @@ -4,7 +4,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import CustomDialog from '@/app/components/base/dialog' import Textarea from '@/app/components/base/textarea' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useAppContext } from '@/context/app-context' import { useRouter } from '@/next/navigation' import { useLogout } from '@/service/use-common' @@ -28,7 +28,7 @@ export default function FeedBack(props: DeleteAccountProps) { await logout() // Tokens are now stored in cookies and cleared by backend router.push('/signin') - Toast.notify({ type: 'info', message: t('account.deleteSuccessTip', { ns: 'common' }) }) + toast.info(t('account.deleteSuccessTip', { ns: 'common' })) } catch (error) { console.error(error) } }, [router, t]) diff --git a/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts index deea28ce3ea..d5eaa4bfe4b 100644 --- a/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts +++ b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts @@ -2,7 +2,16 @@ import { act, renderHook } from '@testing-library/react' import { AppModeEnum } from '@/types/app' import { useAppInfoActions } from '../use-app-info-actions' -const mockNotify = vi.fn() +const toastMocks = vi.hoisted(() => { + const call = vi.fn() + return { + call, + api: vi.fn((message: unknown, options?: Record) => call({ message, ...options })), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + } +}) const mockReplace = vi.fn() const mockOnPlanInfoChanged = vi.fn() const mockInvalidateAppList = vi.fn() @@ -27,10 +36,6 @@ vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace }), })) -vi.mock('use-context-selector', () => ({ - useContext: () => ({ notify: mockNotify }), -})) - vi.mock('@/context/provider-context', () => ({ useProviderContext: () => ({ onPlanInfoChanged: mockOnPlanInfoChanged }), })) @@ -42,8 +47,16 @@ vi.mock('@/app/components/app/store', () => ({ }), })) -vi.mock('@/app/components/base/toast/context', () => ({ - ToastContext: {}, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign(toastMocks.api, { + success: vi.fn((message, options) => toastMocks.call({ type: 'success', message, ...options })), + error: vi.fn((message, options) => toastMocks.call({ type: 'error', message, ...options })), + warning: vi.fn((message, options) => toastMocks.call({ type: 'warning', message, ...options })), + info: vi.fn((message, options) => toastMocks.call({ type: 'info', message, ...options })), + dismiss: toastMocks.dismiss, + update: toastMocks.update, + promise: toastMocks.promise, + }), })) vi.mock('@/service/use-apps', () => ({ @@ -175,7 +188,7 @@ describe('useAppInfoActions', () => { expect(mockUpdateAppInfo).toHaveBeenCalled() expect(mockSetAppDetail).toHaveBeenCalledWith(updatedApp) - expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.editDone' }) + expect(toastMocks.call).toHaveBeenCalledWith({ type: 'success', message: 'app.editDone' }) }) it('should notify error on edit failure', async () => { @@ -194,7 +207,7 @@ describe('useAppInfoActions', () => { }) }) - expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.editFailed' }) + expect(toastMocks.call).toHaveBeenCalledWith({ type: 'error', message: 'app.editFailed' }) }) it('should not call updateAppInfo when appDetail is undefined', async () => { @@ -234,7 +247,7 @@ describe('useAppInfoActions', () => { }) expect(mockCopyApp).toHaveBeenCalled() - expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.newApp.appCreated' }) + expect(toastMocks.call).toHaveBeenCalledWith({ type: 'success', message: 'app.newApp.appCreated' }) expect(mockOnPlanInfoChanged).toHaveBeenCalled() }) @@ -252,7 +265,7 @@ describe('useAppInfoActions', () => { }) }) - expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.newApp.appCreateFailed' }) + expect(toastMocks.call).toHaveBeenCalledWith({ type: 'error', message: 'app.newApp.appCreateFailed' }) }) }) @@ -298,7 +311,7 @@ describe('useAppInfoActions', () => { await result.current.onExport() }) - expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) + expect(toastMocks.call).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) }) }) @@ -410,7 +423,7 @@ describe('useAppInfoActions', () => { await result.current.handleConfirmExport() }) - expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) + expect(toastMocks.call).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) }) }) @@ -456,7 +469,7 @@ describe('useAppInfoActions', () => { }) expect(mockDeleteApp).toHaveBeenCalledWith('app-1') - expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.appDeleted' }) + expect(toastMocks.call).toHaveBeenCalledWith({ type: 'success', message: 'app.appDeleted' }) expect(mockInvalidateAppList).toHaveBeenCalled() expect(mockReplace).toHaveBeenCalledWith('/apps') expect(mockSetAppDetail).toHaveBeenCalledWith() @@ -483,7 +496,7 @@ describe('useAppInfoActions', () => { await result.current.onConfirmDelete() }) - expect(mockNotify).toHaveBeenCalledWith({ + expect(toastMocks.call).toHaveBeenCalledWith({ type: 'error', message: expect.stringContaining('app.appDeleteFailed'), }) diff --git a/web/app/components/app-sidebar/app-info/use-app-info-actions.ts b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts index 55ec13e506f..8b559f7bbaa 100644 --- a/web/app/components/app-sidebar/app-info/use-app-info-actions.ts +++ b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts @@ -3,9 +3,8 @@ import type { CreateAppModalProps } from '@/app/components/explore/create-app-mo import type { EnvironmentVariable } from '@/app/components/workflow/types' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import { useStore as useAppStore } from '@/app/components/app/store' -import { ToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useProviderContext } from '@/context/provider-context' import { useRouter } from '@/next/navigation' @@ -24,7 +23,6 @@ type UseAppInfoActionsParams = { export function useAppInfoActions({ onDetailExpand }: UseAppInfoActionsParams) { const { t } = useTranslation() - const { notify } = useContext(ToastContext) const { replace } = useRouter() const { onPlanInfoChanged } = useProviderContext() const appDetail = useAppStore(state => state.appDetail) @@ -72,13 +70,13 @@ export function useAppInfoActions({ onDetailExpand }: UseAppInfoActionsParams) { max_active_requests, }) closeModal() - notify({ type: 'success', message: t('editDone', { ns: 'app' }) }) + toast(t('editDone', { ns: 'app' }), { type: 'success' }) setAppDetail(app) } catch { - notify({ type: 'error', message: t('editFailed', { ns: 'app' }) }) + toast(t('editFailed', { ns: 'app' }), { type: 'error' }) } - }, [appDetail, closeModal, notify, setAppDetail, t]) + }, [appDetail, closeModal, setAppDetail, t]) const onCopy: DuplicateAppModalProps['onConfirm'] = useCallback(async ({ name, @@ -98,15 +96,15 @@ export function useAppInfoActions({ onDetailExpand }: UseAppInfoActionsParams) { mode: appDetail.mode, }) closeModal() - notify({ type: 'success', message: t('newApp.appCreated', { ns: 'app' }) }) + toast(t('newApp.appCreated', { ns: 'app' }), { type: 'success' }) localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') onPlanInfoChanged() getRedirection(true, newApp, replace) } catch { - notify({ type: 'error', message: t('newApp.appCreateFailed', { ns: 'app' }) }) + toast(t('newApp.appCreateFailed', { ns: 'app' }), { type: 'error' }) } - }, [appDetail, closeModal, notify, onPlanInfoChanged, replace, t]) + }, [appDetail, closeModal, onPlanInfoChanged, replace, t]) const onExport = useCallback(async (include = false) => { if (!appDetail) @@ -117,9 +115,9 @@ export function useAppInfoActions({ onDetailExpand }: UseAppInfoActionsParams) { downloadBlob({ data: file, fileName: `${appDetail.name}.yml` }) } catch { - notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) + toast(t('exportFailed', { ns: 'app' }), { type: 'error' }) } - }, [appDetail, notify, t]) + }, [appDetail, t]) const exportCheck = useCallback(async () => { if (!appDetail) @@ -145,29 +143,26 @@ export function useAppInfoActions({ onDetailExpand }: UseAppInfoActionsParams) { setSecretEnvList(list) } catch { - notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) + toast(t('exportFailed', { ns: 'app' }), { type: 'error' }) } - }, [appDetail, closeModal, notify, onExport, t]) + }, [appDetail, closeModal, onExport, t]) const onConfirmDelete = useCallback(async () => { if (!appDetail) return try { await deleteApp(appDetail.id) - notify({ type: 'success', message: t('appDeleted', { ns: 'app' }) }) + toast(t('appDeleted', { ns: 'app' }), { type: 'success' }) invalidateAppList() onPlanInfoChanged() setAppDetail() replace('/apps') } catch (e: unknown) { - notify({ - type: 'error', - message: `${t('appDeleteFailed', { ns: 'app' })}${e instanceof Error && e.message ? `: ${e.message}` : ''}`, - }) + toast(`${t('appDeleteFailed', { ns: 'app' })}${e instanceof Error && e.message ? `: ${e.message}` : ''}`, { type: 'error' }) } closeModal() - }, [appDetail, closeModal, invalidateAppList, notify, onPlanInfoChanged, replace, setAppDetail, t]) + }, [appDetail, closeModal, invalidateAppList, onPlanInfoChanged, replace, setAppDetail, t]) return { appDetail, diff --git a/web/app/components/app-sidebar/dataset-info/dropdown.tsx b/web/app/components/app-sidebar/dataset-info/dropdown.tsx index 528bac831fb..1d1208e7d3c 100644 --- a/web/app/components/app-sidebar/dataset-info/dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-info/dropdown.tsx @@ -3,6 +3,7 @@ import { RiMoreFill } from '@remixicon/react' import * as React from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' +import { toast } from '@/app/components/base/ui/toast' import { useSelector as useAppContextWithSelector } from '@/context/app-context' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useRouter } from '@/next/navigation' @@ -15,7 +16,6 @@ import { downloadBlob } from '@/utils/download' import ActionButton from '../../base/action-button' import Confirm from '../../base/confirm' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem' -import Toast from '../../base/toast' import RenameDatasetModal from '../../datasets/rename-modal' import Menu from './menu' @@ -69,7 +69,7 @@ const DropDown = ({ downloadBlob({ data: file, fileName: `${name}.pipeline` }) } catch { - Toast.notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) + toast(t('exportFailed', { ns: 'app' }), { type: 'error' }) } }, [dataset, exportPipelineConfig, handleTrigger, t]) @@ -81,7 +81,7 @@ const DropDown = ({ } catch (e: any) { const res = await e.json() - Toast.notify({ type: 'error', message: res?.message || 'Unknown error' }) + toast(res?.message || 'Unknown error', { type: 'error' }) } finally { handleTrigger() @@ -91,7 +91,7 @@ const DropDown = ({ const onConfirmDelete = useCallback(async () => { try { await deleteDataset(dataset.id) - Toast.notify({ type: 'success', message: t('datasetDeleted', { ns: 'dataset' }) }) + toast(t('datasetDeleted', { ns: 'dataset' }), { type: 'success' }) invalidDatasetList() replace('/datasets') } diff --git a/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx index bad3ceefdf9..14f94d910bc 100644 --- a/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx +++ b/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx @@ -9,10 +9,16 @@ vi.mock('@/context/provider-context', () => ({ })) const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/ui/toast', () => ({ default: { notify: vi.fn(args => mockToastNotify(args)), }, + toast: { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + }, })) vi.mock('@/app/components/billing/annotation-full', () => ({ diff --git a/web/app/components/app/annotation/add-annotation-modal/index.tsx b/web/app/components/app/annotation/add-annotation-modal/index.tsx index a3100d51313..d4cc943a574 100644 --- a/web/app/components/app/annotation/add-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/add-annotation-modal/index.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Checkbox from '@/app/components/base/checkbox' import Drawer from '@/app/components/base/drawer-plus' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import AnnotationFull from '@/app/components/billing/annotation-full' import { useProviderContext } from '@/context/provider-context' import EditItem, { EditItemType } from './edit-item' @@ -47,10 +47,7 @@ const AddAnnotationModal: FC = ({ answer, } if (isValid(payload) !== true) { - Toast.notify({ - type: 'error', - message: isValid(payload) as string, - }) + toast.error(isValid(payload) as string) return } diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx index 55f5ee0564d..847db746195 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx @@ -1,11 +1,28 @@ import type { Props } from './csv-uploader' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' -import { ToastContext } from '@/app/components/base/toast/context' import CSVUploader from './csv-uploader' +const toastMocks = vi.hoisted(() => ({ + notify: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: (message: string, options?: Record) => toastMocks.notify({ type: 'success', message, ...options }), + error: (message: string, options?: Record) => toastMocks.notify({ type: 'error', message, ...options }), + warning: (message: string, options?: Record) => toastMocks.notify({ type: 'warning', message, ...options }), + info: (message: string, options?: Record) => toastMocks.notify({ type: 'info', message, ...options }), + dismiss: toastMocks.dismiss, + update: toastMocks.update, + promise: toastMocks.promise, + }, +})) + describe('CSVUploader', () => { - const notify = vi.fn() const updateFile = vi.fn() const getDropElements = () => { @@ -24,9 +41,8 @@ describe('CSVUploader', () => { ...props, } return render( - - - , + , + ) } @@ -76,7 +92,7 @@ describe('CSVUploader', () => { fireEvent.drop(dropContainer, { dataTransfer: { files: [fileA, fileB] } }) - await waitFor(() => expect(notify).toHaveBeenCalledWith({ + await waitFor(() => expect(toastMocks.notify).toHaveBeenCalledWith({ type: 'error', message: 'datasetCreation.stepOne.uploader.validation.count', })) diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx index a969b3d491c..0fbd3974aa2 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx @@ -4,10 +4,9 @@ import { RiDeleteBinLine } from '@remixicon/react' import * as React from 'react' import { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files' -import { ToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import { cn } from '@/utils/classnames' export type Props = { @@ -20,7 +19,6 @@ const CSVUploader: FC = ({ updateFile, }) => { const { t } = useTranslation() - const { notify } = useContext(ToastContext) const [dragging, setDragging] = useState(false) const dropRef = useRef(null) const dragRef = useRef(null) @@ -50,7 +48,7 @@ const CSVUploader: FC = ({ return const files = Array.from(e.dataTransfer.files) if (files.length > 1) { - notify({ type: 'error', message: t('stepOne.uploader.validation.count', { ns: 'datasetCreation' }) }) + toast.error(t('stepOne.uploader.validation.count', { ns: 'datasetCreation' })) return } updateFile(files[0]) diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx index 7fdb99fbab1..8929cc292f9 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx @@ -2,17 +2,10 @@ import type { Mock } from 'vitest' import type { IBatchModalProps } from './index' import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' -import Toast from '@/app/components/base/toast' import { useProviderContext } from '@/context/provider-context' import { annotationBatchImport, checkAnnotationBatchImportProgress } from '@/service/annotation' import BatchModal, { ProcessStatus } from './index' -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: vi.fn(), - }, -})) - vi.mock('@/service/annotation', () => ({ annotationBatchImport: vi.fn(), checkAnnotationBatchImportProgress: vi.fn(), @@ -49,7 +42,18 @@ vi.mock('@/app/components/billing/annotation-full', () => ({ default: () =>
, })) -const mockNotify = Toast.notify as Mock +const mockNotify = vi.fn() +vi.mock('@/app/components/base/ui/toast', () => ({ + default: { + notify: (args: unknown) => mockNotify(args), + }, + toast: { + success: (message: string) => mockNotify({ type: 'success', message }), + error: (message: string) => mockNotify({ type: 'error', message }), + warning: (message: string) => mockNotify({ type: 'warning', message }), + info: (message: string) => mockNotify({ type: 'info', message }), + }, +})) const useProviderContextMock = useProviderContext as Mock const annotationBatchImportMock = annotationBatchImport as Mock const checkAnnotationBatchImportProgressMock = checkAnnotationBatchImportProgress as Mock diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/index.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/index.tsx index be1518b7085..f6d9512d3da 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/index.tsx @@ -7,7 +7,7 @@ import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import AnnotationFull from '@/app/components/billing/annotation-full' import { useProviderContext } from '@/context/provider-context' import { annotationBatchImport, checkAnnotationBatchImportProgress } from '@/service/annotation' @@ -46,7 +46,6 @@ const BatchModal: FC = ({ }, [isShow]) const [importStatus, setImportStatus] = useState() - const notify = Toast.notify const checkProcess = async (jobID: string) => { try { const res = await checkAnnotationBatchImportProgress({ jobID, appId }) @@ -54,15 +53,15 @@ const BatchModal: FC = ({ if (res.job_status === ProcessStatus.WAITING || res.job_status === ProcessStatus.PROCESSING) setTimeout(() => checkProcess(res.job_id), 2500) if (res.job_status === ProcessStatus.ERROR) - notify({ type: 'error', message: `${t('batchModal.runError', { ns: 'appAnnotation' })}` }) + toast.error(`${t('batchModal.runError', { ns: 'appAnnotation' })}`) if (res.job_status === ProcessStatus.COMPLETED) { - notify({ type: 'success', message: `${t('batchModal.completed', { ns: 'appAnnotation' })}` }) + toast.success(`${t('batchModal.completed', { ns: 'appAnnotation' })}`) onAdded() onCancel() } } catch (e: any) { - notify({ type: 'error', message: `${t('batchModal.runError', { ns: 'appAnnotation' })}${'message' in e ? `: ${e.message}` : ''}` }) + toast.error(`${t('batchModal.runError', { ns: 'appAnnotation' })}${'message' in e ? `: ${e.message}` : ''}`) } } @@ -78,7 +77,7 @@ const BatchModal: FC = ({ checkProcess(res.job_id) } catch (e: any) { - notify({ type: 'error', message: `${t('batchModal.runError', { ns: 'appAnnotation' })}${'message' in e ? `: ${e.message}` : ''}` }) + toast.error(`${t('batchModal.runError', { ns: 'appAnnotation' })}${'message' in e ? `: ${e.message}` : ''}`) } } diff --git a/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx index 0bbd1ab67d4..8f6dec42cfe 100644 --- a/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx +++ b/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx @@ -1,7 +1,6 @@ -import type { IToastProps, ToastHandle } from '@/app/components/base/toast' import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import EditAnnotationModal from './index' const { mockAddAnnotation, mockEditAnnotation } = vi.hoisted(() => ({ @@ -37,10 +36,8 @@ vi.mock('@/app/components/billing/annotation-full', () => ({ default: () =>
, })) -type ToastNotifyProps = Pick -type ToastWithNotify = typeof Toast & { notify: (props: ToastNotifyProps) => ToastHandle } -const toastWithNotify = Toast as unknown as ToastWithNotify -const toastNotifySpy = vi.spyOn(toastWithNotify, 'notify').mockReturnValue({ clear: vi.fn() }) +const toastSuccessSpy = vi.spyOn(toast, 'success').mockReturnValue('toast-success') +const toastErrorSpy = vi.spyOn(toast, 'error').mockReturnValue('toast-error') describe('EditAnnotationModal', () => { const defaultProps = { @@ -55,7 +52,8 @@ describe('EditAnnotationModal', () => { } afterAll(() => { - toastNotifySpy.mockRestore() + toastSuccessSpy.mockRestore() + toastErrorSpy.mockRestore() }) beforeEach(() => { @@ -437,10 +435,7 @@ describe('EditAnnotationModal', () => { // Assert await waitFor(() => { - expect(toastNotifySpy).toHaveBeenCalledWith({ - message: 'API Error', - type: 'error', - }) + expect(toastErrorSpy).toHaveBeenCalledWith('API Error') }) expect(mockOnAdded).not.toHaveBeenCalled() @@ -475,10 +470,7 @@ describe('EditAnnotationModal', () => { // Assert await waitFor(() => { - expect(toastNotifySpy).toHaveBeenCalledWith({ - message: 'common.api.actionFailed', - type: 'error', - }) + expect(toastErrorSpy).toHaveBeenCalledWith('common.api.actionFailed') }) expect(mockOnAdded).not.toHaveBeenCalled() @@ -517,10 +509,7 @@ describe('EditAnnotationModal', () => { // Assert await waitFor(() => { - expect(toastNotifySpy).toHaveBeenCalledWith({ - message: 'API Error', - type: 'error', - }) + expect(toastErrorSpy).toHaveBeenCalledWith('API Error') }) expect(mockOnEdited).not.toHaveBeenCalled() @@ -557,10 +546,7 @@ describe('EditAnnotationModal', () => { // Assert await waitFor(() => { - expect(toastNotifySpy).toHaveBeenCalledWith({ - message: 'common.api.actionFailed', - type: 'error', - }) + expect(toastErrorSpy).toHaveBeenCalledWith('common.api.actionFailed') }) expect(mockOnEdited).not.toHaveBeenCalled() @@ -641,10 +627,7 @@ describe('EditAnnotationModal', () => { // Assert await waitFor(() => { - expect(toastNotifySpy).toHaveBeenCalledWith({ - message: 'common.api.actionSuccess', - type: 'success', - }) + expect(toastSuccessSpy).toHaveBeenCalledWith('common.api.actionSuccess') }) }) }) diff --git a/web/app/components/app/annotation/edit-annotation-modal/index.tsx b/web/app/components/app/annotation/edit-annotation-modal/index.tsx index 2595ec38b2e..c0e60b65dcd 100644 --- a/web/app/components/app/annotation/edit-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/edit-annotation-modal/index.tsx @@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next' import Confirm from '@/app/components/base/confirm' import Drawer from '@/app/components/base/drawer-plus' import { MessageCheckRemove } from '@/app/components/base/icons/src/vender/line/communication' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import AnnotationFull from '@/app/components/billing/annotation-full' import { useProviderContext } from '@/context/provider-context' import useTimestamp from '@/hooks/use-timestamp' @@ -72,18 +72,12 @@ const EditAnnotationModal: FC = ({ onAdded(res.id, res.account?.name ?? '', postQuery, postAnswer) } - Toast.notify({ - message: t('api.actionSuccess', { ns: 'common' }) as string, - type: 'success', - }) + toast.success(t('api.actionSuccess', { ns: 'common' }) as string) } catch (error) { const fallbackMessage = t('api.actionFailed', { ns: 'common' }) as string const message = error instanceof Error && error.message ? error.message : fallbackMessage - Toast.notify({ - message, - type: 'error', - }) + toast.error(message) // Re-throw to preserve edit mode behavior for UI components throw error } diff --git a/web/app/components/app/annotation/index.spec.tsx b/web/app/components/app/annotation/index.spec.tsx index d62b60d33dd..5f5e9f74c07 100644 --- a/web/app/components/app/annotation/index.spec.tsx +++ b/web/app/components/app/annotation/index.spec.tsx @@ -3,7 +3,7 @@ import type { AnnotationItem } from './type' import type { App } from '@/types/app' import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useProviderContext } from '@/context/provider-context' import { addAnnotation, @@ -17,10 +17,6 @@ import { AppModeEnum } from '@/types/app' import Annotation from './index' import { JobStatus } from './type' -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: vi.fn() }, -})) - vi.mock('ahooks', () => ({ useDebounce: (value: any) => value, })) @@ -95,7 +91,23 @@ vi.mock('./view-annotation-modal', () => ({ vi.mock('@/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal', () => ({ default: (props: any) => props.isShow ?
: null })) vi.mock('@/app/components/billing/annotation-full/modal', () => ({ default: (props: any) => props.show ?
: null })) -const mockNotify = Toast.notify as Mock +const mockNotify = vi.fn() +vi.spyOn(toast, 'success').mockImplementation((message, options) => { + mockNotify({ type: 'success', message, ...options }) + return 'toast-success-id' +}) +vi.spyOn(toast, 'error').mockImplementation((message, options) => { + mockNotify({ type: 'error', message, ...options }) + return 'toast-error-id' +}) +vi.spyOn(toast, 'warning').mockImplementation((message, options) => { + mockNotify({ type: 'warning', message, ...options }) + return 'toast-warning-id' +}) +vi.spyOn(toast, 'info').mockImplementation((message, options) => { + mockNotify({ type: 'info', message, ...options }) + return 'toast-info-id' +}) const addAnnotationMock = addAnnotation as Mock const delAnnotationMock = delAnnotation as Mock const delAnnotationsMock = delAnnotations as Mock diff --git a/web/app/components/app/annotation/index.tsx b/web/app/components/app/annotation/index.tsx index ee276603cc6..0ea25744ff3 100644 --- a/web/app/components/app/annotation/index.tsx +++ b/web/app/components/app/annotation/index.tsx @@ -15,6 +15,7 @@ import { MessageFast } from '@/app/components/base/icons/src/vender/solid/commun import Loading from '@/app/components/base/loading' import Pagination from '@/app/components/base/pagination' import Switch from '@/app/components/base/switch' +import { toast } from '@/app/components/base/ui/toast' import AnnotationFullModal from '@/app/components/billing/annotation-full/modal' import { APP_PAGE_LIMIT } from '@/config' import { useProviderContext } from '@/context/provider-context' @@ -22,7 +23,6 @@ import { addAnnotation, delAnnotation, delAnnotations, fetchAnnotationConfig as import { AppModeEnum } from '@/types/app' import { sleep } from '@/utils' import { cn } from '@/utils/classnames' -import Toast from '../../base/toast' import EmptyElement from './empty-element' import Filter from './filter' import HeaderOpts from './header-opts' @@ -98,14 +98,14 @@ const Annotation: FC = (props) => { const handleAdd = async (payload: AnnotationItemBasic) => { await addAnnotation(appDetail.id, payload) - Toast.notify({ message: t('api.actionSuccess', { ns: 'common' }), type: 'success' }) + toast.success(t('api.actionSuccess', { ns: 'common' })) fetchList() setControlUpdateList(Date.now()) } const handleRemove = async (id: string) => { await delAnnotation(appDetail.id, id) - Toast.notify({ message: t('api.actionSuccess', { ns: 'common' }), type: 'success' }) + toast.success(t('api.actionSuccess', { ns: 'common' })) fetchList() setControlUpdateList(Date.now()) } @@ -113,13 +113,13 @@ const Annotation: FC = (props) => { const handleBatchDelete = async () => { try { await delAnnotations(appDetail.id, selectedIds) - Toast.notify({ message: t('api.actionSuccess', { ns: 'common' }), type: 'success' }) + toast.success(t('api.actionSuccess', { ns: 'common' })) fetchList() setControlUpdateList(Date.now()) setSelectedIds([]) } catch (e: any) { - Toast.notify({ type: 'error', message: e.message || t('api.actionFailed', { ns: 'common' }) }) + toast.error(e.message || t('api.actionFailed', { ns: 'common' })) } } @@ -132,7 +132,7 @@ const Annotation: FC = (props) => { if (!currItem) return await editAnnotation(appDetail.id, currItem.id, { question, answer }) - Toast.notify({ message: t('api.actionSuccess', { ns: 'common' }), type: 'success' }) + toast.success(t('api.actionSuccess', { ns: 'common' })) fetchList() setControlUpdateList(Date.now()) } @@ -170,10 +170,7 @@ const Annotation: FC = (props) => { const { job_id: jobId }: any = await updateAnnotationStatus(appDetail.id, AnnotationEnableStatus.disable, annotationConfig?.embedding_model, annotationConfig?.score_threshold) await ensureJobCompleted(jobId, AnnotationEnableStatus.disable) await fetchAnnotationConfig() - Toast.notify({ - message: t('api.actionSuccess', { ns: 'common' }), - type: 'success', - }) + toast.success(t('api.actionSuccess', { ns: 'common' })) } }} > @@ -263,10 +260,7 @@ const Annotation: FC = (props) => { await updateAnnotationScore(appDetail.id, annotationId, score) await fetchAnnotationConfig() - Toast.notify({ - message: t('api.actionSuccess', { ns: 'common' }), - type: 'success', - }) + toast.success(t('api.actionSuccess', { ns: 'common' })) setIsShowEdit(false) }} annotationConfig={annotationConfig!} diff --git a/web/app/components/app/app-access-control/access-control.spec.tsx b/web/app/components/app/app-access-control/access-control.spec.tsx index 3a5f2272edd..7411676586e 100644 --- a/web/app/components/app/app-access-control/access-control.spec.tsx +++ b/web/app/components/app/app-access-control/access-control.spec.tsx @@ -2,9 +2,9 @@ import type { AccessControlAccount, AccessControlGroup, Subject } from '@/models import type { App } from '@/types/app' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' +import { toast } from '@/app/components/base/ui/toast' import useAccessControlStore from '@/context/access-control-store' import { AccessMode, SubjectType } from '@/models/access-control' -import Toast from '../../base/toast' import AccessControlDialog from './access-control-dialog' import AccessControlItem from './access-control-item' import AddMemberOrGroupDialog from './add-member-or-group-pop' @@ -303,7 +303,7 @@ describe('AccessControl', () => { it('should initialize menu from app and call update on confirm', async () => { const onClose = vi.fn() const onConfirm = vi.fn() - const toastSpy = vi.spyOn(Toast, 'notify').mockReturnValue({}) + const toastSpy = vi.spyOn(toast, 'success').mockReturnValue('toast-success') useAccessControlStore.setState({ specificGroups: [baseGroup], specificMembers: [baseMember], @@ -336,7 +336,7 @@ describe('AccessControl', () => { { subjectId: baseMember.id, subjectType: SubjectType.ACCOUNT }, ], }) - expect(toastSpy).toHaveBeenCalled() + expect(toastSpy).toHaveBeenCalledWith('app.accessControlDialog.updateSuccess') expect(onConfirm).toHaveBeenCalled() }) }) diff --git a/web/app/components/app/app-access-control/index.tsx b/web/app/components/app/app-access-control/index.tsx index 8d46e41a119..0c1c64eadc2 100644 --- a/web/app/components/app/app-access-control/index.tsx +++ b/web/app/components/app/app-access-control/index.tsx @@ -5,12 +5,12 @@ import { Description as DialogDescription, DialogTitle } from '@headlessui/react import { RiBuildingLine, RiGlobalLine, RiVerifiedBadgeLine } from '@remixicon/react' import { useCallback, useEffect } from 'react' import { useTranslation } from 'react-i18next' +import { toast } from '@/app/components/base/ui/toast' import { useGlobalPublicStore } from '@/context/global-public-context' import { AccessMode, SubjectType } from '@/models/access-control' import { useUpdateAccessMode } from '@/service/access-control' import useAccessControlStore from '../../../../context/access-control-store' import Button from '../../base/button' -import Toast from '../../base/toast' import AccessControlDialog from './access-control-dialog' import AccessControlItem from './access-control-item' import SpecificGroupsOrMembers, { WebAppSSONotEnabledTip } from './specific-groups-or-members' @@ -61,7 +61,7 @@ export default function AccessControl(props: AccessControlProps) { submitData.subjects = subjects } await updateAccessMode(submitData) - Toast.notify({ type: 'success', message: t('accessControlDialog.updateSuccess', { ns: 'app' }) }) + toast.success(t('accessControlDialog.updateSuccess', { ns: 'app' })) onConfirm?.() }, [updateAccessMode, app, specificGroups, specificMembers, t, onConfirm, currentMenu]) return ( diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index 74d6a19cc1e..649b225b231 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -35,8 +35,8 @@ import { AppModeEnum } from '@/types/app' import { basePath } from '@/utils/var' import Divider from '../../base/divider' import Loading from '../../base/loading' -import Toast from '../../base/toast' import Tooltip from '../../base/tooltip' +import { toast } from '../../base/ui/toast' import ShortcutsName from '../../workflow/shortcuts-name' import { getKeyboardKeyCodeBySystem } from '../../workflow/utils' import AccessControl from '../app-access-control' @@ -219,7 +219,7 @@ const AppPublisher = ({ throw new Error('No app found in Explore') }, { onError: (err) => { - Toast.notify({ type: 'error', message: `${err.message || err}` }) + toast.error(`${err.message || err}`) }, }) }, [appDetail?.id, openAsyncWindow]) diff --git a/web/app/components/app/app-publisher/version-info-modal.tsx b/web/app/components/app/app-publisher/version-info-modal.tsx index ee896cf5837..a1d6edcf04d 100644 --- a/web/app/components/app/app-publisher/version-info-modal.tsx +++ b/web/app/components/app/app-publisher/version-info-modal.tsx @@ -5,7 +5,7 @@ import * as React from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Modal from '@/app/components/base/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import Button from '../../base/button' import Input from '../../base/input' import Textarea from '../../base/textarea' @@ -35,10 +35,7 @@ const VersionInfoModal: FC = ({ const handlePublish = () => { if (title.length > TITLE_MAX_LENGTH) { setTitleError(true) - Toast.notify({ - type: 'error', - message: t('versionHistory.editField.titleLengthLimit', { ns: 'workflow', limit: TITLE_MAX_LENGTH }), - }) + toast.error(t('versionHistory.editField.titleLengthLimit', { ns: 'workflow', limit: TITLE_MAX_LENGTH })) return } else { @@ -48,10 +45,7 @@ const VersionInfoModal: FC = ({ if (releaseNotes.length > RELEASE_NOTES_MAX_LENGTH) { setReleaseNotesError(true) - Toast.notify({ - type: 'error', - message: t('versionHistory.editField.releaseNotesLengthLimit', { ns: 'workflow', limit: RELEASE_NOTES_MAX_LENGTH }), - }) + toast.error(t('versionHistory.editField.releaseNotesLengthLimit', { ns: 'workflow', limit: RELEASE_NOTES_MAX_LENGTH })) return } else { diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx index 9625204d814..482f61bb82c 100644 --- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx @@ -20,8 +20,12 @@ import { } from '@/app/components/base/icons/src/vender/line/files' import PromptEditor from '@/app/components/base/prompt-editor' import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' -import { useToastContext } from '@/app/components/base/toast/context' -import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from '@/app/components/base/ui/tooltip' import ConfigContext from '@/context/debug-configuration' import { useEventEmitterContextContext } from '@/context/event-emitter' import { useModalContext } from '@/context/modal-context' @@ -74,7 +78,6 @@ const AdvancedPromptInput: FC = ({ showSelectDataSet, externalDataToolsConfig, } = useContext(ConfigContext) - const { notify } = useToastContext() const { setShowExternalDataToolModal } = useModalContext() const handleOpenExternalDataToolModal = () => { setShowExternalDataToolModal({ @@ -94,7 +97,7 @@ const AdvancedPromptInput: FC = ({ onValidateBeforeSaveCallback: (newExternalDataTool: ExternalDataTool) => { for (let i = 0; i < promptVariables.length; i++) { if (promptVariables[i].key === newExternalDataTool.variable) { - notify({ type: 'error', message: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i].key }) }) + toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i].key })) return false } } @@ -180,13 +183,18 @@ const AdvancedPromptInput: FC = ({
{t('pageTitle.line1', { ns: 'appDebug' })}
- + + )} + /> +
{t('promptTip', { ns: 'appDebug' })}
- )} - /> +
+
)}
diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index 39a16990636..bc54e0f16dd 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -17,8 +17,12 @@ import { useFeaturesStore } from '@/app/components/base/features/hooks' import PromptEditor from '@/app/components/base/prompt-editor' import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block' import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' -import { useToastContext } from '@/app/components/base/toast/context' -import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from '@/app/components/base/ui/tooltip' import ConfigContext from '@/context/debug-configuration' import { useEventEmitterContextContext } from '@/context/event-emitter' import { useModalContext } from '@/context/modal-context' @@ -72,7 +76,6 @@ const Prompt: FC = ({ showSelectDataSet, externalDataToolsConfig, } = useContext(ConfigContext) - const { notify } = useToastContext() const { setShowExternalDataToolModal } = useModalContext() const handleOpenExternalDataToolModal = () => { setShowExternalDataToolModal({ @@ -92,7 +95,7 @@ const Prompt: FC = ({ onValidateBeforeSaveCallback: (newExternalDataTool: ExternalDataTool) => { for (let i = 0; i < promptVariables.length; i++) { if (promptVariables[i].key === newExternalDataTool.variable) { - notify({ type: 'error', message: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i].key }) }) + toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i].key })) return false } } @@ -180,13 +183,18 @@ const Prompt: FC = ({
{mode !== AppModeEnum.COMPLETION ? t('chatSubTitle', { ns: 'appDebug' }) : t('completionSubTitle', { ns: 'appDebug' })}
{!readonly && ( - + + )} + /> +
{t('promptTip', { ns: 'appDebug' })}
- )} - /> +
+
)}
diff --git a/web/app/components/app/configuration/config-var/config-modal/index.tsx b/web/app/components/app/configuration/config-var/config-modal/index.tsx index 7ea784baa32..b864206b26d 100644 --- a/web/app/components/app/configuration/config-var/config-modal/index.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/index.tsx @@ -15,7 +15,7 @@ import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' import { SimpleSelect } from '@/app/components/base/select' import Textarea from '@/app/components/base/textarea' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { DEFAULT_FILE_UPLOAD_SETTING } from '@/app/components/workflow/constants' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' import FileUploadSetting from '@/app/components/workflow/nodes/_base/components/file-upload-setting' @@ -98,10 +98,7 @@ const ConfigModal: FC = ({ const checkVariableName = useCallback((value: string, canBeEmpty?: boolean) => { const { isValid, errorMessageKey } = checkKeys([value], canBeEmpty) if (!isValid) { - Toast.notify({ - type: 'error', - message: t(`varKeyError.${errorMessageKey}`, { ns: 'appDebug', key: t('variableConfig.varName', { ns: 'appDebug' }) }), - }) + toast.error(t(`varKeyError.${errorMessageKey}`, { ns: 'appDebug', key: t('variableConfig.varName', { ns: 'appDebug' }) })) return false } return true @@ -219,10 +216,7 @@ const ConfigModal: FC = ({ const value = e.target.value const { isValid, errorKey, errorMessageKey } = checkKeys([value], true) if (!isValid) { - Toast.notify({ - type: 'error', - message: t(`varKeyError.${errorMessageKey}`, { ns: 'appDebug', key: errorKey }), - }) + toast.error(t(`varKeyError.${errorMessageKey}`, { ns: 'appDebug', key: errorKey })) return } handlePayloadChange('variable')(e.target.value) @@ -264,7 +258,7 @@ const ConfigModal: FC = ({ return if (!tempPayload.label) { - Toast.notify({ type: 'error', message: t('variableConfig.errorMsg.labelNameRequired', { ns: 'appDebug' }) }) + toast.error(t('variableConfig.errorMsg.labelNameRequired', { ns: 'appDebug' })) return } if (isStringInput || type === InputVarType.number) { @@ -272,7 +266,7 @@ const ConfigModal: FC = ({ } else if (type === InputVarType.select) { if (options?.length === 0) { - Toast.notify({ type: 'error', message: t('variableConfig.errorMsg.atLeastOneOption', { ns: 'appDebug' }) }) + toast.error(t('variableConfig.errorMsg.atLeastOneOption', { ns: 'appDebug' })) return } const obj: Record = {} @@ -285,7 +279,7 @@ const ConfigModal: FC = ({ obj[o] = true }) if (hasRepeatedItem) { - Toast.notify({ type: 'error', message: t('variableConfig.errorMsg.optionRepeat', { ns: 'appDebug' }) }) + toast.error(t('variableConfig.errorMsg.optionRepeat', { ns: 'appDebug' })) return } onConfirm(payloadToSave, moreInfo) @@ -293,12 +287,12 @@ const ConfigModal: FC = ({ else if ([InputVarType.singleFile, InputVarType.multiFiles].includes(type)) { if (tempPayload.allowed_file_types?.length === 0) { const errorMessages = t('errorMsg.fieldRequired', { ns: 'workflow', field: t('variableConfig.file.supportFileTypes', { ns: 'appDebug' }) }) - Toast.notify({ type: 'error', message: errorMessages }) + toast.error(errorMessages) return } if (tempPayload.allowed_file_types?.includes(SupportUploadFileTypes.custom) && !tempPayload.allowed_file_extensions?.length) { const errorMessages = t('errorMsg.fieldRequired', { ns: 'workflow', field: t('variableConfig.file.custom.name', { ns: 'appDebug' }) }) - Toast.notify({ type: 'error', message: errorMessages }) + toast.error(errorMessages) return } onConfirm(payloadToSave, moreInfo) @@ -308,12 +302,12 @@ const ConfigModal: FC = ({ try { const schema = JSON.parse(normalizedJsonSchema) if (schema?.type !== 'object') { - Toast.notify({ type: 'error', message: t('variableConfig.errorMsg.jsonSchemaMustBeObject', { ns: 'appDebug' }) }) + toast.error(t('variableConfig.errorMsg.jsonSchemaMustBeObject', { ns: 'appDebug' })) return } } catch { - Toast.notify({ type: 'error', message: t('variableConfig.errorMsg.jsonSchemaInvalid', { ns: 'appDebug' }) }) + toast.error(t('variableConfig.errorMsg.jsonSchemaInvalid', { ns: 'appDebug' })) return } } diff --git a/web/app/components/app/configuration/config-var/index.spec.tsx b/web/app/components/app/configuration/config-var/index.spec.tsx index 096358c8058..a48d3233f5b 100644 --- a/web/app/components/app/configuration/config-var/index.spec.tsx +++ b/web/app/components/app/configuration/config-var/index.spec.tsx @@ -5,13 +5,13 @@ import type { PromptVariable } from '@/models/debug' import { act, fireEvent, render, screen, waitFor, within } from '@testing-library/react' import * as React from 'react' import { vi } from 'vitest' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import DebugConfigurationContext from '@/context/debug-configuration' import { AppModeEnum } from '@/types/app' import ConfigVar, { ADD_EXTERNAL_DATA_TOOL } from './index' -const notifySpy = vi.spyOn(Toast, 'notify').mockImplementation(vi.fn()) +const toastErrorSpy = vi.spyOn(toast, 'error').mockReturnValue('toast-error') const setShowExternalDataToolModal = vi.fn() @@ -112,7 +112,7 @@ describe('ConfigVar', () => { latestSortableProps = null subscriptionCallback = null variableIndex = 0 - notifySpy.mockClear() + toastErrorSpy.mockClear() }) it('should show empty state when no variables exist', () => { @@ -152,7 +152,7 @@ describe('ConfigVar', () => { latestSortableProps = null subscriptionCallback = null variableIndex = 0 - notifySpy.mockClear() + toastErrorSpy.mockClear() }) it('should add a text variable when selecting the string option', async () => { @@ -218,7 +218,7 @@ describe('ConfigVar', () => { latestSortableProps = null subscriptionCallback = null variableIndex = 0 - notifySpy.mockClear() + toastErrorSpy.mockClear() }) it('should save updates when editing a basic variable', async () => { @@ -268,7 +268,7 @@ describe('ConfigVar', () => { fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) - expect(Toast.notify).toHaveBeenCalled() + expect(toastErrorSpy).toHaveBeenCalled() expect(onPromptVariablesChange).not.toHaveBeenCalled() }) @@ -294,7 +294,7 @@ describe('ConfigVar', () => { fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) - expect(Toast.notify).toHaveBeenCalled() + expect(toastErrorSpy).toHaveBeenCalled() expect(onPromptVariablesChange).not.toHaveBeenCalled() }) }) @@ -306,7 +306,7 @@ describe('ConfigVar', () => { latestSortableProps = null subscriptionCallback = null variableIndex = 0 - notifySpy.mockClear() + toastErrorSpy.mockClear() }) it('should remove variable directly when context confirmation is not required', () => { @@ -359,7 +359,7 @@ describe('ConfigVar', () => { latestSortableProps = null subscriptionCallback = null variableIndex = 0 - notifySpy.mockClear() + toastErrorSpy.mockClear() }) it('should append external data tool variables from event emitter', () => { diff --git a/web/app/components/app/configuration/config-var/index.tsx b/web/app/components/app/configuration/config-var/index.tsx index 4d9a4e480fc..17f5e2efe53 100644 --- a/web/app/components/app/configuration/config-var/index.tsx +++ b/web/app/components/app/configuration/config-var/index.tsx @@ -12,8 +12,8 @@ import { useTranslation } from 'react-i18next' import { ReactSortable } from 'react-sortablejs' import { useContext } from 'use-context-selector' import Confirm from '@/app/components/base/confirm' -import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' import { InputVarType } from '@/app/components/workflow/types' import ConfigContext from '@/context/debug-configuration' import { useEventEmitterContextContext } from '@/context/event-emitter' @@ -108,10 +108,7 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar }) const duplicateError = getDuplicateError(newPromptVariables) if (duplicateError) { - Toast.notify({ - type: 'error', - message: t(duplicateError.errorMsgKey as I18nKeysByPrefix<'appDebug', 'duplicateError.'>, { ns: 'appDebug', key: t(duplicateError.typeName as I18nKeysByPrefix<'appDebug', 'duplicateError.'>, { ns: 'appDebug' }) }) as string, - }) + toast.error(t(duplicateError.errorMsgKey as I18nKeysByPrefix<'appDebug', 'duplicateError.'>, { ns: 'appDebug', key: t(duplicateError.typeName as I18nKeysByPrefix<'appDebug', 'duplicateError.'>, { ns: 'appDebug' }) }) as string) return false } @@ -161,7 +158,7 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar onValidateBeforeSaveCallback: (newExternalDataTool: ExternalDataTool) => { for (let i = 0; i < promptVariables.length; i++) { if (promptVariables[i].key === newExternalDataTool.variable && i !== index) { - Toast.notify({ type: 'error', message: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i].key }) }) + toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i].key })) return false } } diff --git a/web/app/components/app/configuration/config/agent/prompt-editor.tsx b/web/app/components/app/configuration/config/agent/prompt-editor.tsx index f719d872613..e807c21518e 100644 --- a/web/app/components/app/configuration/config/agent/prompt-editor.tsx +++ b/web/app/components/app/configuration/config/agent/prompt-editor.tsx @@ -12,7 +12,7 @@ import { CopyCheck, } from '@/app/components/base/icons/src/vender/line/files' import PromptEditor from '@/app/components/base/prompt-editor' -import { useToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import ConfigContext from '@/context/debug-configuration' import { useModalContext } from '@/context/modal-context' import { cn } from '@/utils/classnames' @@ -32,8 +32,6 @@ const Editor: FC = ({ }) => { const { t } = useTranslation() - const { notify } = useToastContext() - const [isCopied, setIsCopied] = React.useState(false) const { modelConfig, @@ -59,14 +57,14 @@ const Editor: FC = ({ onValidateBeforeSaveCallback: (newExternalDataTool: ExternalDataTool) => { for (let i = 0; i < promptVariables.length; i++) { if (promptVariables[i].key === newExternalDataTool.variable) { - notify({ type: 'error', message: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i].key }) }) + toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i].key })) return false } } for (let i = 0; i < externalDataToolsConfig.length; i++) { if (externalDataToolsConfig[i].variable === newExternalDataTool.variable) { - notify({ type: 'error', message: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: externalDataToolsConfig[i].variable }) }) + toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: externalDataToolsConfig[i].variable })) return false } } diff --git a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx index 8ad284bcfb8..6c135fdee35 100644 --- a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx +++ b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx @@ -23,9 +23,9 @@ import Button from '@/app/components/base/button' import Confirm from '@/app/components/base/confirm' import { Generator } from '@/app/components/base/icons/src/vender/other' import Loading from '@/app/components/base/loading' - import Modal from '@/app/components/base/modal' -import Toast from '@/app/components/base/toast' + +import { toast } from '@/app/components/base/ui/toast' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' @@ -159,13 +159,10 @@ const GetAutomaticRes: FC = ({ const isValid = () => { if (instruction.trim() === '') { - Toast.notify({ - type: 'error', - message: t('errorMsg.fieldRequired', { - ns: 'common', - field: t('generate.instruction', { ns: 'appDebug' }), - }), - }) + toast.error(t('errorMsg.fieldRequired', { + ns: 'common', + field: t('generate.instruction', { ns: 'appDebug' }), + })) return false } return true @@ -242,10 +239,7 @@ const GetAutomaticRes: FC = ({ } as GenRes if (error) { hasError = true - Toast.notify({ - type: 'error', - message: error, - }) + toast.error(error) } } else { @@ -260,10 +254,7 @@ const GetAutomaticRes: FC = ({ apiRes = res if (error) { hasError = true - Toast.notify({ - type: 'error', - message: error, - }) + toast.error(error) } } if (!hasError) diff --git a/web/app/components/app/configuration/config/automatic/result.tsx b/web/app/components/app/configuration/config/automatic/result.tsx index ef82007e515..776d774bd8e 100644 --- a/web/app/components/app/configuration/config/automatic/result.tsx +++ b/web/app/components/app/configuration/config/automatic/result.tsx @@ -6,7 +6,7 @@ import copy from 'copy-to-clipboard' import * as React from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import CodeEditor from '@/app/components/workflow/nodes/llm/components/json-schema-config-modal/code-editor' import PromptRes from './prompt-res' import PromptResInWorkflow from './prompt-res-in-workflow' @@ -54,7 +54,7 @@ const Result: FC = ({ className="px-2" onClick={() => { copy(current.modified) - Toast.notify({ type: 'success', message: t('actionMsg.copySuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.copySuccessfully', { ns: 'common' })) }} > diff --git a/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx b/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx index a7bc7ab97b7..6bdb59fa173 100644 --- a/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx +++ b/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx @@ -15,7 +15,7 @@ import Confirm from '@/app/components/base/confirm' import { Generator } from '@/app/components/base/icons/src/vender/other' import Loading from '@/app/components/base/loading' import Modal from '@/app/components/base/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' @@ -97,13 +97,10 @@ export const GetCodeGeneratorResModal: FC = ( const isValid = () => { if (instruction.trim() === '') { - Toast.notify({ - type: 'error', - message: t('errorMsg.fieldRequired', { - ns: 'common', - field: t('code.instruction', { ns: 'appDebug' }), - }), - }) + toast.error(t('errorMsg.fieldRequired', { + ns: 'common', + field: t('code.instruction', { ns: 'appDebug' }), + })) return false } return true @@ -149,10 +146,7 @@ export const GetCodeGeneratorResModal: FC = ( res.modified = (res as any).code if (error) { - Toast.notify({ - type: 'error', - message: error, - }) + toast.error(error) } else { addVersion(res) diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.spec.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.spec.tsx index 2cd8418c656..8a53e9a8b0f 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.spec.tsx @@ -5,7 +5,7 @@ import type { DatasetConfigs } from '@/models/debug' import type { RetrievalConfig } from '@/types/app' import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel, @@ -46,7 +46,7 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () const mockedUseModelListAndDefaultModelAndCurrentProviderAndModel = useModelListAndDefaultModelAndCurrentProviderAndModel as MockedFunction const mockedUseCurrentProviderAndModel = useCurrentProviderAndModel as MockedFunction -let toastNotifySpy: MockInstance +let toastErrorSpy: MockInstance const baseRetrievalConfig: RetrievalConfig = { search_method: RETRIEVE_METHOD.semantic, @@ -172,7 +172,7 @@ const createDatasetConfigs = (overrides: Partial = {}): DatasetC describe('ConfigContent', () => { beforeEach(() => { vi.clearAllMocks() - toastNotifySpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({})) + toastErrorSpy = vi.spyOn(toast, 'error').mockReturnValue('toast-error') mockedUseModelListAndDefaultModelAndCurrentProviderAndModel.mockReturnValue({ modelList: [], defaultModel: undefined, @@ -186,7 +186,7 @@ describe('ConfigContent', () => { }) afterEach(() => { - toastNotifySpy.mockRestore() + toastErrorSpy.mockRestore() }) // State management @@ -331,10 +331,7 @@ describe('ConfigContent', () => { await user.click(screen.getByText('common.modelProvider.rerankModel.key')) // Assert - expect(toastNotifySpy).toHaveBeenCalledWith({ - type: 'error', - message: 'workflow.errorMsg.rerankModelRequired', - }) + expect(toastErrorSpy).toHaveBeenCalledWith('workflow.errorMsg.rerankModelRequired') expect(onChange).toHaveBeenCalledWith( expect.objectContaining({ reranking_mode: RerankingModeEnum.RerankingModel, @@ -373,10 +370,7 @@ describe('ConfigContent', () => { await user.click(screen.getByRole('switch')) // Assert - expect(toastNotifySpy).toHaveBeenCalledWith({ - type: 'error', - message: 'workflow.errorMsg.rerankModelRequired', - }) + expect(toastErrorSpy).toHaveBeenCalledWith('workflow.errorMsg.rerankModelRequired') expect(onChange).toHaveBeenCalledWith( expect.objectContaining({ reranking_enable: true, diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 6dd03d217e0..be0d1d9394c 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -15,8 +15,8 @@ import Divider from '@/app/components/base/divider' import ScoreThresholdItem from '@/app/components/base/param-item/score-threshold-item' import TopKItem from '@/app/components/base/param-item/top-k-item' import Switch from '@/app/components/base/switch' -import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' @@ -136,7 +136,7 @@ const ConfigContent: FC = ({ return if (mode === RerankingModeEnum.RerankingModel && !currentRerankModel) - Toast.notify({ type: 'error', message: t('errorMsg.rerankModelRequired', { ns: 'workflow' }) }) + toast.error(t('errorMsg.rerankModelRequired', { ns: 'workflow' })) onChange({ ...datasetConfigs, @@ -179,7 +179,7 @@ const ConfigContent: FC = ({ const handleManuallyToggleRerank = useCallback((enable: boolean) => { if (!currentRerankModel && enable) - Toast.notify({ type: 'error', message: t('errorMsg.rerankModelRequired', { ns: 'workflow' }) }) + toast.error(t('errorMsg.rerankModelRequired', { ns: 'workflow' })) onChange({ ...datasetConfigs, reranking_enable: enable, diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx index ea70725ea81..7fdf9d0a23d 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx @@ -3,7 +3,6 @@ import type { DataSet } from '@/models/datasets' import type { RetrievalConfig } from '@/types/app' import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' -import { ToastContext } from '@/app/components/base/toast/context' import { IndexingType } from '@/app/components/datasets/create/step-two' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' @@ -13,7 +12,24 @@ import { useMembers } from '@/service/use-common' import { RETRIEVE_METHOD } from '@/types/app' import SettingsModal from './index' -const mockNotify = vi.fn() +const toastMocks = vi.hoisted(() => ({ + call: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign(toastMocks.call, { + success: vi.fn((message: string, options?: Record) => toastMocks.call({ type: 'success', message, ...options })), + error: vi.fn((message: string, options?: Record) => toastMocks.call({ type: 'error', message, ...options })), + warning: vi.fn((message: string, options?: Record) => toastMocks.call({ type: 'warning', message, ...options })), + info: vi.fn((message: string, options?: Record) => toastMocks.call({ type: 'info', message, ...options })), + dismiss: toastMocks.dismiss, + update: toastMocks.update, + promise: toastMocks.promise, + }), +})) const mockOnCancel = vi.fn() const mockOnSave = vi.fn() const mockSetShowAccountSettingModal = vi.fn() @@ -183,13 +199,12 @@ const createDataset = (overrides: Partial = {}, retrievalOverrides: Par const renderWithProviders = (dataset: DataSet) => { return render( - - - , + , + ) } @@ -378,7 +393,7 @@ describe('SettingsModal', () => { await user.click(screen.getByRole('button', { name: 'common.operation.save' })) // Assert - expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + expect(toastMocks.call).toHaveBeenCalledWith(expect.objectContaining({ type: 'error', message: 'datasetSettings.form.nameError', })) @@ -402,7 +417,7 @@ describe('SettingsModal', () => { await user.click(screen.getByRole('button', { name: 'common.operation.save' })) // Assert - expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + expect(toastMocks.call).toHaveBeenCalledWith(expect.objectContaining({ type: 'error', message: 'appDebug.datasetConfig.rerankModelRequired', })) @@ -444,7 +459,7 @@ describe('SettingsModal', () => { permission: DatasetPermission.allTeamMembers, }), })) - expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + expect(toastMocks.call).toHaveBeenCalledWith(expect.objectContaining({ type: 'success', message: 'common.actionMsg.modifiedSuccessfully', })) @@ -528,7 +543,7 @@ describe('SettingsModal', () => { // Assert await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'error' })) + expect(toastMocks.call).toHaveBeenCalledWith(expect.objectContaining({ type: 'error' })) }) }) }) diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index bc534599dec..8b2c4270cdc 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Textarea from '@/app/components/base/textarea' -import { useToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' import { IndexingType } from '@/app/components/datasets/create/step-two' import IndexMethod from '@/app/components/datasets/settings/index-method' @@ -51,7 +51,6 @@ const SettingsModal: FC = ({ const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank) const { t } = useTranslation() const docLink = useDocLink() - const { notify } = useToastContext() const ref = useRef(null) const isExternal = currentDataset.provider === 'external' const { setShowAccountSettingModal } = useModalContext() @@ -96,7 +95,7 @@ const SettingsModal: FC = ({ if (loading) return if (!localeCurrentDataset.name?.trim()) { - notify({ type: 'error', message: t('form.nameError', { ns: 'datasetSettings' }) }) + toast.error(t('form.nameError', { ns: 'datasetSettings' })) return } if ( @@ -106,7 +105,7 @@ const SettingsModal: FC = ({ indexMethod, }) ) { - notify({ type: 'error', message: t('datasetConfig.rerankModelRequired', { ns: 'appDebug' }) }) + toast.error(t('datasetConfig.rerankModelRequired', { ns: 'appDebug' })) return } try { @@ -146,7 +145,7 @@ const SettingsModal: FC = ({ }) } await updateDatasetSetting(requestParams) - notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) onSave({ ...localeCurrentDataset, indexing_technique: indexMethod, @@ -154,7 +153,7 @@ const SettingsModal: FC = ({ }) } catch { - notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) } finally { setLoading(false) diff --git a/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx b/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx index a75516a43ff..910a8fd2b58 100644 --- a/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx @@ -386,13 +386,6 @@ vi.mock('@/context/event-emitter', () => ({ })), })) -// Mock toast context -vi.mock('@/app/components/base/toast/context', () => ({ - useToastContext: vi.fn(() => ({ - notify: vi.fn(), - })), -})) - // Mock hooks/use-timestamp vi.mock('@/hooks/use-timestamp', () => ({ default: vi.fn(() => ({ diff --git a/web/app/components/app/configuration/debug/index.spec.tsx b/web/app/components/app/configuration/debug/index.spec.tsx index e94695f1ef1..61fe6730795 100644 --- a/web/app/components/app/configuration/debug/index.spec.tsx +++ b/web/app/components/app/configuration/debug/index.spec.tsx @@ -1,7 +1,6 @@ import type { ComponentProps } from 'react' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' -import { ToastContext } from '@/app/components/base/toast/context' import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import ConfigContext from '@/context/debug-configuration' import { AppModeEnum, ModelModeType, TransferMethod } from '@/types/app' @@ -16,6 +15,10 @@ const mockState = vi.hoisted(() => ({ mockHandleRestart: vi.fn(), mockSetFeatures: vi.fn(), mockEventEmitterEmit: vi.fn(), + mockToastCall: vi.fn(), + mockToastDismiss: vi.fn(), + mockToastUpdate: vi.fn(), + mockToastPromise: vi.fn(), mockText2speechDefaultModel: null as unknown, mockStoreState: { currentLogItem: null as unknown, @@ -43,6 +46,22 @@ const mockState = vi.hoisted(() => ({ }, })) +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign(mockState.mockToastCall, { + success: vi.fn((message: string, options?: Record) => + mockState.mockToastCall({ type: 'success', message, ...options })), + error: vi.fn((message: string, options?: Record) => + mockState.mockToastCall({ type: 'error', message, ...options })), + warning: vi.fn((message: string, options?: Record) => + mockState.mockToastCall({ type: 'warning', message, ...options })), + info: vi.fn((message: string, options?: Record) => + mockState.mockToastCall({ type: 'info', message, ...options })), + dismiss: mockState.mockToastDismiss, + update: mockState.mockToastUpdate, + promise: mockState.mockToastPromise, + }), +})) + vi.mock('@/app/components/app/configuration/debug/chat-user-input', () => ({ default: () =>
ChatUserInput
, })) @@ -215,19 +234,27 @@ vi.mock('./debug-with-multiple-model', () => ({ ), })) -vi.mock('./debug-with-single-model', () => ({ - default: React.forwardRef((props: { checkCanSend: () => boolean }, ref) => { +vi.mock('./debug-with-single-model', () => { + function DebugWithSingleModelMock({ + checkCanSend, + ref, + }: { + checkCanSend: () => boolean + ref?: React.Ref<{ handleRestart: () => void }> + }) { React.useImperativeHandle(ref, () => ({ handleRestart: mockState.mockHandleRestart, })) return (
- +
) - }), -})) + } + + return { default: DebugWithSingleModelMock } +}) const createContextValue = (overrides: Partial = {}): DebugContextValue => ({ readonly: false, @@ -376,7 +403,6 @@ const renderDebug = (options: { props?: Partial } = {}) => { const onSetting = vi.fn() - const notify = vi.fn() const props: ComponentProps = { isAPIKeySet: true, onSetting, @@ -392,14 +418,16 @@ const renderDebug = (options: { } render( - - - - - , + React.createElement( + ConfigContext.Provider, + { + value: createContextValue(options.contextValue), + children: , + }, + ), ) - return { onSetting, notify, props } + return { onSetting, notify: mockState.mockToastCall, props } } describe('Debug', () => { diff --git a/web/app/components/app/configuration/debug/index.tsx b/web/app/components/app/configuration/debug/index.tsx index cd07885f0cb..36cd4c34454 100644 --- a/web/app/components/app/configuration/debug/index.tsx +++ b/web/app/components/app/configuration/debug/index.tsx @@ -29,8 +29,12 @@ import Button from '@/app/components/base/button' import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' import { RefreshCcw01 } from '@/app/components/base/icons/src/vender/line/arrows' import PromptLogModal from '@/app/components/base/prompt-log-modal' -import { ToastContext } from '@/app/components/base/toast/context' -import TooltipPlus from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from '@/app/components/base/ui/tooltip' import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config' @@ -139,22 +143,20 @@ const Debug: FC = ({ setIsShowFormattingChangeConfirm(false) setFormattingChanged(false) } - - const { notify } = useContext(ToastContext) const logError = useCallback((message: string) => { - notify({ type: 'error', message }) - }, [notify]) + toast.error(message) + }, []) const [completionFiles, setCompletionFiles] = useState([]) const checkCanSend = useCallback(() => { if (isAdvancedMode && mode !== AppModeEnum.COMPLETION) { if (modelModeType === ModelModeType.completion) { if (!hasSetBlockStatus.history) { - notify({ type: 'error', message: t('otherError.historyNoBeEmpty', { ns: 'appDebug' }) }) + toast.error(t('otherError.historyNoBeEmpty', { ns: 'appDebug' })) return false } if (!hasSetBlockStatus.query) { - notify({ type: 'error', message: t('otherError.queryNoBeEmpty', { ns: 'appDebug' }) }) + toast.error(t('otherError.queryNoBeEmpty', { ns: 'appDebug' })) return false } } @@ -180,7 +182,7 @@ const Debug: FC = ({ } if (completionFiles.find(item => item.transfer_method === TransferMethod.local_file && !item.upload_file_id)) { - notify({ type: 'info', message: t('errorMessage.waitForFileUpload', { ns: 'appDebug' }) }) + toast.info(t('errorMessage.waitForFileUpload', { ns: 'appDebug' })) return false } return !hasEmptyInput @@ -194,7 +196,6 @@ const Debug: FC = ({ modelConfig.configs.prompt_variables, t, logError, - notify, modelModeType, ]) @@ -205,7 +206,7 @@ const Debug: FC = ({ const sendTextCompletion = async () => { if (isResponding) { - notify({ type: 'info', message: t('errorMessage.waitForResponse', { ns: 'appDebug' }) }) + toast.info(t('errorMessage.waitForResponse', { ns: 'appDebug' })) return false } @@ -420,27 +421,24 @@ const Debug: FC = ({ <> { !readonly && ( - - - - - - + + } /> + + {t('operation.refresh', { ns: 'common' })} + + ) } { varList.length > 0 && (
- - !readonly && setExpanded(!expanded)}> - - - + + !readonly && setExpanded(!expanded)}>} /> + + {t('panel.userInputField', { ns: 'workflow' })} + + {expanded &&
}
) diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index aa1bbe0a163..08df556d8e5 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -26,7 +26,6 @@ import { produce } from 'immer' import * as React from 'react' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import { useShallow } from 'zustand/react/shallow' import AppPublisher from '@/app/components/app/app-publisher/features-wrapper' import Config from '@/app/components/app/configuration/config' @@ -48,8 +47,7 @@ import { FeaturesProvider } from '@/app/components/base/features' import NewFeaturePanel from '@/app/components/base/features/new-feature-panel' import Loading from '@/app/components/base/loading' import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' -import Toast from '@/app/components/base/toast' -import { ToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { @@ -93,7 +91,6 @@ type PublishConfig = { const Configuration: FC = () => { const { t } = useTranslation() - const { notify } = useContext(ToastContext) const { isLoadingCurrentWorkspace, currentWorkspace } = useAppContext() const { appDetail, showAppConfigureFeaturesModal, setAppSidebarExpand, setShowAppConfigureFeaturesModal } = useAppStore(useShallow(state => ({ @@ -492,11 +489,11 @@ const Configuration: FC = () => { isAdvancedMode, ) if (Object.keys(removedDetails).length) - Toast.notify({ type: 'warning', message: `${t('modelProvider.parametersInvalidRemoved', { ns: 'common' })}: ${Object.entries(removedDetails).map(([k, reason]) => `${k} (${reason})`).join(', ')}` }) + toast.warning(`${t('modelProvider.parametersInvalidRemoved', { ns: 'common' })}: ${Object.entries(removedDetails).map(([k, reason]) => `${k} (${reason})`).join(', ')}`) setCompletionParams(filtered) } catch { - Toast.notify({ type: 'error', message: t('error', { ns: 'common' }) }) + toast.error(t('error', { ns: 'common' })) setCompletionParams({}) } } @@ -767,23 +764,23 @@ const Configuration: FC = () => { const promptVariables = modelConfig.configs.prompt_variables if (promptEmpty) { - notify({ type: 'error', message: t('otherError.promptNoBeEmpty', { ns: 'appDebug' }) }) + toast.error(t('otherError.promptNoBeEmpty', { ns: 'appDebug' })) return } if (isAdvancedMode && mode !== AppModeEnum.COMPLETION) { if (modelModeType === ModelModeType.completion) { if (!hasSetBlockStatus.history) { - notify({ type: 'error', message: t('otherError.historyNoBeEmpty', { ns: 'appDebug' }) }) + toast.error(t('otherError.historyNoBeEmpty', { ns: 'appDebug' })) return } if (!hasSetBlockStatus.query) { - notify({ type: 'error', message: t('otherError.queryNoBeEmpty', { ns: 'appDebug' }) }) + toast.error(t('otherError.queryNoBeEmpty', { ns: 'appDebug' })) return } } } if (contextVarEmpty) { - notify({ type: 'error', message: t('feature.dataSet.queryVariable.contextVarNotEmpty', { ns: 'appDebug' }) }) + toast.error(t('feature.dataSet.queryVariable.contextVarNotEmpty', { ns: 'appDebug' })) return } const postDatasets = dataSets.map(({ id }) => ({ @@ -849,7 +846,7 @@ const Configuration: FC = () => { modelConfig: newModelConfig, completionParams, }) - notify({ type: 'success', message: t('api.success', { ns: 'common' }) }) + toast.success(t('api.success', { ns: 'common' })) setCanReturnToSimpleMode(false) return true diff --git a/web/app/components/app/configuration/tools/external-data-tool-modal.tsx b/web/app/components/app/configuration/tools/external-data-tool-modal.tsx index dd7a0c6a6cc..1c9adca1d1f 100644 --- a/web/app/components/app/configuration/tools/external-data-tool-modal.tsx +++ b/web/app/components/app/configuration/tools/external-data-tool-modal.tsx @@ -11,9 +11,9 @@ import Button from '@/app/components/base/button' import EmojiPicker from '@/app/components/base/emoji-picker' import FormGeneration from '@/app/components/base/features/new-feature-panel/moderation/form-generation' import { BookOpen01 } from '@/app/components/base/icons/src/vender/line/education' -import Modal from '@/app/components/base/modal' -import { SimpleSelect } from '@/app/components/base/select' -import { useToastContext } from '@/app/components/base/toast/context' +import { Dialog, DialogContent } from '@/app/components/base/ui/dialog' +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/app/components/base/ui/select' +import { toast } from '@/app/components/base/ui/toast' import ApiBasedExtensionSelector from '@/app/components/header/account-setting/api-based-extension-page/selector' import { useDocLink, useLocale } from '@/context/i18n' import { LanguagesSupported } from '@/i18n-config/language' @@ -39,7 +39,6 @@ const ExternalDataToolModal: FC = ({ }) => { const { t } = useTranslation() const docLink = useDocLink() - const { notify } = useToastContext() const locale = useLocale() const [localeData, setLocaleData] = useState(data.type ? data : { ...data, type: 'api' }) const [showEmojiPicker, setShowEmojiPicker] = useState(false) @@ -133,37 +132,34 @@ const ExternalDataToolModal: FC = ({ const handleSave = () => { if (!localeData.type) { - notify({ type: 'error', message: t('errorMessage.valueOfVarRequired', { ns: 'appDebug', key: t('feature.tools.modal.toolType.title', { ns: 'appDebug' }) }) }) + toast.error(t('errorMessage.valueOfVarRequired', { ns: 'appDebug', key: t('feature.tools.modal.toolType.title', { ns: 'appDebug' }) })) return } if (!localeData.label) { - notify({ type: 'error', message: t('errorMessage.valueOfVarRequired', { ns: 'appDebug', key: t('feature.tools.modal.name.title', { ns: 'appDebug' }) }) }) + toast.error(t('errorMessage.valueOfVarRequired', { ns: 'appDebug', key: t('feature.tools.modal.name.title', { ns: 'appDebug' }) })) return } if (!localeData.variable) { - notify({ type: 'error', message: t('errorMessage.valueOfVarRequired', { ns: 'appDebug', key: t('feature.tools.modal.variableName.title', { ns: 'appDebug' }) }) }) + toast.error(t('errorMessage.valueOfVarRequired', { ns: 'appDebug', key: t('feature.tools.modal.variableName.title', { ns: 'appDebug' }) })) return } if (localeData.variable && !/^[a-z_]\w{0,29}$/i.test(localeData.variable)) { - notify({ type: 'error', message: t('varKeyError.notValid', { ns: 'appDebug', key: t('feature.tools.modal.variableName.title', { ns: 'appDebug' }) }) }) + toast.error(t('varKeyError.notValid', { ns: 'appDebug', key: t('feature.tools.modal.variableName.title', { ns: 'appDebug' }) })) return } if (localeData.type === 'api' && !localeData.config?.api_based_extension_id) { - notify({ type: 'error', message: t('errorMessage.valueOfVarRequired', { ns: 'appDebug', key: locale !== LanguagesSupported[1] ? 'API Extension' : 'API 扩展' }) }) + toast.error(t('errorMessage.valueOfVarRequired', { ns: 'appDebug', key: locale !== LanguagesSupported[1] ? 'API Extension' : 'API 扩展' })) return } if (systemTypes.findIndex(t => t === localeData.type) < 0 && currentProvider?.form_schema) { for (let i = 0; i < currentProvider.form_schema.length; i++) { if (!localeData.config?.[currentProvider.form_schema[i].variable] && currentProvider.form_schema[i].required) { - notify({ - type: 'error', - message: t('errorMessage.valueOfVarRequired', { ns: 'appDebug', key: locale !== LanguagesSupported[1] ? currentProvider.form_schema[i].label['en-US'] : currentProvider.form_schema[i].label['zh-Hans'] }), - }) + toast.error(t('errorMessage.valueOfVarRequired', { ns: 'appDebug', key: locale !== LanguagesSupported[1] ? currentProvider.form_schema[i].label['en-US'] : currentProvider.form_schema[i].label['zh-Hans'] })) return } } @@ -180,122 +176,128 @@ const ExternalDataToolModal: FC = ({ const action = data.type ? t('operation.edit', { ns: 'common' }) : t('operation.add', { ns: 'common' }) return ( - -
- {`${action} ${t('variableConfig.apiBasedVar', { ns: 'appDebug' })}`} -
-
-
- {t('apiBasedExtension.type', { ns: 'common' })} + +
+ {`${action} ${t('variableConfig.apiBasedVar', { ns: 'appDebug' })}`}
- { - return { - value: option.key, - name: option.name, - } - })} - onSelect={item => handleDataTypeChange(item.value as string)} - /> -
-
-
- {t('feature.tools.modal.name.title', { ns: 'appDebug' })} +
+
+ {t('apiBasedExtension.type', { ns: 'common' })} +
+
-
- handleValueChange({ label: e.target.value })} - className="mr-2 block h-9 grow appearance-none rounded-lg bg-components-input-bg-normal px-3 text-sm text-components-input-text-filled outline-none" - placeholder={t('feature.tools.modal.name.placeholder', { ns: 'appDebug' }) || ''} - /> - { setShowEmojiPicker(true) }} - className="!h-9 !w-9 cursor-pointer rounded-lg border-[0.5px] border-components-panel-border" - icon={localeData.icon} - background={localeData.icon_background} - /> -
-
-
-
- {t('feature.tools.modal.variableName.title', { ns: 'appDebug' })} -
- handleValueChange({ variable: e.target.value })} - className="block h-9 w-full appearance-none rounded-lg bg-components-input-bg-normal px-3 text-sm text-components-input-text-filled outline-none" - placeholder={t('feature.tools.modal.variableName.placeholder', { ns: 'appDebug' }) || ''} - /> -
- { - localeData.type === 'api' && ( -
-
- {t('apiBasedExtension.selector.title', { ns: 'common' })} - - - {t('apiBasedExtension.link', { ns: 'common' })} - -
- +
+ {t('feature.tools.modal.name.title', { ns: 'appDebug' })} +
+
+ handleValueChange({ label: e.target.value })} + className="mr-2 block h-9 grow appearance-none rounded-lg bg-components-input-bg-normal px-3 text-sm text-components-input-text-filled outline-none" + placeholder={t('feature.tools.modal.name.placeholder', { ns: 'appDebug' }) || ''} + /> + { setShowEmojiPicker(true) }} + className="!h-9 !w-9 cursor-pointer rounded-lg border-[0.5px] border-components-panel-border" + icon={localeData.icon} + background={localeData.icon_background} />
- ) - } - { - systemTypes.findIndex(t => t === localeData.type) < 0 - && currentProvider?.form_schema - && ( - +
+
+ {t('feature.tools.modal.variableName.title', { ns: 'appDebug' })} +
+ handleValueChange({ variable: e.target.value })} + className="block h-9 w-full appearance-none rounded-lg bg-components-input-bg-normal px-3 text-sm text-components-input-text-filled outline-none" + placeholder={t('feature.tools.modal.variableName.placeholder', { ns: 'appDebug' }) || ''} /> - ) - } -
- - -
- { - showEmojiPicker && ( - { - handleValueChange({ icon, icon_background }) - setShowEmojiPicker(false) - }} - onClose={() => { - handleValueChange({ icon: '', icon_background: '' }) - setShowEmojiPicker(false) - }} - /> - ) - } - +
+ { + localeData.type === 'api' && ( +
+
+ {t('apiBasedExtension.selector.title', { ns: 'common' })} + + + {t('apiBasedExtension.link', { ns: 'common' })} + +
+ +
+ ) + } + { + systemTypes.findIndex(t => t === localeData.type) < 0 + && currentProvider?.form_schema + && ( + + ) + } +
+ + +
+ { + showEmojiPicker && ( + { + handleValueChange({ icon, icon_background }) + setShowEmojiPicker(false) + }} + onClose={() => { + handleValueChange({ icon: '', icon_background: '' }) + setShowEmojiPicker(false) + }} + /> + ) + } + + ) } diff --git a/web/app/components/app/configuration/tools/index.tsx b/web/app/components/app/configuration/tools/index.tsx index 51a9e87a973..8ab71c73cf5 100644 --- a/web/app/components/app/configuration/tools/index.tsx +++ b/web/app/components/app/configuration/tools/index.tsx @@ -5,7 +5,6 @@ import { RiDeleteBinLine, } from '@remixicon/react' import copy from 'copy-to-clipboard' -// abandoned import { useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' @@ -15,14 +14,17 @@ import { } from '@/app/components/base/icons/src/vender/line/general' import { Tool03 } from '@/app/components/base/icons/src/vender/solid/general' import Switch from '@/app/components/base/switch' -import { useToastContext } from '@/app/components/base/toast/context' -import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from '@/app/components/base/ui/tooltip' import ConfigContext from '@/context/debug-configuration' import { useModalContext } from '@/context/modal-context' const Tools = () => { const { t } = useTranslation() - const { notify } = useToastContext() const { setShowExternalDataToolModal } = useModalContext() const { externalDataToolsConfig, @@ -48,7 +50,7 @@ const Tools = () => { const promptVariables = modelConfig?.configs?.prompt_variables || [] for (let i = 0; i < promptVariables.length; i++) { if (promptVariables[i].key === newExternalDataTool.variable) { - notify({ type: 'error', message: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i].key }) }) + toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: promptVariables[i].key })) return false } } @@ -66,7 +68,7 @@ const Tools = () => { for (let i = 0; i < existedExternalDataTools.length; i++) { if (existedExternalDataTools[i].variable === newExternalDataTool.variable) { - notify({ type: 'error', message: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: existedExternalDataTools[i].variable }) }) + toast.error(t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: existedExternalDataTools[i].variable })) return false } } @@ -110,13 +112,14 @@ const Tools = () => {
{t('feature.tools.title', { ns: 'appDebug' })}
- + } /> +
{t('feature.tools.tips', { ns: 'appDebug' })}
- )} - /> +
+
{ !expanded && !!externalDataToolsConfig.length && ( @@ -151,18 +154,23 @@ const Tools = () => { background={item.icon_background} />
{item.label}
- -
{ - copy(item.variable || '') - setCopied(true) - }} - > - {item.variable} -
+ + { + copy(item.variable || '') + setCopied(true) + }} + > + {item.variable} +
+ )} + /> + + {copied ? t('copied', { ns: 'appApi' }) : `${item.variable}, ${t('copy', { ns: 'appApi' })}`} +
({ vi.mock('@/service/apps', () => ({ createApp: vi.fn(), })) +const toastMocks = vi.hoisted(() => ({ + mockToastSuccess: vi.fn(), + mockToastError: vi.fn(), +})) +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: toastMocks.mockToastSuccess, + error: toastMocks.mockToastError, + }, +})) vi.mock('@/utils/app-redirection', () => ({ getRedirection: vi.fn(), })) @@ -48,7 +57,6 @@ vi.mock('@/hooks/use-theme', () => ({ default: () => ({ theme: 'light' }), })) -const mockNotify = vi.fn() const mockUseRouter = vi.mocked(useRouter) const mockPush = vi.fn() const mockCreateApp = vi.mocked(createApp) @@ -56,6 +64,7 @@ const mockTrackEvent = vi.mocked(trackEvent) const mockGetRedirection = vi.mocked(getRedirection) const mockUseProviderContext = vi.mocked(useProviderContext) const mockUseAppContext = vi.mocked(useAppContext) +const { mockToastSuccess, mockToastError } = toastMocks const defaultPlanUsage = { buildApps: 0, @@ -70,11 +79,7 @@ const defaultPlanUsage = { const renderModal = () => { const onClose = vi.fn() const onSuccess = vi.fn() - render( - - - , - ) + render() return { onClose, onSuccess } } @@ -140,7 +145,7 @@ describe('CreateAppModal', () => { app_mode: AppModeEnum.ADVANCED_CHAT, description: '', }) - expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.newApp.appCreated' }) + expect(mockToastSuccess).toHaveBeenCalledWith('app.newApp.appCreated') expect(onSuccess).toHaveBeenCalled() expect(onClose).toHaveBeenCalled() await waitFor(() => expect(mockSetItem).toHaveBeenCalledWith(NEED_REFRESH_APP_LIST_KEY, '1')) @@ -156,7 +161,7 @@ describe('CreateAppModal', () => { fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ })) await waitFor(() => expect(mockCreateApp).toHaveBeenCalled()) - expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' }) + expect(mockToastError).toHaveBeenCalledWith('boom') expect(onClose).not.toHaveBeenCalled() }) }) diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index 556773c3411..8750b732b14 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -6,7 +6,6 @@ import { RiArrowRightLine, RiArrowRightSLine, RiExchange2Fill } from '@remixicon import { useDebounceFn, useKeyPress } from 'ahooks' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import { trackEvent } from '@/app/components/base/amplitude' import AppIcon from '@/app/components/base/app-icon' import Button from '@/app/components/base/button' @@ -15,7 +14,7 @@ import FullScreenModal from '@/app/components/base/fullscreen-modal' import { BubbleTextMod, ChatBot, ListSparkle, Logic } from '@/app/components/base/icons/src/vender/solid/communication' import Input from '@/app/components/base/input' import Textarea from '@/app/components/base/textarea' -import { ToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' @@ -40,7 +39,6 @@ type CreateAppProps = { function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }: CreateAppProps) { const { t } = useTranslation() const { push } = useRouter() - const { notify } = useContext(ToastContext) const [appMode, setAppMode] = useState(defaultAppMode || AppModeEnum.ADVANCED_CHAT) const [appIcon, setAppIcon] = useState({ type: 'emoji', icon: '🤖', background: '#FFEAD5' }) @@ -62,11 +60,11 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }: const onCreate = useCallback(async () => { if (!appMode) { - notify({ type: 'error', message: t('newApp.appTypeRequired', { ns: 'app' }) }) + toast.error(t('newApp.appTypeRequired', { ns: 'app' })) return } if (!name.trim()) { - notify({ type: 'error', message: t('newApp.nameNotEmpty', { ns: 'app' }) }) + toast.error(t('newApp.nameNotEmpty', { ns: 'app' })) return } if (isCreatingRef.current) @@ -88,20 +86,17 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }: description, }) - notify({ type: 'success', message: t('newApp.appCreated', { ns: 'app' }) }) + toast.success(t('newApp.appCreated', { ns: 'app' })) onSuccess() onClose() localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') getRedirection(isCurrentWorkspaceEditor, app, push) } catch (e: any) { - notify({ - type: 'error', - message: e.message || t('newApp.appCreateFailed', { ns: 'app' }), - }) + toast.error(e.message || t('newApp.appCreateFailed', { ns: 'app' })) } isCreatingRef.current = false - }, [name, notify, t, appMode, appIcon, description, onSuccess, onClose, push, isCurrentWorkspaceEditor]) + }, [name, t, appMode, appIcon, description, onSuccess, onClose, push, isCurrentWorkspaceEditor]) const { run: handleCreateApp } = useDebounceFn(onCreate, { wait: 300 }) useKeyPress(['meta.enter', 'ctrl.enter'], () => { diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index eaaee509733..dd17655e3cb 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -6,12 +6,11 @@ import { useDebounceFn, useKeyPress } from 'ahooks' import { noop } from 'es-toolkit/function' import { useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import { trackEvent } from '@/app/components/base/amplitude' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' @@ -48,7 +47,6 @@ export enum CreateFromDSLModalTab { const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDSLModalTab.FROM_FILE, dslUrl = '', droppedFile }: CreateFromDSLModalProps) => { const { push } = useRouter() const { t } = useTranslation() - const { notify } = useContext(ToastContext) const [currentFile, setDSLFile] = useState(droppedFile) const [fileContent, setFileContent] = useState() const [currentTab, setCurrentTab] = useState(activeTab) @@ -126,10 +124,11 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS if (onClose) onClose() - notify({ + toast(t(status === DSLImportStatus.COMPLETED ? 'newApp.appCreated' : 'newApp.caution', { ns: 'app' }), { type: status === DSLImportStatus.COMPLETED ? 'success' : 'warning', - message: t(status === DSLImportStatus.COMPLETED ? 'newApp.appCreated' : 'newApp.caution', { ns: 'app' }), - children: status === DSLImportStatus.COMPLETED_WITH_WARNINGS && t('newApp.appCreateDSLWarning', { ns: 'app' }), + description: status === DSLImportStatus.COMPLETED_WITH_WARNINGS + ? t('newApp.appCreateDSLWarning', { ns: 'app' }) + : undefined, }) localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') if (app_id) @@ -147,12 +146,12 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS setImportId(id) } else { - notify({ type: 'error', message: t('newApp.appCreateFailed', { ns: 'app' }) }) + toast.error(t('newApp.appCreateFailed', { ns: 'app' })) } } // eslint-disable-next-line unused-imports/no-unused-vars catch (e) { - notify({ type: 'error', message: t('newApp.appCreateFailed', { ns: 'app' }) }) + toast.error(t('newApp.appCreateFailed', { ns: 'app' })) } isCreatingRef.current = false } @@ -185,22 +184,19 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS if (onClose) onClose() - notify({ - type: 'success', - message: t('newApp.appCreated', { ns: 'app' }), - }) + toast.success(t('newApp.appCreated', { ns: 'app' })) if (app_id) await handleCheckPluginDependencies(app_id) localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') getRedirection(isCurrentWorkspaceEditor, { id: app_id!, mode: app_mode }, push) } else if (status === DSLImportStatus.FAILED) { - notify({ type: 'error', message: t('newApp.appCreateFailed', { ns: 'app' }) }) + toast.error(t('newApp.appCreateFailed', { ns: 'app' })) } } // eslint-disable-next-line unused-imports/no-unused-vars catch (e) { - notify({ type: 'error', message: t('newApp.appCreateFailed', { ns: 'app' }) }) + toast.error(t('newApp.appCreateFailed', { ns: 'app' })) } } diff --git a/web/app/components/app/create-from-dsl-modal/uploader.tsx b/web/app/components/app/create-from-dsl-modal/uploader.tsx index 74c8e5f48ec..3dcab1c6d6f 100644 --- a/web/app/components/app/create-from-dsl-modal/uploader.tsx +++ b/web/app/components/app/create-from-dsl-modal/uploader.tsx @@ -7,10 +7,9 @@ import { import * as React from 'react' import { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import ActionButton from '@/app/components/base/action-button' import { Yaml as YamlIcon } from '@/app/components/base/icons/src/public/files' -import { ToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import { cn } from '@/utils/classnames' import { formatFileSize } from '@/utils/format' @@ -30,7 +29,6 @@ const Uploader: FC = ({ displayName = 'YAML', }) => { const { t } = useTranslation() - const { notify } = useContext(ToastContext) const [dragging, setDragging] = useState(false) const dropRef = useRef(null) const dragRef = useRef(null) @@ -60,7 +58,7 @@ const Uploader: FC = ({ return const files = Array.from(e.dataTransfer.files) if (files.length > 1) { - notify({ type: 'error', message: t('stepOne.uploader.validation.count', { ns: 'datasetCreation' }) }) + toast.error(t('stepOne.uploader.validation.count', { ns: 'datasetCreation' })) return } updateFile(files[0]) diff --git a/web/app/components/app/duplicate-modal/index.spec.tsx b/web/app/components/app/duplicate-modal/index.spec.tsx index ef126465715..e70329a1052 100644 --- a/web/app/components/app/duplicate-modal/index.spec.tsx +++ b/web/app/components/app/duplicate-modal/index.spec.tsx @@ -2,7 +2,7 @@ import type { ProviderContextState } from '@/context/provider-context' import { render, screen } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { Plan } from '@/app/components/billing/type' import { baseProviderContextValue } from '@/context/provider-context' import DuplicateAppModal from './index' @@ -129,7 +129,7 @@ describe('DuplicateAppModal', () => { it('should show error toast when name is empty', async () => { const user = userEvent.setup() - const toastSpy = vi.spyOn(Toast, 'notify') + const toastSpy = vi.spyOn(toast, 'error').mockReturnValue('toast-error') // Arrange const { onConfirm, onHide } = renderComponent() @@ -138,7 +138,7 @@ describe('DuplicateAppModal', () => { await user.click(screen.getByRole('button', { name: 'app.duplicate' })) // Assert - expect(toastSpy).toHaveBeenCalledWith({ type: 'error', message: 'explore.appCustomize.nameRequired' }) + expect(toastSpy).toHaveBeenCalledWith('explore.appCustomize.nameRequired') expect(onConfirm).not.toHaveBeenCalled() expect(onHide).not.toHaveBeenCalled() }) diff --git a/web/app/components/app/duplicate-modal/index.tsx b/web/app/components/app/duplicate-modal/index.tsx index 7d5b122f699..b2ba7f1d0f7 100644 --- a/web/app/components/app/duplicate-modal/index.tsx +++ b/web/app/components/app/duplicate-modal/index.tsx @@ -9,7 +9,7 @@ import AppIcon from '@/app/components/base/app-icon' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { useProviderContext } from '@/context/provider-context' import { cn } from '@/utils/classnames' @@ -57,7 +57,7 @@ const DuplicateAppModal = ({ const submit = () => { if (!name.trim()) { - Toast.notify({ type: 'error', message: t('appCustomize.nameRequired', { ns: 'explore' }) }) + toast.error(t('appCustomize.nameRequired', { ns: 'explore' })) return } onConfirm({ diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 453c7c9d4c8..4a22a0c85fb 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -30,8 +30,8 @@ import Drawer from '@/app/components/base/drawer' import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' import Loading from '@/app/components/base/loading' import MessageLogModal from '@/app/components/base/message-log-modal' -import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' import { WorkflowContextProvider } from '@/app/components/workflow/context' import { useAppContext } from '@/context/app-context' @@ -223,7 +223,6 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { const { userProfile: { timezone } } = useAppContext() const { formatTime } = useTimestamp() const { onClose, appDetail } = useContext(DrawerContext) - const { notify } = useContext(ToastContext) const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, showPromptLogModal, setShowPromptLogModal, currentLogModalActiveTab } = useAppStore(useShallow((state: AppStoreState) => ({ currentLogItem: state.currentLogItem, setCurrentLogItem: state.setCurrentLogItem, @@ -413,14 +412,14 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { return item })) - notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) return true } catch { - notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) return false } - }, [allChatItems, appDetail?.id, notify, t]) + }, [allChatItems, appDetail?.id, t]) const fetchInitiated = useRef(false) @@ -734,7 +733,6 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { const CompletionConversationDetailComp: FC<{ appId?: string, conversationId?: string }> = ({ appId, conversationId }) => { // Text Generator App Session Details Including Message List const { data: conversationDetail, refetch: conversationDetailMutate } = useCompletionConversationDetail(appId, conversationId) - const { notify } = useContext(ToastContext) const { t } = useTranslation() const handleFeedback = async (mid: string, { rating, content }: FeedbackType): Promise => { @@ -744,11 +742,11 @@ const CompletionConversationDetailComp: FC<{ appId?: string, conversationId?: st body: { message_id: mid, rating, content: content ?? undefined }, }) conversationDetailMutate() - notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) return true } catch { - notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) return false } } @@ -757,11 +755,11 @@ const CompletionConversationDetailComp: FC<{ appId?: string, conversationId?: st try { await updateLogMessageAnnotations({ url: `/apps/${appId}/annotations`, body: { message_id: mid, content: value } }) conversationDetailMutate() - notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) return true } catch { - notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) return false } } @@ -783,7 +781,6 @@ const CompletionConversationDetailComp: FC<{ appId?: string, conversationId?: st */ const ChatConversationDetailComp: FC<{ appId?: string, conversationId?: string }> = ({ appId, conversationId }) => { const { data: conversationDetail } = useChatConversationDetail(appId, conversationId) - const { notify } = useContext(ToastContext) const { t } = useTranslation() const handleFeedback = async (mid: string, { rating, content }: FeedbackType): Promise => { @@ -792,11 +789,11 @@ const ChatConversationDetailComp: FC<{ appId?: string, conversationId?: string } url: `/apps/${appId}/feedbacks`, body: { message_id: mid, rating, content: content ?? undefined }, }) - notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) return true } catch { - notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) return false } } @@ -804,11 +801,11 @@ const ChatConversationDetailComp: FC<{ appId?: string, conversationId?: string } const handleAnnotation = async (mid: string, value: string): Promise => { try { await updateLogMessageAnnotations({ url: `/apps/${appId}/annotations`, body: { message_id: mid, content: value } }) - notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) return true } catch { - notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) return false } } diff --git a/web/app/components/app/overview/settings/index.spec.tsx b/web/app/components/app/overview/settings/index.spec.tsx index e933855ca8d..d6f9612f751 100644 --- a/web/app/components/app/overview/settings/index.spec.tsx +++ b/web/app/components/app/overview/settings/index.spec.tsx @@ -29,7 +29,24 @@ vi.mock('react-i18next', async () => { } }) -const mockNotify = vi.fn() +const toastMocks = vi.hoisted(() => ({ + call: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign(toastMocks.call, { + success: vi.fn((message: string, options?: Record) => toastMocks.call({ type: 'success', message, ...options })), + error: vi.fn((message: string, options?: Record) => toastMocks.call({ type: 'error', message, ...options })), + warning: vi.fn((message: string, options?: Record) => toastMocks.call({ type: 'warning', message, ...options })), + info: vi.fn((message: string, options?: Record) => toastMocks.call({ type: 'info', message, ...options })), + dismiss: toastMocks.dismiss, + update: toastMocks.update, + promise: toastMocks.promise, + }), +})) const mockOnClose = vi.fn() const mockOnSave = vi.fn() const mockSetShowPricingModal = vi.fn() @@ -56,13 +73,6 @@ vi.mock('@/context/modal-context', () => ({ useModalContext: () => buildModalContext(), })) -vi.mock('@/app/components/base/toast/context', () => ({ - useToastContext: () => ({ - notify: mockNotify, - close: vi.fn(), - }), -})) - vi.mock('@/context/i18n', async () => { const actual = await vi.importActual('@/context/i18n') return { @@ -112,7 +122,7 @@ const renderSettingsModal = () => render( describe('SettingsModal', () => { beforeEach(() => { - mockNotify.mockClear() + toastMocks.call.mockClear() mockOnClose.mockClear() mockOnSave.mockClear() mockSetShowPricingModal.mockClear() @@ -152,7 +162,7 @@ describe('SettingsModal', () => { fireEvent.click(screen.getByText('common.operation.save')) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ message: 'app.newApp.nameNotEmpty' })) + expect(toastMocks.call).toHaveBeenCalledWith(expect.objectContaining({ message: 'app.newApp.nameNotEmpty' })) }) expect(mockOnSave).not.toHaveBeenCalled() }) @@ -164,7 +174,7 @@ describe('SettingsModal', () => { fireEvent.click(screen.getByText('common.operation.save')) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + expect(toastMocks.call).toHaveBeenCalledWith(expect.objectContaining({ message: 'appOverview.overview.appInfo.settings.invalidHexMessage', })) }) @@ -180,7 +190,7 @@ describe('SettingsModal', () => { fireEvent.click(screen.getByText('common.operation.save')) await waitFor(() => { - expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + expect(toastMocks.call).toHaveBeenCalledWith(expect.objectContaining({ message: 'appOverview.overview.appInfo.settings.invalidPrivacyPolicy', })) }) diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 13dacde4245..0d77d32ec41 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -19,8 +19,8 @@ import PremiumBadge from '@/app/components/base/premium-badge' import { SimpleSelect } from '@/app/components/base/select' import Switch from '@/app/components/base/switch' import Textarea from '@/app/components/base/textarea' -import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' @@ -65,7 +65,6 @@ const SettingsModal: FC = ({ onClose, onSave, }) => { - const { notify } = useToastContext() const [isShowMore, setIsShowMore] = useState(false) const { title, @@ -159,7 +158,7 @@ const SettingsModal: FC = ({ const onClickSave = async () => { if (!inputInfo.title) { - notify({ type: 'error', message: t('newApp.nameNotEmpty', { ns: 'app' }) }) + toast.error(t('newApp.nameNotEmpty', { ns: 'app' })) return } @@ -181,11 +180,11 @@ const SettingsModal: FC = ({ if (inputInfo !== null) { if (!validateColorHex(inputInfo.chatColorTheme)) { - notify({ type: 'error', message: t(`${prefixSettings}.invalidHexMessage`, { ns: 'appOverview' }) }) + toast.error(t(`${prefixSettings}.invalidHexMessage`, { ns: 'appOverview' })) return } if (!validatePrivacyPolicy(inputInfo.privacyPolicy)) { - notify({ type: 'error', message: t(`${prefixSettings}.invalidPrivacyPolicy`, { ns: 'appOverview' }) }) + toast.error(t(`${prefixSettings}.invalidPrivacyPolicy`, { ns: 'appOverview' })) return } } diff --git a/web/app/components/app/switch-app-modal/index.spec.tsx b/web/app/components/app/switch-app-modal/index.spec.tsx index 53007b986b2..147edeb5edd 100644 --- a/web/app/components/app/switch-app-modal/index.spec.tsx +++ b/web/app/components/app/switch-app-modal/index.spec.tsx @@ -3,7 +3,6 @@ import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' import { useStore as useAppStore } from '@/app/components/app/store' -import { ToastContext } from '@/app/components/base/toast/context' import { Plan } from '@/app/components/billing/type' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { AppModeEnum } from '@/types/app' @@ -108,27 +107,44 @@ const createMockApp = (overrides: Partial = {}): App => ({ ...overrides, }) +const toastMocks = vi.hoisted(() => ({ + notify: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: (message: string, options?: Record) => toastMocks.notify({ type: 'success', message, ...options }), + error: (message: string, options?: Record) => toastMocks.notify({ type: 'error', message, ...options }), + warning: (message: string, options?: Record) => toastMocks.notify({ type: 'warning', message, ...options }), + info: (message: string, options?: Record) => toastMocks.notify({ type: 'info', message, ...options }), + dismiss: toastMocks.dismiss, + update: toastMocks.update, + promise: toastMocks.promise, + }, +})) + const renderComponent = (overrides: Partial> = {}) => { - const notify = vi.fn() const onClose = vi.fn() const onSuccess = vi.fn() const appDetail = createMockApp() const utils = render( - - - , + , + ) return { ...utils, - notify, + notify: toastMocks.notify, onClose, onSuccess, appDetail, diff --git a/web/app/components/app/switch-app-modal/index.tsx b/web/app/components/app/switch-app-modal/index.tsx index 7c3269d52cf..ffa5dc6ef46 100644 --- a/web/app/components/app/switch-app-modal/index.tsx +++ b/web/app/components/app/switch-app-modal/index.tsx @@ -5,7 +5,6 @@ import { RiCloseLine } from '@remixicon/react' import { noop } from 'es-toolkit/function' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import { useStore as useAppStore } from '@/app/components/app/store' import AppIcon from '@/app/components/base/app-icon' import Button from '@/app/components/base/button' @@ -14,7 +13,7 @@ import Confirm from '@/app/components/base/confirm' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' @@ -37,7 +36,6 @@ type SwitchAppModalProps = { const SwitchAppModal = ({ show, appDetail, inAppDetail = false, onSuccess, onClose }: SwitchAppModalProps) => { const { push, replace } = useRouter() const { t } = useTranslation() - const { notify } = useContext(ToastContext) const setAppDetail = useAppStore(s => s.setAppDetail) const { isCurrentWorkspaceEditor } = useAppContext() @@ -68,7 +66,7 @@ const SwitchAppModal = ({ show, appDetail, inAppDetail = false, onSuccess, onClo onSuccess() if (onClose) onClose() - notify({ type: 'success', message: t('newApp.appCreated', { ns: 'app' }) }) + toast.success(t('newApp.appCreated', { ns: 'app' })) if (inAppDetail) setAppDetail() if (removeOriginal) @@ -84,7 +82,7 @@ const SwitchAppModal = ({ show, appDetail, inAppDetail = false, onSuccess, onClo ) } catch { - notify({ type: 'error', message: t('newApp.appCreateFailed', { ns: 'app' }) }) + toast.error(t('newApp.appCreateFailed', { ns: 'app' })) } } diff --git a/web/app/components/app/text-generate/item/index.tsx b/web/app/components/app/text-generate/item/index.tsx index d22375a2926..ab96077f679 100644 --- a/web/app/components/app/text-generate/item/index.tsx +++ b/web/app/components/app/text-generate/item/index.tsx @@ -28,7 +28,7 @@ import { useChatContext } from '@/app/components/base/chat/chat/context' import Loading from '@/app/components/base/loading' import { Markdown } from '@/app/components/base/markdown' import NewAudioButton from '@/app/components/base/new-audio-button' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useParams } from '@/next/navigation' import { fetchTextGenerationMessage } from '@/service/debug' import { AppSourceType, fetchMoreLikeThis, submitHumanInputForm, updateFeedback } from '@/service/share' @@ -145,7 +145,7 @@ const GenerationItem: FC = ({ const handleMoreLikeThis = async () => { if (isQuerying || !messageId) { - Toast.notify({ type: 'warning', message: t('errorMessage.waitForResponse', { ns: 'appDebug' }) }) + toast.warning(t('errorMessage.waitForResponse', { ns: 'appDebug' })) return } startQuerying() @@ -366,7 +366,7 @@ const GenerationItem: FC = ({ copy(copyContent) else copy(JSON.stringify(copyContent)) - Toast.notify({ type: 'success', message: t('actionMsg.copySuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.copySuccessfully', { ns: 'common' })) }} > diff --git a/web/app/components/app/text-generate/saved-items/index.spec.tsx b/web/app/components/app/text-generate/saved-items/index.spec.tsx index b45a1cca6c6..dff0950f897 100644 --- a/web/app/components/app/text-generate/saved-items/index.spec.tsx +++ b/web/app/components/app/text-generate/saved-items/index.spec.tsx @@ -4,7 +4,7 @@ import copy from 'copy-to-clipboard' import * as React from 'react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import SavedItems from './index' vi.mock('copy-to-clipboard', () => ({ @@ -16,7 +16,7 @@ vi.mock('@/next/navigation', () => ({ })) const mockCopy = vi.mocked(copy) -const toastNotifySpy = vi.spyOn(Toast, 'notify') +const toastSuccessSpy = vi.spyOn(toast, 'success').mockReturnValue('toast-success') const baseProps: ISavedItemsProps = { list: [ @@ -30,7 +30,7 @@ const baseProps: ISavedItemsProps = { describe('SavedItems', () => { beforeEach(() => { vi.clearAllMocks() - toastNotifySpy.mockClear() + toastSuccessSpy.mockClear() }) it('renders saved answers with metadata and controls', () => { @@ -58,7 +58,7 @@ describe('SavedItems', () => { fireEvent.click(copyButton) expect(mockCopy).toHaveBeenCalledWith('hello world') - expect(toastNotifySpy).toHaveBeenCalledWith({ type: 'success', message: 'common.actionMsg.copySuccessfully' }) + expect(toastSuccessSpy).toHaveBeenCalledWith('common.actionMsg.copySuccessfully') fireEvent.click(deleteButton) expect(handleRemove).toHaveBeenCalledWith('1') diff --git a/web/app/components/app/text-generate/saved-items/index.tsx b/web/app/components/app/text-generate/saved-items/index.tsx index 36006402c4f..cd43f354f32 100644 --- a/web/app/components/app/text-generate/saved-items/index.tsx +++ b/web/app/components/app/text-generate/saved-items/index.tsx @@ -11,7 +11,7 @@ import { useTranslation } from 'react-i18next' import ActionButton from '@/app/components/base/action-button' import { Markdown } from '@/app/components/base/markdown' import NewAudioButton from '@/app/components/base/new-audio-button' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { cn } from '@/utils/classnames' import NoData from './no-data' @@ -60,7 +60,7 @@ const SavedItems: FC = ({ {isShowTextToSpeech && } { copy(answer) - Toast.notify({ type: 'success', message: t('actionMsg.copySuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.copySuccessfully', { ns: 'common' })) }} > diff --git a/web/app/components/apps/__tests__/app-card.spec.tsx b/web/app/components/apps/__tests__/app-card.spec.tsx index 86c87e0c5bc..d1e89b7a850 100644 --- a/web/app/components/apps/__tests__/app-card.spec.tsx +++ b/web/app/components/apps/__tests__/app-card.spec.tsx @@ -17,16 +17,36 @@ vi.mock('@/next/navigation', () => ({ }), })) -// Mock use-context-selector with stable mockNotify reference for tracking calls +const toastMocks = vi.hoisted(() => { + const record = vi.fn() + const api = vi.fn((message: unknown, options?: Record) => record({ message, ...options })) + return { + record, + api: Object.assign(api, { + success: vi.fn((message: unknown, options?: Record) => record({ type: 'success', message, ...options })), + error: vi.fn((message: unknown, options?: Record) => record({ type: 'error', message, ...options })), + warning: vi.fn((message: unknown, options?: Record) => record({ type: 'warning', message, ...options })), + info: vi.fn((message: unknown, options?: Record) => record({ type: 'info', message, ...options })), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), + } +}) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: toastMocks.api, +})) + +// Mock use-context-selector with stable toast reference for tracking calls // Include createContext for components that use it (like Toast) -const mockNotify = vi.fn() vi.mock('use-context-selector', () => ({ createContext: (defaultValue: T) => React.createContext(defaultValue), useContext: () => ({ - notify: mockNotify, + notify: toastMocks.api, }), useContextSelector: (_context: unknown, selector: (state: Record) => unknown) => selector({ - notify: mockNotify, + notify: toastMocks.api, }), })) @@ -591,7 +611,7 @@ describe('AppCard', () => { await waitFor(() => { expect(mockDeleteAppMutation).toHaveBeenCalled() - expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: expect.stringContaining('Delete failed') }) + expect(toastMocks.record).toHaveBeenCalledWith({ type: 'error', message: expect.stringContaining('Delete failed') }) }) }) @@ -670,7 +690,7 @@ describe('AppCard', () => { await waitFor(() => { expect(appsService.copyApp).toHaveBeenCalled() - expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.newApp.appCreateFailed' }) + expect(toastMocks.record).toHaveBeenCalledWith({ type: 'error', message: 'app.newApp.appCreateFailed' }) }) }) @@ -699,7 +719,7 @@ describe('AppCard', () => { await waitFor(() => { expect(appsService.exportAppConfig).toHaveBeenCalled() - expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) + expect(toastMocks.record).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) }) }) }) @@ -945,7 +965,7 @@ describe('AppCard', () => { await waitFor(() => { expect(appsService.updateAppInfo).toHaveBeenCalled() - expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: expect.stringContaining('Edit failed') }) + expect(toastMocks.record).toHaveBeenCalledWith({ type: 'error', message: expect.stringContaining('Edit failed') }) }) }) @@ -998,7 +1018,7 @@ describe('AppCard', () => { await waitFor(() => { expect(workflowService.fetchWorkflowDraft).toHaveBeenCalled() - expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) + expect(toastMocks.record).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) }) }) }) diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index 9a8abf64439..c1131ad2d44 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -10,14 +10,11 @@ import { RiBuildingLine, RiGlobalLine, RiLockLine, RiMoreFill, RiVerifiedBadgeLi import * as React from 'react' import { useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import { AppTypeIcon } from '@/app/components/app/type-selector' import AppIcon from '@/app/components/base/app-icon' import Divider from '@/app/components/base/divider' import CustomPopover from '@/app/components/base/popover' import TagSelector from '@/app/components/base/tag-management/selector' -import Toast from '@/app/components/base/toast' -import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { AlertDialog, @@ -28,6 +25,7 @@ import { AlertDialogDescription, AlertDialogTitle, } from '@/app/components/base/ui/alert-dialog' +import { toast } from '@/app/components/base/ui/toast' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' @@ -71,7 +69,6 @@ export type AppCardProps = { const AppCard = ({ app, onRefresh }: AppCardProps) => { const { t } = useTranslation() - const { notify } = useContext(ToastContext) const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) const { isCurrentWorkspaceEditor } = useAppContext() const { onPlanInfoChanged } = useProviderContext() @@ -90,20 +87,17 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { const onConfirmDelete = useCallback(async () => { try { await mutateDeleteApp(app.id) - notify({ type: 'success', message: t('appDeleted', { ns: 'app' }) }) + toast.success(t('appDeleted', { ns: 'app' })) onPlanInfoChanged() } catch (e: any) { - notify({ - type: 'error', - message: `${t('appDeleteFailed', { ns: 'app' })}${'message' in e ? `: ${e.message}` : ''}`, - }) + toast.error(`${t('appDeleteFailed', { ns: 'app' })}${'message' in e ? `: ${e.message}` : ''}`) } finally { setShowConfirmDelete(false) setConfirmDeleteInput('') } - }, [app.id, mutateDeleteApp, notify, onPlanInfoChanged, t]) + }, [app.id, mutateDeleteApp, onPlanInfoChanged, t]) const onDeleteDialogOpenChange = useCallback((open: boolean) => { if (isDeleting) @@ -135,20 +129,14 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { max_active_requests, }) setShowEditModal(false) - notify({ - type: 'success', - message: t('editDone', { ns: 'app' }), - }) + toast.success(t('editDone', { ns: 'app' })) if (onRefresh) onRefresh() } catch (e: any) { - notify({ - type: 'error', - message: e.message || t('editFailed', { ns: 'app' }), - }) + toast.error(e.message || t('editFailed', { ns: 'app' })) } - }, [app.id, notify, onRefresh, t]) + }, [app.id, onRefresh, t]) const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon_type, icon, icon_background }) => { try { @@ -161,10 +149,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { mode: app.mode, }) setShowDuplicateModal(false) - notify({ - type: 'success', - message: t('newApp.appCreated', { ns: 'app' }), - }) + toast.success(t('newApp.appCreated', { ns: 'app' })) localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') if (onRefresh) onRefresh() @@ -172,7 +157,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { getRedirection(isCurrentWorkspaceEditor, newApp, push) } catch { - notify({ type: 'error', message: t('newApp.appCreateFailed', { ns: 'app' }) }) + toast.error(t('newApp.appCreateFailed', { ns: 'app' })) } } @@ -186,7 +171,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { downloadBlob({ data: file, fileName: `${app.name}.yml` }) } catch { - notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) + toast.error(t('exportFailed', { ns: 'app' })) } } @@ -205,7 +190,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { setSecretEnvList(list) } catch { - notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) + toast.error(t('exportFailed', { ns: 'app' })) } } @@ -274,13 +259,13 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { throw new Error('No app found in Explore') }, { onError: (err) => { - Toast.notify({ type: 'error', message: `${err.message || err}` }) + toast.error(`${err.message || err}`) }, }) } catch (e: unknown) { const message = e instanceof Error ? e.message : `${e}` - Toast.notify({ type: 'error', message }) + toast.error(message) } } return ( diff --git a/web/app/components/base/agent-log-modal/__tests__/detail.spec.tsx b/web/app/components/base/agent-log-modal/__tests__/detail.spec.tsx index 8b796435e09..6ce1e54a47c 100644 --- a/web/app/components/base/agent-log-modal/__tests__/detail.spec.tsx +++ b/web/app/components/base/agent-log-modal/__tests__/detail.spec.tsx @@ -1,16 +1,32 @@ -import type { ComponentProps } from 'react' +import type { ComponentProps, ReactNode } from 'react' import type { IChatItem } from '@/app/components/base/chat/chat/type' import type { AgentLogDetailResponse } from '@/models/log' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { useStore as useAppStore } from '@/app/components/app/store' -import { ToastContext } from '@/app/components/base/toast/context' import { fetchAgentLogDetail } from '@/service/log' import AgentLogDetail from '../detail' +const { mockToast } = vi.hoisted(() => { + const mockToast = Object.assign(vi.fn(), { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }) + return { mockToast } +}) + vi.mock('@/service/log', () => ({ fetchAgentLogDetail: vi.fn(), })) +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: mockToast, +})) + vi.mock('@/app/components/app/store', () => ({ useStore: vi.fn(selector => selector({ appDetail: { id: 'app-id' } })), })) @@ -22,7 +38,7 @@ vi.mock('@/app/components/workflow/run/status', () => ({ })) vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({ - default: ({ title, value }: { title: React.ReactNode, value: string | object }) => ( + default: ({ title, value }: { title: ReactNode, value: string | object }) => (
{title} {typeof value === 'string' ? value : JSON.stringify(value)} @@ -76,19 +92,13 @@ const createMockResponse = (overrides: Partial = {}): Ag }) describe('AgentLogDetail', () => { - const notify = vi.fn() - const renderComponent = (props: Partial> = {}) => { const defaultProps: ComponentProps = { conversationID: 'conv-id', messageID: 'msg-id', log: createMockLog(), } - return render( - ['value']}> - - , - ) + return render() } const renderAndWaitForData = async (props: Partial> = {}) => { @@ -212,10 +222,7 @@ describe('AgentLogDetail', () => { renderComponent() await waitFor(() => { - expect(notify).toHaveBeenCalledWith({ - type: 'error', - message: 'Error: API Error', - }) + expect(mockToast.error).toHaveBeenCalledWith('Error: API Error') }) }) diff --git a/web/app/components/base/agent-log-modal/__tests__/index.spec.tsx b/web/app/components/base/agent-log-modal/__tests__/index.spec.tsx index b2db5244535..d1581c40b52 100644 --- a/web/app/components/base/agent-log-modal/__tests__/index.spec.tsx +++ b/web/app/components/base/agent-log-modal/__tests__/index.spec.tsx @@ -1,14 +1,30 @@ import type { IChatItem } from '@/app/components/base/chat/chat/type' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { useClickAway } from 'ahooks' -import { ToastContext } from '@/app/components/base/toast/context' import { fetchAgentLogDetail } from '@/service/log' import AgentLogModal from '../index' +const { mockToast } = vi.hoisted(() => { + const mockToast = Object.assign(vi.fn(), { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }) + return { mockToast } +}) + vi.mock('@/service/log', () => ({ fetchAgentLogDetail: vi.fn(), })) +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: mockToast, +})) + vi.mock('@/app/components/app/store', () => ({ useStore: vi.fn(selector => selector({ appDetail: { id: 'app-id' } })), })) @@ -94,11 +110,7 @@ describe('AgentLogModal', () => { }) it('should render correctly when log item is provided', async () => { - render( - ['value']}> - - , - ) + render() expect(screen.getByText('appLog.runDetail.workflowTitle')).toBeInTheDocument() @@ -110,11 +122,7 @@ describe('AgentLogModal', () => { it('should call onCancel when close button is clicked', () => { vi.mocked(fetchAgentLogDetail).mockReturnValue(new Promise(() => {})) - render( - ['value']}> - - , - ) + render() const closeBtn = screen.getByRole('heading', { name: /appLog.runDetail.workflowTitle/i }).nextElementSibling! fireEvent.click(closeBtn) @@ -130,11 +138,7 @@ describe('AgentLogModal', () => { clickAwayHandler = callback }) - render( - ['value']}> - - , - ) + render() clickAwayHandler(new Event('click')) expect(mockProps.onCancel).toHaveBeenCalledTimes(1) @@ -150,11 +154,7 @@ describe('AgentLogModal', () => { } }) - render( - ['value']}> - - , - ) + render() expect(mockProps.onCancel).not.toHaveBeenCalled() }) diff --git a/web/app/components/base/agent-log-modal/detail.tsx b/web/app/components/base/agent-log-modal/detail.tsx index 21ed0be7e8c..6550b305f87 100644 --- a/web/app/components/base/agent-log-modal/detail.tsx +++ b/web/app/components/base/agent-log-modal/detail.tsx @@ -7,10 +7,9 @@ import { flatten } from 'es-toolkit/compat' import * as React from 'react' import { useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import { useStore as useAppStore } from '@/app/components/app/store' import Loading from '@/app/components/base/loading' -import { ToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import { fetchAgentLogDetail } from '@/service/log' import { cn } from '@/utils/classnames' import ResultPanel from './result' @@ -22,28 +21,19 @@ export type AgentLogDetailProps = { log: IChatItem messageID: string } - -const AgentLogDetail: FC = ({ - activeTab = 'DETAIL', - conversationID, - messageID, - log, -}) => { +const AgentLogDetail: FC = ({ activeTab = 'DETAIL', conversationID, messageID, log }) => { const { t } = useTranslation() - const { notify } = useContext(ToastContext) const [currentTab, setCurrentTab] = useState(activeTab) const appDetail = useAppStore(s => s.appDetail) const [loading, setLoading] = useState(true) const [runDetail, setRunDetail] = useState() const [list, setList] = useState([]) - const tools = useMemo(() => { const res = uniq(flatten(runDetail?.iterations.map((iteration) => { return iteration.tool_calls.map((tool: any) => tool.tool_name).filter(Boolean) })).filter(Boolean)) return res }, [runDetail]) - const getLogDetail = useCallback(async (appID: string, conversationID: string, messageID: string) => { try { const res = await fetchAgentLogDetail({ @@ -57,51 +47,30 @@ const AgentLogDetail: FC = ({ setList(res.iterations) } catch (err) { - notify({ - type: 'error', - message: `${err}`, - }) + toast.error(`${err}`) } - }, [notify]) - + }, []) const getData = async (appID: string, conversationID: string, messageID: string) => { setLoading(true) await getLogDetail(appID, conversationID, messageID) setLoading(false) } - const switchTab = async (tab: string) => { setCurrentTab(tab) } - useEffect(() => { // fetch data if (appDetail) getData(appDetail.id, conversationID, messageID) }, [appDetail, conversationID, messageID]) - return (
{/* tab */}
-
switchTab('DETAIL')} - > +
switchTab('DETAIL')}> {t('detail', { ns: 'runLog' })}
-
switchTab('TRACING')} - > +
switchTab('TRACING')}> {t('tracing', { ns: 'runLog' })}
@@ -112,29 +81,10 @@ const AgentLogDetail: FC = ({
)} - {!loading && currentTab === 'DETAIL' && runDetail && ( - - )} - {!loading && currentTab === 'TRACING' && ( - - )} + {!loading && currentTab === 'DETAIL' && runDetail && ()} + {!loading && currentTab === 'TRACING' && ()}
) } - export default AgentLogDetail diff --git a/web/app/components/base/agent-log-modal/index.stories.tsx b/web/app/components/base/agent-log-modal/index.stories.tsx index 87318848b4f..e8b49600a57 100644 --- a/web/app/components/base/agent-log-modal/index.stories.tsx +++ b/web/app/components/base/agent-log-modal/index.stories.tsx @@ -3,7 +3,7 @@ import type { IChatItem } from '@/app/components/base/chat/chat/type' import type { AgentLogDetailResponse } from '@/models/log' import { useEffect, useRef } from 'react' import { useStore as useAppStore } from '@/app/components/app/store' -import { ToastProvider } from '@/app/components/base/toast' +import { ToastHost } from '@/app/components/base/ui/toast' import AgentLogModal from '.' const MOCK_RESPONSE: AgentLogDetailResponse = { @@ -109,7 +109,8 @@ const AgentLogModalDemo = ({ }, [setAppDetail]) return ( - + <> +
-
+ ) } diff --git a/web/app/components/base/audio-btn/__tests__/audio.spec.ts b/web/app/components/base/audio-btn/__tests__/audio.spec.ts index 00ffea2dfb0..4399cb40fd0 100644 --- a/web/app/components/base/audio-btn/__tests__/audio.spec.ts +++ b/web/app/components/base/audio-btn/__tests__/audio.spec.ts @@ -6,9 +6,9 @@ import AudioPlayer from '../audio' const mockToastNotify = vi.hoisted(() => vi.fn()) const mockTextToAudioStream = vi.hoisted(() => vi.fn()) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: (...args: unknown[]) => mockToastNotify(...args), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: (message: string) => mockToastNotify({ type: 'error', message }), }, })) diff --git a/web/app/components/base/audio-btn/audio.ts b/web/app/components/base/audio-btn/audio.ts index abfcad7c2f7..5afe2bb656f 100644 --- a/web/app/components/base/audio-btn/audio.ts +++ b/web/app/components/base/audio-btn/audio.ts @@ -1,4 +1,4 @@ -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { AppSourceType, textToAudioStream } from '@/service/share' declare global { @@ -7,7 +7,6 @@ declare global { ManagedMediaSource: any } } - export default class AudioPlayer { mediaSource: MediaSource | null audio: HTMLAudioElement @@ -22,7 +21,6 @@ export default class AudioPlayer { url: string isPublic: boolean callback: ((event: string) => void) | null - constructor(streamUrl: string, isPublic: boolean, msgId: string | undefined, msgContent: string | null | undefined, voice: string | undefined, callback: ((event: string) => void) | null) { this.audioContext = new AudioContext() this.msgId = msgId @@ -31,14 +29,10 @@ export default class AudioPlayer { this.isPublic = isPublic this.voice = voice this.callback = callback - // Compatible with iphone ios17 ManagedMediaSource const MediaSource = window.ManagedMediaSource || window.MediaSource if (!MediaSource) { - Toast.notify({ - message: 'Your browser does not support audio streaming, if you are using an iPhone, please update to iOS 17.1 or later.', - type: 'error', - }) + toast.error('Your browser does not support audio streaming, if you are using an iPhone, please update to iOS 17.1 or later.') } this.mediaSource = MediaSource ? new MediaSource() : null this.audio = new Audio() @@ -49,7 +43,6 @@ export default class AudioPlayer { } this.audio.src = this.mediaSource ? URL.createObjectURL(this.mediaSource) : '' this.audio.autoplay = true - const source = this.audioContext.createMediaElementSource(this.audio) source.connect(this.audioContext.destination) this.listenMediaSource('audio/mpeg') @@ -63,7 +56,6 @@ export default class AudioPlayer { this.mediaSource?.addEventListener('sourceopen', () => { if (this.sourceBuffer) return - this.sourceBuffer = this.mediaSource?.addSourceBuffer(contentType) }) } @@ -106,22 +98,18 @@ export default class AudioPlayer { voice: this.voice, text: this.msgContent, }) - if (audioResponse.status !== 200) { this.isLoadData = false if (this.callback) this.callback('error') } - const reader = audioResponse.body.getReader() while (true) { const { value, done } = await reader.read() - if (done) { this.receiveAudioData(value) break } - this.receiveAudioData(value) } } @@ -167,7 +155,6 @@ export default class AudioPlayer { this.theEndOfStream() clearInterval(timer) } - if (this.cacheBuffers.length && !this.sourceBuffer?.updating) { const arrayBuffer = this.cacheBuffers.shift()! this.sourceBuffer?.appendBuffer(arrayBuffer) @@ -180,7 +167,6 @@ export default class AudioPlayer { this.finishStream() return } - const audioContent = Buffer.from(audio, 'base64') this.receiveAudioData(new Uint8Array(audioContent)) if (play) { @@ -196,7 +182,6 @@ export default class AudioPlayer { this.callback?.('play') } else if (this.audio.played) { /* empty */ } - else { this.audio.play() this.callback?.('play') @@ -221,7 +206,6 @@ export default class AudioPlayer { this.finishStream() return } - if (this.sourceBuffer?.updating) { this.cacheBuffers.push(audioData) } diff --git a/web/app/components/base/audio-gallery/AudioPlayer.tsx b/web/app/components/base/audio-gallery/AudioPlayer.tsx index cbf50ddc13f..5a0a753ecf5 100644 --- a/web/app/components/base/audio-gallery/AudioPlayer.tsx +++ b/web/app/components/base/audio-gallery/AudioPlayer.tsx @@ -1,7 +1,7 @@ import { t } from 'i18next' import * as React from 'react' import { useCallback, useEffect, useRef, useState } from 'react' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import useTheme from '@/hooks/use-theme' import { Theme } from '@/types/app' import { cn } from '@/utils/classnames' @@ -10,7 +10,6 @@ type AudioPlayerProps = { src?: string // Keep backward compatibility srcs?: string[] // Support multiple sources } - const AudioPlayer: React.FC = ({ src, srcs }) => { const [isPlaying, setIsPlaying] = useState(false) const [currentTime, setCurrentTime] = useState(0) @@ -23,43 +22,34 @@ const AudioPlayer: React.FC = ({ src, srcs }) => { const [hoverTime, setHoverTime] = useState(0) const [isAudioAvailable, setIsAudioAvailable] = useState(true) const { theme } = useTheme() - useEffect(() => { const audio = audioRef.current /* v8 ignore next 2 - @preserve */ if (!audio) return - const handleError = () => { setIsAudioAvailable(false) } - const setAudioData = () => { setDuration(audio.duration) } - const setAudioTime = () => { setCurrentTime(audio.currentTime) } - const handleProgress = () => { if (audio.buffered.length > 0) setBufferedTime(audio.buffered.end(audio.buffered.length - 1)) } - const handleEnded = () => { setIsPlaying(false) } - audio.addEventListener('loadedmetadata', setAudioData) audio.addEventListener('timeupdate', setAudioTime) audio.addEventListener('progress', handleProgress) audio.addEventListener('ended', handleEnded) audio.addEventListener('error', handleError) - // Preload audio metadata audio.load() - // Use the first source or src to generate waveform const primarySrc = srcs?.[0] || src if (primarySrc) { @@ -76,17 +66,12 @@ const AudioPlayer: React.FC = ({ src, srcs }) => { } } }, [src, srcs]) - const generateWaveformData = async (audioSrc: string) => { if (!window.AudioContext && !(window as any).webkitAudioContext) { setIsAudioAvailable(false) - Toast.notify({ - type: 'error', - message: 'Web Audio API is not supported in this browser', - }) + toast.error('Web Audio API is not supported in this browser') return null } - const primarySrc = srcs?.[0] || src const url = primarySrc ? new URL(primarySrc) : null const isHttp = url ? (url.protocol === 'http:' || url.protocol === 'https:') : false @@ -94,53 +79,43 @@ const AudioPlayer: React.FC = ({ src, srcs }) => { setIsAudioAvailable(false) return null } - const audioContext = new (window.AudioContext || (window as any).webkitAudioContext)() const samples = 70 - try { const response = await fetch(audioSrc, { mode: 'cors' }) if (!response || !response.ok) { setIsAudioAvailable(false) return null } - const arrayBuffer = await response.arrayBuffer() const audioBuffer = await audioContext.decodeAudioData(arrayBuffer) const channelData = audioBuffer.getChannelData(0) const blockSize = Math.floor(channelData.length / samples) const waveformData: number[] = [] - for (let i = 0; i < samples; i++) { let sum = 0 for (let j = 0; j < blockSize; j++) sum += Math.abs(channelData[i * blockSize + j]) - // Apply nonlinear scaling to enhance small amplitudes waveformData.push((sum / blockSize) * 5) } - // Normalized waveform data const maxAmplitude = Math.max(...waveformData) const normalizedWaveform = waveformData.map(amp => amp / maxAmplitude) - setWaveformData(normalizedWaveform) setIsAudioAvailable(true) } catch { const waveform: number[] = [] let prevValue = Math.random() - for (let i = 0; i < samples; i++) { const targetValue = Math.random() const interpolatedValue = prevValue + (targetValue - prevValue) * 0.3 waveform.push(interpolatedValue) prevValue = interpolatedValue } - const maxAmplitude = Math.max(...waveform) const randomWaveform = waveform.map(amp => amp / maxAmplitude) - setWaveformData(randomWaveform) setIsAudioAvailable(true) } @@ -148,7 +123,6 @@ const AudioPlayer: React.FC = ({ src, srcs }) => { await audioContext.close() } } - const togglePlay = useCallback(() => { const audio = audioRef.current if (audio && isAudioAvailable) { @@ -160,99 +134,75 @@ const AudioPlayer: React.FC = ({ src, srcs }) => { setHasStartedPlaying(true) audio.play().catch(error => console.error('Error playing audio:', error)) } - setIsPlaying(!isPlaying) } else { - Toast.notify({ - type: 'error', - message: 'Audio element not found', - }) + toast.error('Audio element not found') setIsAudioAvailable(false) } }, [isAudioAvailable, isPlaying]) - const handleCanvasInteraction = useCallback((e: React.MouseEvent | React.TouchEvent) => { e.preventDefault() - const getClientX = (event: React.MouseEvent | React.TouchEvent): number => { if ('touches' in event) return event.touches[0].clientX return event.clientX } - const updateProgress = (clientX: number) => { const canvas = canvasRef.current const audio = audioRef.current if (!canvas || !audio) return - const rect = canvas.getBoundingClientRect() const percent = Math.min(Math.max(0, clientX - rect.left), rect.width) / rect.width const newTime = percent * duration - // Removes the buffer check, allowing drag to any location audio.currentTime = newTime setCurrentTime(newTime) - if (!isPlaying) { setIsPlaying(true) audio.play().catch((error) => { - Toast.notify({ - type: 'error', - message: `Error playing audio: ${error}`, - }) + toast.error(`Error playing audio: ${error}`) setIsPlaying(false) }) } } - updateProgress(getClientX(e)) }, [duration, isPlaying]) - const formatTime = (time: number) => { const minutes = Math.floor(time / 60) const seconds = Math.floor(time % 60) return `${minutes}:${seconds.toString().padStart(2, '0')}` } - const drawWaveform = useCallback(() => { const canvas = canvasRef.current /* v8 ignore next 2 - @preserve */ if (!canvas) return - const ctx = canvas.getContext('2d') if (!ctx) return - const width = canvas.width const height = canvas.height const data = waveformData - ctx.clearRect(0, 0, width, height) - const barWidth = width / data.length const playedWidth = (currentTime / duration) * width const cornerRadius = 2 - // Draw waveform bars data.forEach((value, index) => { let color - if (index * barWidth <= playedWidth) color = theme === Theme.light ? '#296DFF' : '#84ABFF' else if ((index * barWidth / width) * duration <= hoverTime) color = theme === Theme.light ? 'rgba(21,90,239,.40)' : 'rgba(200, 206, 218, 0.28)' else color = theme === Theme.light ? 'rgba(21,90,239,.20)' : 'rgba(200, 206, 218, 0.14)' - const barHeight = value * height const rectX = index * barWidth const rectY = (height - barHeight) / 2 const rectWidth = barWidth * 0.5 const rectHeight = barHeight - ctx.lineWidth = 1 ctx.fillStyle = color if (ctx.roundRect) { @@ -265,27 +215,22 @@ const AudioPlayer: React.FC = ({ src, srcs }) => { } }) }, [currentTime, duration, hoverTime, theme, waveformData]) - useEffect(() => { drawWaveform() }, [drawWaveform, bufferedTime, hasStartedPlaying]) - const handleMouseMove = useCallback((e: React.MouseEvent | React.TouchEvent) => { const canvas = canvasRef.current const audio = audioRef.current if (!canvas || !audio) return - const clientX = 'touches' in e ? e.touches[0]?.clientX ?? e.changedTouches[0]?.clientX : e.clientX if (clientX === undefined) return - const rect = canvas.getBoundingClientRect() const percent = Math.min(Math.max(0, clientX - rect.left), rect.width) / rect.width const time = percent * duration - // Check if the hovered position is within a buffered range before updating hoverTime for (let i = 0; i < audio.buffered.length; i++) { if (time >= audio.buffered.start(i) && time <= audio.buffered.end(i)) { @@ -294,38 +239,20 @@ const AudioPlayer: React.FC = ({ src, srcs }) => { } } }, [duration]) - return (
- ) } - const placeholder = '' const editAreaClassName = 'focus:outline-none bg-transparent text-sm' - const textAreaContent = (
!readonly && setIsEditing(true)}> {isEditing @@ -134,10 +105,10 @@ const BlockInput: FC = ({ onBlur={() => { blur() setIsEditing(false) - // click confirm also make blur. Then outer value is change. So below code has problem. - // setTimeout(() => { - // handleCancel() - // }, 1000) + // click confirm also make blur. Then outer value is change. So below code has problem. + // setTimeout(() => { + // handleCancel() + // }, 1000) }} />
@@ -145,7 +116,6 @@ const BlockInput: FC = ({ : }
) - return (
{textAreaContent} @@ -159,5 +129,4 @@ const BlockInput: FC = ({
) } - export default React.memo(BlockInput) diff --git a/web/app/components/base/chat/chat-with-history/__tests__/hooks.spec.tsx b/web/app/components/base/chat/chat-with-history/__tests__/hooks.spec.tsx index b004a1bee67..f4c8ef0c458 100644 --- a/web/app/components/base/chat/chat-with-history/__tests__/hooks.spec.tsx +++ b/web/app/components/base/chat/chat-with-history/__tests__/hooks.spec.tsx @@ -4,7 +4,7 @@ import type { InstalledApp } from '@/models/explore' import type { AppConversationData, AppData, AppMeta, ConversationItem } from '@/models/share' import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { act, renderHook, waitFor } from '@testing-library/react' -import { ToastProvider } from '@/app/components/base/toast' +import { ToastHost } from '@/app/components/base/ui/toast' import { AppSourceType, delConversation, @@ -95,7 +95,8 @@ const createQueryClient = () => new QueryClient({ const createWrapper = (queryClient: QueryClient) => { return ({ children }: { children: ReactNode }) => ( - {children} + + {children} ) } diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index 23936111ce0..e6f5657ff5b 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -1,47 +1,21 @@ import type { ExtraContent } from '../chat/type' -import type { - Callback, - ChatConfig, - ChatItem, - Feedback, -} from '../types' +import type { Callback, ChatConfig, ChatItem, Feedback } from '../types' import type { InstalledApp } from '@/models/explore' -import type { - AppData, - ConversationItem, -} from '@/models/share' +import type { AppData, ConversationItem } from '@/models/share' import type { HumanInputFilledFormData, HumanInputFormData } from '@/types/workflow' import { useLocalStorageState } from 'ahooks' import { noop } from 'es-toolkit/function' import { produce } from 'immer' -import { - useCallback, - useEffect, - useMemo, - useRef, - useState, -} from 'react' +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' -import { useToastContext } from '@/app/components/base/toast/context' +import { toast } from '@/app/components/base/ui/toast' import { InputVarType } from '@/app/components/workflow/types' import { useWebAppStore } from '@/context/web-app-context' import { useAppFavicon } from '@/hooks/use-app-favicon' import { changeLanguage } from '@/i18n-config/client' -import { - AppSourceType, - delConversation, - pinConversation, - renameConversation, - unpinConversation, - updateFeedback, -} from '@/service/share' -import { - useInvalidateShareConversations, - useShareChatList, - useShareConversationName, - useShareConversations, -} from '@/service/use-share' +import { AppSourceType, delConversation, pinConversation, renameConversation, unpinConversation, updateFeedback } from '@/service/share' +import { useInvalidateShareConversations, useShareChatList, useShareConversationName, useShareConversations } from '@/service/use-share' import { TransferMethod } from '@/types/app' import { addFileInfos, sortAgentSorts } from '../../../tools/utils' import { CONVERSATION_ID_INFO } from '../constants' @@ -93,14 +67,12 @@ function getFormattedChatList(messages: any[]) { }) return newChatList } - export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { const isInstalledApp = useMemo(() => !!installedAppInfo, [installedAppInfo]) const appSourceType = isInstalledApp ? AppSourceType.installedApp : AppSourceType.webApp const appInfo = useWebAppStore(s => s.appInfo) const appParams = useWebAppStore(s => s.appParams) const appMeta = useWebAppStore(s => s.appMeta) - useAppFavicon({ enable: !installedAppInfo, icon_type: appInfo?.site.icon_type, @@ -108,7 +80,6 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { icon_background: appInfo?.site.icon_background, icon_url: appInfo?.site.icon_url, }) - const appData = useMemo(() => { if (isInstalledApp) { const { id, app } = installedAppInfo! @@ -129,18 +100,15 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { custom_config: null, } as AppData } - return appInfo }, [isInstalledApp, installedAppInfo, appInfo]) const appId = useMemo(() => appData?.app_id, [appData]) - const [userId, setUserId] = useState() useEffect(() => { getProcessedSystemVariablesFromUrlParams().then(({ user_id }) => { setUserId(user_id) }) }, []) - useEffect(() => { const setLocaleFromProps = async () => { if (appData?.site.default_language) @@ -148,7 +116,6 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { } setLocaleFromProps() }, [appData]) - const [sidebarCollapseState, setSidebarCollapseState] = useState(() => { if (typeof window !== 'undefined') { try { @@ -192,15 +159,12 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { }) } }, [appId, conversationIdInfo, setConversationIdInfo, userId]) - const [newConversationId, setNewConversationId] = useState('') const chatShouldReloadKey = useMemo(() => { if (currentConversationId === newConversationId) return '' - return currentConversationId }, [currentConversationId, newConversationId]) - const { data: appPinnedConversationData } = useShareConversations({ appSourceType, appId, @@ -211,10 +175,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { refetchOnWindowFocus: false, refetchOnReconnect: false, }) - const { - data: appConversationData, - isLoading: appConversationDataLoading, - } = useShareConversations({ + const { data: appConversationData, isLoading: appConversationDataLoading } = useShareConversations({ appSourceType, appId, pinned: false, @@ -224,10 +185,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { refetchOnWindowFocus: false, refetchOnReconnect: false, }) - const { - data: appChatListData, - isLoading: appChatListDataLoading, - } = useShareChatList({ + const { data: appChatListData, isLoading: appChatListDataLoading } = useShareChatList({ conversationId: chatShouldReloadKey, appSourceType, appId, @@ -237,18 +195,12 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { refetchOnReconnect: false, }) const invalidateShareConversations = useInvalidateShareConversations() - const [clearChatList, setClearChatList] = useState(false) const [isResponding, setIsResponding] = useState(false) - const appPrevChatTree = useMemo( - () => (currentConversationId && appChatListData?.data.length) - ? buildChatItemTree(getFormattedChatList(appChatListData.data)) - : [], - [appChatListData, currentConversationId], - ) - + const appPrevChatTree = useMemo(() => (currentConversationId && appChatListData?.data.length) + ? buildChatItemTree(getFormattedChatList(appChatListData.data)) + : [], [appChatListData, currentConversationId]) const [showNewConversationItemInList, setShowNewConversationItemInList] = useState(false) - const pinnedConversationList = useMemo(() => { return appPinnedConversationData?.data || [] }, [appPinnedConversationData]) @@ -267,7 +219,6 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { let value = initInputs[item.paragraph.variable] if (value && item.paragraph.max_length && value.length > item.paragraph.max_length) value = value.slice(0, item.paragraph.max_length) - return { ...item.paragraph, default: value || item.default || item.paragraph.default, @@ -282,7 +233,6 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { type: 'number', } } - if (item.checkbox) { const preset = initInputs[item.checkbox.variable] === true return { @@ -291,7 +241,6 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { type: 'checkbox', } } - if (item.select) { const isInputInOptions = item.select.options.includes(initInputs[item.select.variable]) return { @@ -300,32 +249,27 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { type: 'select', } } - if (item['file-list']) { return { ...item['file-list'], type: 'file-list', } } - if (item.file) { return { ...item.file, type: 'file', } } - if (item.json_object) { return { ...item.json_object, type: 'json_object', } } - let value = initInputs[item['text-input'].variable] if (value && item['text-input'].max_length && value.length > item['text-input'].max_length) value = value.slice(0, item['text-input'].max_length) - return { ...item['text-input'], default: value || item.default || item['text-input'].default, @@ -333,11 +277,9 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { } }) }, [initInputs, appParams]) - const allInputsHidden = useMemo(() => { return inputsForms.length > 0 && inputsForms.every(item => item.hide === true) }, [inputsForms]) - useEffect(() => { // init inputs from url params (async () => { @@ -347,16 +289,13 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { setInitUserVariables(userVariables) })() }, []) - useEffect(() => { const conversationInputs: Record = {} - inputsForms.forEach((item: any) => { conversationInputs[item.variable] = item.default || null }) handleNewConversationInputsChange(conversationInputs) }, [handleNewConversationInputsChange, inputsForms]) - const { data: newConversation } = useShareConversationName({ conversationId: newConversationId, appSourceType, @@ -372,7 +311,6 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { }, [appConversationData, appConversationDataLoading]) const conversationList = useMemo(() => { const data = originConversationList.slice() - if (showNewConversationItemInList && data[0]?.id !== '') { data.unshift({ id: '', @@ -383,12 +321,10 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { } return data }, [originConversationList, showNewConversationItemInList, t]) - useEffect(() => { if (newConversation) { setOriginConversationList(produce((draft) => { const index = draft.findIndex(item => item.id === newConversation.id) - if (index > -1) draft[index] = newConversation else @@ -396,16 +332,12 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { })) } }, [newConversation]) - const currentConversationItem = useMemo(() => { let conversationItem = conversationList.find(item => item.id === currentConversationId) - if (!conversationItem && pinnedConversationList.length) conversationItem = pinnedConversationList.find(item => item.id === currentConversationId) - return conversationItem }, [conversationList, currentConversationId, pinnedConversationList]) - const currentConversationLatestInputs = useMemo(() => { if (!currentConversationId || !appChatListData?.data.length) return newConversationInputsRef.current || {} @@ -416,12 +348,9 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { if (currentConversationItem) setCurrentConversationInputs(currentConversationLatestInputs || {}) }, [currentConversationItem, currentConversationLatestInputs]) - - const { notify } = useToastContext() const checkInputsRequired = useCallback((silent?: boolean) => { if (allInputsHidden) return true - let hasEmptyInput = '' let fileIsUploading = false const requiredVars = inputsForms.filter(({ required, type }) => required && type !== InputVarType.checkbox) @@ -429,13 +358,10 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { requiredVars.forEach(({ variable, label, type }) => { if (hasEmptyInput) return - if (fileIsUploading) return - if (!newConversationInputsRef.current[variable] && !silent) hasEmptyInput = label as string - if ((type === InputVarType.singleFile || type === InputVarType.multiFiles) && newConversationInputsRef.current[variable] && !silent) { const files = newConversationInputsRef.current[variable] if (Array.isArray(files)) @@ -445,26 +371,25 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { } }) } - if (hasEmptyInput) { - notify({ type: 'error', message: t('errorMessage.valueOfVarRequired', { ns: 'appDebug', key: hasEmptyInput }) }) + toast.error(t('errorMessage.valueOfVarRequired', { ns: 'appDebug', key: hasEmptyInput })) return false } - if (fileIsUploading) { - notify({ type: 'info', message: t('errorMessage.waitForFileUpload', { ns: 'appDebug' }) }) + toast.info(t('errorMessage.waitForFileUpload', { ns: 'appDebug' })) return } - return true - }, [inputsForms, notify, t, allInputsHidden]) + }, [inputsForms, t, allInputsHidden]) const handleStartChat = useCallback((callback: any) => { if (checkInputsRequired()) { setShowNewConversationItemInList(true) callback?.() } }, [setShowNewConversationItemInList, checkInputsRequired]) - const currentChatInstanceRef = useRef<{ handleStop: () => void }>({ handleStop: noop }) + const currentChatInstanceRef = useRef<{ + handleStop: () => void + }>({ handleStop: noop }) const handleChangeConversation = useCallback((conversationId: string) => { currentChatInstanceRef.current.handleStop() setNewConversationId('') @@ -486,76 +411,48 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { const handleUpdateConversationList = useCallback(() => { invalidateShareConversations() }, [invalidateShareConversations]) - const handlePinConversation = useCallback(async (conversationId: string) => { await pinConversation(appSourceType, appId, conversationId) - notify({ type: 'success', message: t('api.success', { ns: 'common' }) }) + toast.success(t('api.success', { ns: 'common' })) handleUpdateConversationList() - }, [appSourceType, appId, notify, t, handleUpdateConversationList]) - + }, [appSourceType, appId, t, handleUpdateConversationList]) const handleUnpinConversation = useCallback(async (conversationId: string) => { await unpinConversation(appSourceType, appId, conversationId) - notify({ type: 'success', message: t('api.success', { ns: 'common' }) }) + toast.success(t('api.success', { ns: 'common' })) handleUpdateConversationList() - }, [appSourceType, appId, notify, t, handleUpdateConversationList]) - + }, [appSourceType, appId, t, handleUpdateConversationList]) const [conversationDeleting, setConversationDeleting] = useState(false) - const handleDeleteConversation = useCallback(async ( - conversationId: string, - { - onSuccess, - }: Callback, - ) => { + const handleDeleteConversation = useCallback(async (conversationId: string, { onSuccess }: Callback) => { if (conversationDeleting) return - try { setConversationDeleting(true) await delConversation(appSourceType, appId, conversationId) - notify({ type: 'success', message: t('api.success', { ns: 'common' }) }) + toast.success(t('api.success', { ns: 'common' })) onSuccess() } finally { setConversationDeleting(false) } - if (conversationId === currentConversationId) handleNewConversation() - handleUpdateConversationList() - }, [isInstalledApp, appId, notify, t, handleUpdateConversationList, handleNewConversation, currentConversationId, conversationDeleting]) - + }, [isInstalledApp, appId, t, handleUpdateConversationList, handleNewConversation, currentConversationId, conversationDeleting]) const [conversationRenaming, setConversationRenaming] = useState(false) - const handleRenameConversation = useCallback(async ( - conversationId: string, - newName: string, - { - onSuccess, - }: Callback, - ) => { + const handleRenameConversation = useCallback(async (conversationId: string, newName: string, { onSuccess }: Callback) => { if (conversationRenaming) return - if (!newName.trim()) { - notify({ - type: 'error', - message: t('chat.conversationNameCanNotEmpty', { ns: 'common' }), - }) + toast.error(t('chat.conversationNameCanNotEmpty', { ns: 'common' })) return } - setConversationRenaming(true) try { await renameConversation(appSourceType, appId, conversationId, newName) - - notify({ - type: 'success', - message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }), - }) + toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) setOriginConversationList(produce((draft) => { const index = originConversationList.findIndex(item => item.id === conversationId) const item = draft[index] - draft[index] = { ...item, name: newName, @@ -566,20 +463,17 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { finally { setConversationRenaming(false) } - }, [isInstalledApp, appId, notify, t, conversationRenaming, originConversationList]) - + }, [isInstalledApp, appId, t, conversationRenaming, originConversationList]) const handleNewConversationCompleted = useCallback((newConversationId: string) => { setNewConversationId(newConversationId) handleConversationIdInfoChange(newConversationId) setShowNewConversationItemInList(false) invalidateShareConversations() }, [handleConversationIdInfoChange, invalidateShareConversations]) - const handleFeedback = useCallback(async (messageId: string, feedback: Feedback) => { await updateFeedback({ url: `/messages/${messageId}/feedbacks`, body: { rating: feedback.rating, content: feedback.content } }, appSourceType, appId) - notify({ type: 'success', message: t('api.success', { ns: 'common' }) }) - }, [appSourceType, appId, t, notify]) - + toast.success(t('api.success', { ns: 'common' })) + }, [appSourceType, appId, t]) return { isInstalledApp, appId, diff --git a/web/app/components/base/chat/chat/__tests__/check-input-forms-hooks.spec.tsx b/web/app/components/base/chat/chat/__tests__/check-input-forms-hooks.spec.tsx index 6afbc26582a..1e96c1f798f 100644 --- a/web/app/components/base/chat/chat/__tests__/check-input-forms-hooks.spec.tsx +++ b/web/app/components/base/chat/chat/__tests__/check-input-forms-hooks.spec.tsx @@ -5,8 +5,16 @@ import { TransferMethod } from '@/types/app' import { useCheckInputsForms } from '../check-input-forms-hooks' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast/context', () => ({ - useToastContext: () => ({ notify: mockNotify }), +vi.mock('@/app/components/base/ui/toast', () => ({ + default: { + notify: (args: unknown) => mockNotify(args), + }, + toast: { + success: (message: string) => mockNotify({ type: 'success', message }), + error: (message: string) => mockNotify({ type: 'error', message }), + warning: (message: string) => mockNotify({ type: 'warning', message }), + info: (message: string) => mockNotify({ type: 'info', message }), + }, })) describe('useCheckInputsForms', () => { diff --git a/web/app/components/base/chat/chat/__tests__/hooks.spec.tsx b/web/app/components/base/chat/chat/__tests__/hooks.spec.tsx index 92fa9ea42ee..89327341de8 100644 --- a/web/app/components/base/chat/chat/__tests__/hooks.spec.tsx +++ b/web/app/components/base/chat/chat/__tests__/hooks.spec.tsx @@ -20,8 +20,14 @@ vi.mock('@/app/components/base/audio-btn/audio.player.manager', () => ({ }, })) -vi.mock('@/app/components/base/toast/context', () => ({ - useToastContext: () => ({ notify: vi.fn() }), +vi.mock('@/app/components/base/ui/toast', () => ({ + default: { notify: vi.fn() }, + toast: { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + }, })) vi.mock('@/hooks/use-timestamp', () => ({ diff --git a/web/app/components/base/chat/chat/__tests__/question.spec.tsx b/web/app/components/base/chat/chat/__tests__/question.spec.tsx index e9392adb8a5..9d49be3a156 100644 --- a/web/app/components/base/chat/chat/__tests__/question.spec.tsx +++ b/web/app/components/base/chat/chat/__tests__/question.spec.tsx @@ -5,7 +5,7 @@ import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import copy from 'copy-to-clipboard' import * as React from 'react' -import Toast from '../../../toast' +import { toast } from '@/app/components/base/ui/toast' import { ThemeBuilder } from '../../embedded-chatbot/theme/theme-context' import { ChatContextProvider } from '../context-provider' import Question from '../question' @@ -179,7 +179,7 @@ describe('Question component', () => { it('should call copy-to-clipboard and show a toast when copy action is clicked', async () => { const user = userEvent.setup() - const toastSpy = vi.spyOn(Toast, 'notify') + const toastSpy = vi.spyOn(toast, 'success').mockReturnValue('toast-success') renderWithProvider(makeItem()) diff --git a/web/app/components/base/chat/chat/answer/__tests__/operation.spec.tsx b/web/app/components/base/chat/chat/answer/__tests__/operation.spec.tsx index 836397a5862..588b261323e 100644 --- a/web/app/components/base/chat/chat/answer/__tests__/operation.spec.tsx +++ b/web/app/components/base/chat/chat/answer/__tests__/operation.spec.tsx @@ -29,7 +29,7 @@ const { vi.mock('copy-to-clipboard', () => ({ default: vi.fn() })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/ui/toast', () => ({ default: { notify: vi.fn() }, })) diff --git a/web/app/components/base/chat/chat/answer/operation.tsx b/web/app/components/base/chat/chat/answer/operation.tsx index f0d077975c0..26a4b6bd99e 100644 --- a/web/app/components/base/chat/chat/answer/operation.tsx +++ b/web/app/components/base/chat/chat/answer/operation.tsx @@ -17,8 +17,8 @@ import AnnotationCtrlButton from '@/app/components/base/features/new-feature-pan import Modal from '@/app/components/base/modal/modal' import NewAudioButton from '@/app/components/base/new-audio-button' import Textarea from '@/app/components/base/textarea' -import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' import { cn } from '@/utils/classnames' import { useChatContext } from '../context' @@ -302,7 +302,7 @@ const Operation: FC = ({ { copy(content) - Toast.notify({ type: 'success', message: t('actionMsg.copySuccessfully', { ns: 'common' }) }) + toast.success(t('actionMsg.copySuccessfully', { ns: 'common' })) }} data-testid="copy-btn" > diff --git a/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx b/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx index f628b7de827..1a8dd55f616 100644 --- a/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx @@ -175,8 +175,16 @@ vi.mock('@/app/components/base/features/hooks', () => ({ // --------------------------------------------------------------------------- // Toast context // --------------------------------------------------------------------------- -vi.mock('@/app/components/base/toast/context', () => ({ - useToastContext: () => ({ notify: mockNotify, close: vi.fn() }), +vi.mock('@/app/components/base/ui/toast', () => ({ + default: { + notify: (args: unknown) => mockNotify(args), + }, + toast: { + success: (message: string) => mockNotify({ type: 'success', message }), + error: (message: string) => mockNotify({ type: 'error', message }), + warning: (message: string) => mockNotify({ type: 'warning', message }), + info: (message: string) => mockNotify({ type: 'info', message }), + }, })) // --------------------------------------------------------------------------- diff --git a/web/app/components/base/chat/chat/chat-input-area/index.tsx b/web/app/components/base/chat/chat/chat-input-area/index.tsx index 8b5ca185850..0ea928d6d6c 100644 --- a/web/app/components/base/chat/chat/chat-input-area/index.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/index.tsx @@ -1,28 +1,18 @@ import type { Theme } from '../../embedded-chatbot/theme/theme-context' -import type { - EnableType, - OnSend, -} from '../../types' +import type { EnableType, OnSend } from '../../types' import type { InputForm } from '../type' import type { FileUpload } from '@/app/components/base/features/types' import { noop } from 'es-toolkit/function' import { decode } from 'html-entities' import Recorder from 'js-audio-recorder' -import { - useCallback, - useRef, - useState, -} from 'react' +import { useCallback, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import Textarea from 'react-textarea-autosize' import FeatureBar from '@/app/components/base/features/new-feature-panel/feature-bar' import { FileListInChatInput } from '@/app/components/base/file-uploader' import { useFile } from '@/app/components/base/file-uploader/hooks' -import { - FileContextProvider, - useFileStore, -} from '@/app/components/base/file-uploader/store' -import { useToastContext } from '@/app/components/base/toast/context' +import { FileContextProvider, useFileStore } from '@/app/components/base/file-uploader/store' +import { toast } from '@/app/components/base/ui/toast' import VoiceInput from '@/app/components/base/voice-input' import { TransferMethod } from '@/types/app' import { cn } from '@/utils/classnames' @@ -53,71 +43,34 @@ type ChatInputAreaProps = { */ sendOnEnter?: boolean } -const ChatInputArea = ({ - readonly, - botName, - showFeatureBar, - showFileUpload, - featureBarDisabled, - onFeatureBarClick, - visionConfig, - speechToTextConfig = { enabled: true }, - onSend, - inputs = {}, - inputsForm = [], - theme, - isResponding, - disabled, - sendOnEnter = true, -}: ChatInputAreaProps) => { +const ChatInputArea = ({ readonly, botName, showFeatureBar, showFileUpload, featureBarDisabled, onFeatureBarClick, visionConfig, speechToTextConfig = { enabled: true }, onSend, inputs = {}, inputsForm = [], theme, isResponding, disabled, sendOnEnter = true }: ChatInputAreaProps) => { const { t } = useTranslation() - const { notify } = useToastContext() - const { - wrapperRef, - textareaRef, - textValueRef, - holdSpaceRef, - handleTextareaResize, - isMultipleLine, - } = useTextAreaHeight() + const { wrapperRef, textareaRef, textValueRef, holdSpaceRef, handleTextareaResize, isMultipleLine } = useTextAreaHeight() const [query, setQuery] = useState('') const [showVoiceInput, setShowVoiceInput] = useState(false) const filesStore = useFileStore() - const { - handleDragFileEnter, - handleDragFileLeave, - handleDragFileOver, - handleDropFile, - handleClipboardPasteFile, - isDragActive, - } = useFile(visionConfig!, false) + const { handleDragFileEnter, handleDragFileLeave, handleDragFileOver, handleDropFile, handleClipboardPasteFile, isDragActive } = useFile(visionConfig!, false) const { checkInputsForm } = useCheckInputsForms() const historyRef = useRef(['']) const [currentIndex, setCurrentIndex] = useState(-1) const isComposingRef = useRef(false) - - const handleQueryChange = useCallback( - (value: string) => { - setQuery(value) - setTimeout(handleTextareaResize, 0) - }, - [handleTextareaResize], - ) - + const handleQueryChange = useCallback((value: string) => { + setQuery(value) + setTimeout(handleTextareaResize, 0) + }, [handleTextareaResize]) const handleSend = () => { if (isResponding) { - notify({ type: 'info', message: t('errorMessage.waitForResponse', { ns: 'appDebug' }) }) + toast.info(t('errorMessage.waitForResponse', { ns: 'appDebug' })) return } - if (onSend) { const { files, setFiles } = filesStore.getState() if (files.some(item => item.transferMethod === TransferMethod.local_file && !item.uploadedId)) { - notify({ type: 'info', message: t('errorMessage.waitForFileUpload', { ns: 'appDebug' }) }) + toast.info(t('errorMessage.waitForFileUpload', { ns: 'appDebug' })) return } if (!query || !query.trim()) { - notify({ type: 'info', message: t('errorMessage.queryRequired', { ns: 'appAnnotation' }) }) + toast.info(t('errorMessage.queryRequired', { ns: 'appAnnotation' })) return } if (checkInputsForm(inputs, inputsForm)) { @@ -145,7 +98,6 @@ const ChatInputArea = ({ const isSendCombo = sendOnEnter ? (e.key === 'Enter' && !e.shiftKey) : (e.key === 'Enter' && e.shiftKey) - if (isSendCombo && !e.nativeEvent.isComposing) { // if isComposing, exit if (isComposingRef.current) @@ -176,101 +128,36 @@ const ChatInputArea = ({ } } } - const handleShowVoiceInput = useCallback(() => { (Recorder as any).getPermission().then(() => { setShowVoiceInput(true) }, () => { - notify({ type: 'error', message: t('voiceInput.notAllow', { ns: 'common' }) }) + toast.error(t('voiceInput.notAllow', { ns: 'common' })) }) - }, [t, notify]) - - const operation = ( - - ) - + }, [t]) + const operation = () return ( <> -
+
-
+
-
+
{query}
-