Compare commits

..

29 Commits

Author SHA1 Message Date
Yansong Zhang
ce644e1549 add console_ns import 2026-02-04 16:03:48 +08:00
Stephen Zhou
468990cc39 fix: remove api reference doc link en prefix (#31910) 2026-02-04 14:58:26 +08:00
Coding On Star
64e769f96e refactor: plugin detail panel components for better maintainability and code organization. (#31870)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-02-04 14:51:47 +08:00
Sean Kenneth Doherty
778aabb485 refactor(api): replace reqparse with Pydantic models in trial.py (#31789)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-02-04 14:36:52 +08:00
Stephen Zhou
d8402f686e fix: base url in client (#31902)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-04 12:09:22 +08:00
Tomo
8bd8dee767 fix(docker): improve IRIS data persistence with proper Durable %SYS (#31901)
Co-authored-by: Tomo Okuyama <tomo.okuyama@intersystems.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-04 11:39:26 +08:00
Tomo
05f2764d7c fix(docker): persist IRIS data across container recreation using Durable %SYS (#31899)
Co-authored-by: Tomo Okuyama <tomo.okuyama@intersystems.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-04 09:57:46 +08:00
Asuka Minato
f5d6c250ed fix: "refactor: port api/controllers/console/tag/tags.py to ov3" (#31887)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-03 22:18:53 +08:00
niveshdandyan
45daec7541 refactor: replace line-clamp package with native CSS (#31877)
Co-authored-by: OSS Contributor <oss-contributor@example.com>
Co-authored-by: Claude (claude-opus-4-5) <noreply@anthropic.com>
Co-authored-by: niveshdandyan <niveshdandyan@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-03 22:14:18 +08:00
盐粒 Yanli
c14a8bb437 chore(dev): use strict bash mode for pytest (#31873) 2026-02-03 19:42:42 +08:00
Stephen Zhou
b76c8fa853 test: fix test (#31880)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-03 18:47:05 +08:00
QuantumGhost
8c3e77cd0c chore: update version to 1.12.0 (#31878) 2026-02-03 18:08:15 +08:00
Stephen Zhou
476946f122 test: fix test (#31869) 2026-02-03 17:43:27 +08:00
Joel
62a698a883 fix: create app from template not support review (#31866) 2026-02-03 16:40:35 +08:00
Coding On Star
ebca36ffbb refactor: update oauth_new_user handling in AppInitializer to use parseAsBoolean (#31862)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-02-03 15:20:26 +08:00
Coding On Star
aa7fe42615 test: enhance CommandSelector and GotoAnythingProvider tests (#31743)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-02-03 13:47:30 +08:00
Stephen Zhou
b55c0ec4de fix: revert "refactor: api/controllers/console/feature.py (test)" (#31850) 2026-02-03 12:26:47 +08:00
dependabot[bot]
8b50c0d920 chore(deps-dev): bump types-psutil from 7.0.0.20251116 to 7.2.2.20260130 in /api (#31814)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-03 09:59:29 +08:00
Asuka Minato
47f8de3f8e refactor: port api/controllers/console/app/annotation.py api/controllers/console/explore/trial.py api/controllers/console/workspace/account.py api/controllers/console/workspace/members.py api/controllers/service_api/app/annotation.py to basemodel (#31833)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-03 09:59:00 +08:00
Asuka Minato
491fa9923b refactor: port api/controllers/console/datasets/data_source.py /datasets/metadata.py /service_api/dataset/metadata.py /nodes/agent/agent_node.py api/core/workflow/nodes/datasource/datasource_node.py api/services/dataset_service.py to match case (#31836) 2026-02-02 21:03:16 +09:00
Asuka Minato
ce2c41bbf5 refactor: port api/controllers/console/datasets/datasets_document.py api/controllers/service_api/app/annotation.py api/core/app/app_config/easy_ui_based_app/agent/manager.py api/core/app/apps/pipeline/pipeline_generator.py api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py to match case (#31832) 2026-02-02 19:07:30 +09:00
Asuka Minato
920db69ef2 refactor: if to match (#31799) 2026-02-02 18:12:03 +09:00
Asuka Minato
ac222a4dd4 refactor: port api/controllers/console/app/audio.py api/controllers/console/app/message.py api/controllers/console/auth/data_source_oauth.py api/controllers/console/auth/forgot_password.py api/controllers/console/workspace/endpoint.py (#30680) 2026-02-02 18:03:07 +09:00
Asuka Minato
840a975fef refactor: add test for api/controllers/console/workspace/tool_pr… (#29886)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-02 14:54:16 +09:00
QuantumGhost
9fb72c151c refactor: "chore: update version to 1.12.0" (#31817) 2026-02-02 11:18:18 +08:00
-LAN-
603a896c49 chore(CODEOWNERS): assign .agents/skills to @hyoban (#31816)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2026-02-02 11:12:04 +08:00
FFXN
41177757e6 fix: summary index bug (#31810)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: Yansong Zhang <916125788@qq.com>
Co-authored-by: hj24 <mambahj24@gmail.com>
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-02 09:45:17 +08:00
yyh
4f826b4641 refactor(typing): use enum types for workflow status fields (#31792) 2026-02-02 09:41:34 +08:00
Asuka Minato
3216b67bfa refactor: examples of use match case (#31312)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-01 19:25:54 +09:00
128 changed files with 11493 additions and 4842 deletions

3
.github/CODEOWNERS vendored
View File

@@ -9,6 +9,9 @@
# CODEOWNERS file
/.github/CODEOWNERS @laipz8200 @crazywoola
# Agents
/.agents/skills/ @hyoban
# Docs
/docs/ @crazywoola

View File

@@ -1450,54 +1450,58 @@ def clear_orphaned_file_records(force: bool):
all_ids_in_tables = []
for ids_table in ids_tables:
query = ""
if ids_table["type"] == "uuid":
click.echo(
click.style(
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", fg="white"
match ids_table["type"]:
case "uuid":
click.echo(
click.style(
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}",
fg="white",
)
)
)
query = (
f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
elif ids_table["type"] == "text":
click.echo(
click.style(
f"- Listing file-id-like strings in column {ids_table['column']} in table {ids_table['table']}",
fg="white",
c = ids_table["column"]
query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL"
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
case "text":
t = ids_table["table"]
click.echo(
click.style(
f"- Listing file-id-like strings in column {ids_table['column']} in table {t}",
fg="white",
)
)
)
query = (
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
elif ids_table["type"] == "json":
click.echo(
click.style(
(
f"- Listing file-id-like JSON string in column {ids_table['column']} "
f"in table {ids_table['table']}"
),
fg="white",
query = (
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
)
query = (
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
case "json":
click.echo(
click.style(
(
f"- Listing file-id-like JSON string in column {ids_table['column']} "
f"in table {ids_table['table']}"
),
fg="white",
)
)
query = (
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
case _:
pass
click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white"))
except Exception as e:
@@ -1737,59 +1741,18 @@ def file_usage(
if src_filter != src:
continue
if ids_table["type"] == "uuid":
# Direct UUID match
query = (
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
ref_file_id = str(row[1])
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
elif ids_table["type"] in ("text", "json"):
# Extract UUIDs from text/json content
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
query = (
f"SELECT {ids_table['pk_column']}, {column_cast} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
content = str(row[1])
# Find all UUIDs in the content
import re
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
matches = uuid_pattern.findall(content)
for ref_file_id in matches:
match ids_table["type"]:
case "uuid":
# Direct UUID match
query = (
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
ref_file_id = str(row[1])
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
@@ -1812,6 +1775,50 @@ def file_usage(
)
total_count += 1
case "text" | "json":
# Extract UUIDs from text/json content
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
query = (
f"SELECT {ids_table['pk_column']}, {column_cast} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
content = str(row[1])
# Find all UUIDs in the content
import re
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
matches = uuid_pattern.findall(content)
for ref_file_id in matches:
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
case _:
pass
# Output results
if output_json:
result = {

View File

@@ -1,10 +1,11 @@
from typing import Any, Literal
from flask import abort, make_response, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
@@ -16,9 +17,11 @@ from controllers.console.wraps import (
)
from extensions.ext_redis import redis_client
from fields.annotation_fields import (
annotation_fields,
annotation_hit_history_fields,
build_annotation_model,
Annotation,
AnnotationExportList,
AnnotationHitHistory,
AnnotationHitHistoryList,
AnnotationList,
)
from libs.helper import uuid_value
from libs.login import login_required
@@ -89,6 +92,14 @@ reg(CreateAnnotationPayload)
reg(UpdateAnnotationPayload)
reg(AnnotationReplyStatusQuery)
reg(AnnotationFilePayload)
register_schema_models(
console_ns,
Annotation,
AnnotationList,
AnnotationExportList,
AnnotationHitHistory,
AnnotationHitHistoryList,
)
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
@@ -107,10 +118,11 @@ class AnnotationReplyActionApi(Resource):
def post(self, app_id, action: Literal["enable", "disable"]):
app_id = str(app_id)
args = AnnotationReplyPayload.model_validate(console_ns.payload)
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
match action:
case "enable":
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200
@@ -201,33 +213,33 @@ class AnnotationApi(Resource):
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
response = {
"data": marshal(annotation_list, annotation_fields),
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response = AnnotationList(
data=annotation_models,
has_more=len(annotation_list) == limit,
limit=limit,
total=total,
page=page,
)
return response.model_dump(mode="json"), 200
@console_ns.doc("create_annotation")
@console_ns.doc(description="Create a new annotation for an app")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
@console_ns.response(201, "Annotation created successfully", console_ns.models[Annotation.__name__])
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
@edit_permission_required
def post(self, app_id):
app_id = str(app_id)
args = CreateAnnotationPayload.model_validate(console_ns.payload)
data = args.model_dump(exclude_none=True)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
return annotation
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required
@login_required
@@ -264,7 +276,7 @@ class AnnotationExportApi(Resource):
@console_ns.response(
200,
"Annotations exported successfully",
console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}),
console_ns.models[AnnotationExportList.__name__],
)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@@ -274,7 +286,8 @@ class AnnotationExportApi(Resource):
def get(self, app_id):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response_data = {"data": marshal(annotation_list, annotation_fields)}
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json")
# Create response with secure headers for CSV export
response = make_response(response_data, 200)
@@ -289,7 +302,7 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.doc("update_delete_annotation")
@console_ns.doc(description="Update or delete an annotation")
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
@console_ns.response(200, "Annotation updated successfully", console_ns.models[Annotation.__name__])
@console_ns.response(204, "Annotation deleted successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
@@ -298,7 +311,6 @@ class AnnotationUpdateDeleteApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
app_id = str(app_id)
annotation_id = str(annotation_id)
@@ -306,7 +318,7 @@ class AnnotationUpdateDeleteApi(Resource):
annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
return annotation
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required
@login_required
@@ -414,14 +426,7 @@ class AnnotationHitHistoryListApi(Resource):
@console_ns.response(
200,
"Hit histories retrieved successfully",
console_ns.model(
"AnnotationHitHistoryList",
{
"data": fields.List(
fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields))
)
},
),
console_ns.models[AnnotationHitHistoryList.__name__],
)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@@ -436,11 +441,14 @@ class AnnotationHitHistoryListApi(Resource):
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
app_id, annotation_id, page, limit
)
response = {
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields),
"has_more": len(annotation_hit_history_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response
history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python(
annotation_hit_history_list, from_attributes=True
)
response = AnnotationHitHistoryList(
data=history_models,
has_more=len(annotation_hit_history_list) == limit,
limit=limit,
total=total,
page=page,
)
return response.model_dump(mode="json")

View File

@@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
AppUnavailableError,
@@ -33,7 +34,6 @@ from services.errors.audio import (
)
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TextToSpeechPayload(BaseModel):
@@ -47,13 +47,11 @@ class TextToSpeechVoiceQuery(BaseModel):
language: str = Field(..., description="Language code")
console_ns.schema_model(
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
TextToSpeechVoiceQuery.__name__,
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
class AudioTranscriptResponse(BaseModel):
text: str = Field(description="Transcribed text from audio")
register_schema_models(console_ns, AudioTranscriptResponse, TextToSpeechPayload, TextToSpeechVoiceQuery)
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
@@ -64,7 +62,7 @@ class ChatMessageAudioApi(Resource):
@console_ns.response(
200,
"Audio transcription successful",
console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
console_ns.models[AudioTranscriptResponse.__name__],
)
@console_ns.response(400, "Bad request - No audio uploaded or unsupported type")
@console_ns.response(413, "Audio file too large")

View File

@@ -508,16 +508,19 @@ class ChatConversationApi(Resource):
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args.annotation_status == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
match args.annotation_status:
case "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
case "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
case "all":
pass
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)

View File

@@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
CompletionRequestError,
@@ -35,7 +36,6 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
from services.message_service import MessageService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ChatMessagesQuery(BaseModel):
@@ -90,13 +90,22 @@ class FeedbackExportQuery(BaseModel):
raise ValueError("has_comment must be a boolean value")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class AnnotationCountResponse(BaseModel):
count: int = Field(description="Number of annotations")
reg(ChatMessagesQuery)
reg(MessageFeedbackPayload)
reg(FeedbackExportQuery)
class SuggestedQuestionsResponse(BaseModel):
data: list[str] = Field(description="Suggested question")
register_schema_models(
console_ns,
ChatMessagesQuery,
MessageFeedbackPayload,
FeedbackExportQuery,
AnnotationCountResponse,
SuggestedQuestionsResponse,
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@@ -231,7 +240,7 @@ class ChatMessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required
def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = ChatMessagesQuery.model_validate(request.args.to_dict())
conversation = (
db.session.query(Conversation)
@@ -356,7 +365,7 @@ class MessageAnnotationCountApi(Resource):
@console_ns.response(
200,
"Annotation count retrieved successfully",
console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
console_ns.models[AnnotationCountResponse.__name__],
)
@get_app_model
@setup_required
@@ -376,9 +385,7 @@ class MessageSuggestedQuestionApi(Resource):
@console_ns.response(
200,
"Suggested questions retrieved successfully",
console_ns.model(
"SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
),
console_ns.models[SuggestedQuestionsResponse.__name__],
)
@console_ns.response(404, "Message or conversation not found")
@setup_required
@@ -428,7 +435,7 @@ class MessageFeedbackExportApi(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = FeedbackExportQuery.model_validate(request.args.to_dict())
# Import the service function
from services.feedback_service import FeedbackService

View File

@@ -2,9 +2,11 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource, fields
from flask_restx import Resource
from pydantic import BaseModel, Field
from configs import dify_config
from controllers.common.schema import register_schema_models
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
@@ -14,6 +16,26 @@ from ..wraps import account_initialization_required, is_admin_or_owner_required,
logger = logging.getLogger(__name__)
class OAuthDataSourceResponse(BaseModel):
data: str = Field(description="Authorization URL or 'internal' for internal setup")
class OAuthDataSourceBindingResponse(BaseModel):
result: str = Field(description="Operation result")
class OAuthDataSourceSyncResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(
console_ns,
OAuthDataSourceResponse,
OAuthDataSourceBindingResponse,
OAuthDataSourceSyncResponse,
)
def get_oauth_providers():
with current_app.app_context():
notion_oauth = NotionOAuth(
@@ -34,10 +56,7 @@ class OAuthDataSource(Resource):
@console_ns.response(
200,
"Authorization URL or internal setup success",
console_ns.model(
"OAuthDataSourceResponse",
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
),
console_ns.models[OAuthDataSourceResponse.__name__],
)
@console_ns.response(400, "Invalid provider")
@console_ns.response(403, "Admin privileges required")
@@ -101,7 +120,7 @@ class OAuthDataSourceBinding(Resource):
@console_ns.response(
200,
"Data source binding success",
console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
console_ns.models[OAuthDataSourceBindingResponse.__name__],
)
@console_ns.response(400, "Invalid provider or code")
def get(self, provider: str):
@@ -133,7 +152,7 @@ class OAuthDataSourceSync(Resource):
@console_ns.response(
200,
"Data source sync success",
console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
console_ns.models[OAuthDataSourceSyncResponse.__name__],
)
@console_ns.response(400, "Invalid provider or sync failed")
@setup_required

View File

@@ -2,10 +2,11 @@ import base64
import secrets
from flask import request
from flask_restx import Resource, fields
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailCodeError,
@@ -48,8 +49,31 @@ class ForgotPasswordResetPayload(BaseModel):
return valid_password(value)
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class ForgotPasswordEmailResponse(BaseModel):
result: str = Field(description="Operation result")
data: str | None = Field(default=None, description="Reset token")
code: str | None = Field(default=None, description="Error code if account not found")
class ForgotPasswordCheckResponse(BaseModel):
is_valid: bool = Field(description="Whether code is valid")
email: EmailStr = Field(description="Email address")
token: str = Field(description="New reset token")
class ForgotPasswordResetResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(
console_ns,
ForgotPasswordSendPayload,
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordEmailResponse,
ForgotPasswordCheckResponse,
ForgotPasswordResetResponse,
)
@console_ns.route("/forgot-password")
@@ -60,14 +84,7 @@ class ForgotPasswordSendEmailApi(Resource):
@console_ns.response(
200,
"Email sent successfully",
console_ns.model(
"ForgotPasswordEmailResponse",
{
"result": fields.String(description="Operation result"),
"data": fields.String(description="Reset token"),
"code": fields.String(description="Error code if account not found"),
},
),
console_ns.models[ForgotPasswordEmailResponse.__name__],
)
@console_ns.response(400, "Invalid email or rate limit exceeded")
@setup_required
@@ -106,14 +123,7 @@ class ForgotPasswordCheckApi(Resource):
@console_ns.response(
200,
"Code verified successfully",
console_ns.model(
"ForgotPasswordCheckResponse",
{
"is_valid": fields.Boolean(description="Whether code is valid"),
"email": fields.String(description="Email address"),
"token": fields.String(description="New reset token"),
},
),
console_ns.models[ForgotPasswordCheckResponse.__name__],
)
@console_ns.response(400, "Invalid code or token")
@setup_required
@@ -163,7 +173,7 @@ class ForgotPasswordResetApi(Resource):
@console_ns.response(
200,
"Password reset successfully",
console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
console_ns.models[ForgotPasswordResetResponse.__name__],
)
@console_ns.response(400, "Invalid token or password mismatch")
@setup_required

View File

@@ -155,43 +155,43 @@ class OAuthServerUserTokenApi(Resource):
grant_type = OAuthGrantType(payload.grant_type)
except ValueError:
raise BadRequest("invalid grant_type")
match grant_type:
case OAuthGrantType.AUTHORIZATION_CODE:
if not payload.code:
raise BadRequest("code is required")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not payload.code:
raise BadRequest("code is required")
if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
case OAuthGrantType.REFRESH_TOKEN:
if not payload.refresh_token:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not payload.refresh_token:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
@console_ns.route("/oauth/provider/account")

View File

@@ -1,6 +1,6 @@
import json
from collections.abc import Generator
from typing import Any, cast
from typing import Any, Literal, cast
from flask import request
from flask_restx import Resource, fields, marshal_with
@@ -157,9 +157,8 @@ class DataSourceApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, binding_id, action):
def patch(self, binding_id, action: Literal["enable", "disable"]):
binding_id = str(binding_id)
action = str(action)
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id)
@@ -167,23 +166,24 @@ class DataSourceApi(Resource):
if data_source_binding is None:
raise NotFound("Data source binding not found.")
# enable binding
if action == "enable":
if data_source_binding.disabled:
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is not disabled.")
# disable binding
if action == "disable":
if not data_source_binding.disabled:
data_source_binding.disabled = True
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is disabled.")
match action:
case "enable":
if data_source_binding.disabled:
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is not disabled.")
# disable binding
case "disable":
if not data_source_binding.disabled:
data_source_binding.disabled = True
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is disabled.")
return {"result": "success"}, 200

View File

@@ -576,63 +576,62 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict
match document.data_source_type:
case "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
)
if document.data_source_type == "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
)
if file_detail is None:
raise NotFound("File not found.")
if file_detail is None:
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
case "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_tenant_id,
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
case "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
elif document.data_source_type == "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_tenant_id,
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
else:
raise ValueError("Data source type not support")
case _:
raise ValueError("Data source type not support")
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(
@@ -954,23 +953,24 @@ class DocumentProcessingApi(DocumentResource):
if not current_user.is_dataset_editor:
raise Forbidden()
if action == "pause":
if document.indexing_status != "indexing":
raise InvalidActionError("Document not in indexing state.")
match action:
case "pause":
if document.indexing_status != "indexing":
raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id
document.paused_at = naive_utc_now()
document.is_paused = True
db.session.commit()
document.paused_by = current_user.id
document.paused_at = naive_utc_now()
document.is_paused = True
db.session.commit()
elif action == "resume":
if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.")
case "resume":
if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None
document.paused_at = None
document.is_paused = False
db.session.commit()
document.paused_by = None
document.paused_at = None
document.is_paused = False
db.session.commit()
return {"result": "success"}, 200
@@ -1339,6 +1339,18 @@ class DocumentGenerateSummaryApi(Resource):
missing_ids = set(document_list) - found_ids
raise NotFound(f"Some documents not found: {list(missing_ids)}")
# Update need_summary to True for documents that don't have it set
# This handles the case where documents were created when summary_index_setting was disabled
documents_to_update = [doc for doc in documents if not doc.need_summary and doc.doc_form != "qa_model"]
if documents_to_update:
document_ids_to_update = [str(doc.id) for doc in documents_to_update]
DocumentService.update_documents_need_summary(
dataset_id=dataset_id,
document_ids=document_ids_to_update,
need_summary=True,
)
# Dispatch async tasks for each document
for document in documents:
# Skip qa_model documents as they don't generate summaries

View File

@@ -126,10 +126,11 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
if action == "enable":
MetadataService.enable_built_in_field(dataset)
elif action == "disable":
MetadataService.disable_built_in_field(dataset)
match action:
case "enable":
MetadataService.enable_built_in_field(dataset)
case "disable":
MetadataService.disable_built_in_field(dataset)
return {"result": "success"}, 200

View File

@@ -1,8 +1,9 @@
import logging
from typing import Any, cast
from typing import Any, Literal, cast
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
@@ -51,7 +52,7 @@ from fields.app_fields import (
tag_fields,
)
from fields.dataset_fields import dataset_fields
from fields.member_fields import build_simple_account_model
from fields.member_fields import simple_account_fields
from fields.workflow_fields import (
conversation_variable_fields,
pipeline_variable_fields,
@@ -103,7 +104,7 @@ app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model))
app_detail_fields_with_site_copy["site"] = fields.Nested(site_model)
app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy)
simple_account_model = build_simple_account_model(console_ns)
simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields)
conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields)
pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields)
@@ -117,7 +118,56 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel
workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy)
# Pydantic models for request validation
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowRunRequest(BaseModel):
inputs: dict
files: list | None = None
class ChatRequest(BaseModel):
inputs: dict
query: str
files: list | None = None
conversation_id: str | None = None
parent_message_id: str | None = None
retriever_from: str = "explore_app"
class TextToSpeechRequest(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None
class CompletionRequest(BaseModel):
inputs: dict
query: str = ""
files: list | None = None
response_mode: Literal["blocking", "streaming"] | None = None
retriever_from: str = "explore_app"
# Register schemas for Swagger documentation
console_ns.schema_model(
WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class TrialAppWorkflowRunApi(TrialAppResource):
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
def post(self, trial_app):
"""
Run workflow
@@ -129,10 +179,8 @@ class TrialAppWorkflowRunApi(TrialAppResource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args()
request_data = WorkflowRunRequest.model_validate(console_ns.payload)
args = request_data.model_dump()
assert current_user is not None
try:
app_id = app_model.id
@@ -183,6 +231,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
class TrialChatApi(TrialAppResource):
@console_ns.expect(console_ns.models[ChatRequest.__name__])
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
@@ -190,14 +239,14 @@ class TrialChatApi(TrialAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
request_data = ChatRequest.model_validate(console_ns.payload)
args = request_data.model_dump()
# Validate UUID values if provided
if args.get("conversation_id"):
args["conversation_id"] = uuid_value(args["conversation_id"])
if args.get("parent_message_id"):
args["parent_message_id"] = uuid_value(args["parent_message_id"])
args["auto_generate_name"] = False
@@ -320,20 +369,16 @@ class TrialChatAudioApi(TrialAppResource):
class TrialChatTextApi(TrialAppResource):
@console_ns.expect(console_ns.models[TextToSpeechRequest.__name__])
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
request_data = TextToSpeechRequest.model_validate(console_ns.payload)
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
message_id = request_data.message_id
text = request_data.text
voice = request_data.voice
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
@@ -371,19 +416,15 @@ class TrialChatTextApi(TrialAppResource):
class TrialCompletionApi(TrialAppResource):
@console_ns.expect(console_ns.models[CompletionRequest.__name__])
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
request_data = CompletionRequest.model_validate(console_ns.payload)
args = request_data.model_dump()
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False

View File

@@ -1,58 +1,60 @@
from pydantic import BaseModel, Field
from flask_restx import Resource, fields
from werkzeug.exceptions import Unauthorized
from controllers.fastopenapi import console_router
from libs.login import current_account_with_tenant, current_user, login_required
from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel
from services.feature_service import FeatureService
from . import console_ns
from .wraps import account_initialization_required, cloud_utm_record, setup_required
class FeatureResponse(BaseModel):
features: FeatureModel = Field(description="Feature configuration object")
@console_ns.route("/features")
class FeatureApi(Resource):
@console_ns.doc("get_tenant_features")
@console_ns.doc(description="Get feature configuration for current tenant")
@console_ns.response(
200,
"Success",
console_ns.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}),
)
@setup_required
@login_required
@account_initialization_required
@cloud_utm_record
def get(self):
"""Get feature configuration for current tenant"""
_, current_tenant_id = current_account_with_tenant()
return FeatureService.get_features(current_tenant_id).model_dump()
class SystemFeatureResponse(BaseModel):
features: SystemFeatureModel = Field(description="System feature configuration object")
@console_ns.route("/system-features")
class SystemFeatureApi(Resource):
@console_ns.doc("get_system_features")
@console_ns.doc(description="Get system-wide feature configuration")
@console_ns.response(
200,
"Success",
console_ns.model(
"SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}
),
)
def get(self):
"""Get system-wide feature configuration
NOTE: This endpoint is unauthenticated by design, as it provides system features
data required for dashboard initialization.
@console_router.get(
"/features",
response_model=FeatureResponse,
tags=["console"],
)
@setup_required
@login_required
@account_initialization_required
@cloud_utm_record
def get_tenant_features() -> FeatureResponse:
"""Get feature configuration for current tenant."""
_, current_tenant_id = current_account_with_tenant()
Authentication would create circular dependency (can't login without dashboard loading).
return FeatureResponse(features=FeatureService.get_features(current_tenant_id))
@console_router.get(
"/system-features",
response_model=SystemFeatureResponse,
tags=["console"],
)
def get_system_features() -> SystemFeatureResponse:
"""Get system-wide feature configuration
NOTE: This endpoint is unauthenticated by design, as it provides system features
data required for dashboard initialization.
Authentication would create circular dependency (can't login without dashboard loading).
Only non-sensitive configuration data should be returned by this endpoint.
"""
# NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
# without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
# in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
# raise `Unauthorized` exception if authentication token is not provided.
try:
is_authenticated = current_user.is_authenticated
except Unauthorized:
is_authenticated = False
return SystemFeatureResponse(features=FeatureService.get_system_features(is_authenticated=is_authenticated))
Only non-sensitive configuration data should be returned by this endpoint.
"""
# NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
# without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
# in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
# raise `Unauthorized` exception if authentication token is not provided.
try:
is_authenticated = current_user.is_authenticated
except Unauthorized:
is_authenticated = False
return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump()

View File

@@ -1,14 +1,27 @@
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Namespace, Resource, fields, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.fastopenapi import console_router
from libs.login import current_account_with_tenant, login_required
from services.tag_service import TagService
dataset_tag_fields = {
"id": fields.String,
"name": fields.String,
"type": fields.String,
"binding_count": fields.String,
}
def build_dataset_tag_fields(api_or_ns: Namespace):
return api_or_ns.model("DataSetTag", dataset_tag_fields)
class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50)
@@ -32,129 +45,115 @@ class TagListQueryParam(BaseModel):
keyword: str | None = Field(None, description="Search keyword")
class TagResponse(BaseModel):
id: str = Field(description="Tag ID")
name: str = Field(description="Tag name")
type: str = Field(description="Tag type")
binding_count: int = Field(description="Number of bindings")
class TagBindingResult(BaseModel):
result: Literal["success"] = Field(description="Operation result", examples=["success"])
@console_router.get(
"/tags",
response_model=list[TagResponse],
tags=["console"],
register_schema_models(
console_ns,
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
TagListQueryParam,
)
@setup_required
@login_required
@account_initialization_required
def list_tags(query: TagListQueryParam) -> list[TagResponse]:
_, current_tenant_id = current_account_with_tenant()
tags = TagService.get_tags(query.type, current_tenant_id, query.keyword)
return [
TagResponse(
id=tag.id,
name=tag.name,
type=tag.type,
binding_count=int(tag.binding_count),
)
for tag in tags
]
@console_router.post(
"/tags",
response_model=TagResponse,
tags=["console"],
)
@setup_required
@login_required
@account_initialization_required
def create_tag(payload: TagBasePayload) -> TagResponse:
current_user, _ = current_account_with_tenant()
# The role of the current user in the tag table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@console_ns.route("/tags")
class TagListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.doc(
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
)
@marshal_with(dataset_tag_fields)
def get(self):
_, current_tenant_id = current_account_with_tenant()
raw_args = request.args.to_dict()
param = TagListQueryParam.model_validate(raw_args)
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
tag = TagService.save_tags(payload.model_dump())
return tags, 200
return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=0)
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(payload.model_dump())
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
return response, 200
@console_router.patch(
"/tags/<uuid:tag_id>",
response_model=TagResponse,
tags=["console"],
)
@setup_required
@login_required
@account_initialization_required
def update_tag(tag_id: UUID, payload: TagBasePayload) -> TagResponse:
current_user, _ = current_account_with_tenant()
tag_id_str = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@console_ns.route("/tags/<uuid:tag_id>")
class TagUpdateDeleteApi(Resource):
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def patch(self, tag_id):
current_user, _ = current_account_with_tenant()
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
tag = TagService.update_tags(payload.model_dump(), tag_id_str)
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(payload.model_dump(), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id_str)
binding_count = TagService.get_tag_binding_count(tag_id)
return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=binding_count)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
return response, 200
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, tag_id):
tag_id = str(tag_id)
TagService.delete_tag(tag_id)
return 204
@console_router.delete(
"/tags/<uuid:tag_id>",
tags=["console"],
status_code=204,
)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete_tag(tag_id: UUID) -> None:
tag_id_str = str(tag_id)
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
TagService.delete_tag(tag_id_str)
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(payload.model_dump())
return {"result": "success"}, 200
@console_router.post(
"/tag-bindings/create",
response_model=TagBindingResult,
tags=["console"],
)
@setup_required
@login_required
@account_initialization_required
def create_tag_binding(payload: TagBindingPayload) -> TagBindingResult:
current_user, _ = current_account_with_tenant()
# The role of the current user in the tag table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource):
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
TagService.save_tag_binding(payload.model_dump())
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(payload.model_dump())
return TagBindingResult(result="success")
@console_router.post(
"/tag-bindings/remove",
response_model=TagBindingResult,
tags=["console"],
)
@setup_required
@login_required
@account_initialization_required
def delete_tag_binding(payload: TagBindingRemovePayload) -> TagBindingResult:
current_user, _ = current_account_with_tenant()
# The role of the current user in the tag table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
TagService.delete_tag_binding(payload.model_dump())
return TagBindingResult(result="success")
return {"result": "success"}, 200

View File

@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
@@ -37,7 +38,7 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.member_fields import account_fields
from fields.member_fields import Account as AccountResponse
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
@@ -170,6 +171,12 @@ reg(ChangeEmailSendPayload)
reg(ChangeEmailValidityPayload)
reg(ChangeEmailResetPayload)
reg(CheckEmailUniquePayload)
register_schema_models(console_ns, AccountResponse)
def _serialize_account(account) -> dict:
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
integrate_fields = {
"provider": fields.String,
@@ -236,11 +243,11 @@ class AccountProfileApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@enterprise_license_required
def get(self):
current_user, _ = current_account_with_tenant()
return current_user
return _serialize_account(current_user)
@console_ns.route("/account/name")
@@ -249,14 +256,14 @@ class AccountNameApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = AccountNamePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, name=args.name)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/avatar")
@@ -265,7 +272,7 @@ class AccountAvatarApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@@ -273,7 +280,7 @@ class AccountAvatarApi(Resource):
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/interface-language")
@@ -282,7 +289,7 @@ class AccountInterfaceLanguageApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@@ -290,7 +297,7 @@ class AccountInterfaceLanguageApi(Resource):
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/interface-theme")
@@ -299,7 +306,7 @@ class AccountInterfaceThemeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@@ -307,7 +314,7 @@ class AccountInterfaceThemeApi(Resource):
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/timezone")
@@ -316,7 +323,7 @@ class AccountTimezoneApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@@ -324,7 +331,7 @@ class AccountTimezoneApi(Resource):
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/password")
@@ -333,7 +340,7 @@ class AccountPasswordApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@@ -344,7 +351,7 @@ class AccountPasswordApi(Resource):
except ServiceCurrentPasswordIncorrectError:
raise CurrentPasswordIncorrectError()
return {"result": "success"}
return _serialize_account(current_user)
@console_ns.route("/account/integrates")
@@ -620,7 +627,7 @@ class ChangeEmailResetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
@@ -649,7 +656,7 @@ class ChangeEmailResetApi(Resource):
email=normalized_new_email,
)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/change-email/check-email-unique")

View File

@@ -1,9 +1,10 @@
from typing import Any
from flask import request
from flask_restx import Resource, fields
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -38,15 +39,53 @@ class EndpointListForPluginQuery(EndpointListQuery):
plugin_id: str
class EndpointCreateResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointListResponse(BaseModel):
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
class PluginEndpointListResponse(BaseModel):
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
class EndpointDeleteResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointUpdateResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointEnableResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointDisableResponse(BaseModel):
success: bool = Field(description="Operation success")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(EndpointCreatePayload)
reg(EndpointIdPayload)
reg(EndpointUpdatePayload)
reg(EndpointListQuery)
reg(EndpointListForPluginQuery)
register_schema_models(
console_ns,
EndpointCreatePayload,
EndpointIdPayload,
EndpointUpdatePayload,
EndpointListQuery,
EndpointListForPluginQuery,
EndpointCreateResponse,
EndpointListResponse,
PluginEndpointListResponse,
EndpointDeleteResponse,
EndpointUpdateResponse,
EndpointEnableResponse,
EndpointDisableResponse,
)
@console_ns.route("/workspaces/current/endpoints/create")
@@ -57,7 +96,7 @@ class EndpointCreateApi(Resource):
@console_ns.response(
200,
"Endpoint created successfully",
console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
console_ns.models[EndpointCreateResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@@ -91,9 +130,7 @@ class EndpointListApi(Resource):
@console_ns.response(
200,
"Success",
console_ns.model(
"EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
),
console_ns.models[EndpointListResponse.__name__],
)
@setup_required
@login_required
@@ -126,9 +163,7 @@ class EndpointListForSinglePluginApi(Resource):
@console_ns.response(
200,
"Success",
console_ns.model(
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
),
console_ns.models[PluginEndpointListResponse.__name__],
)
@setup_required
@login_required
@@ -163,7 +198,7 @@ class EndpointDeleteApi(Resource):
@console_ns.response(
200,
"Endpoint deleted successfully",
console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
console_ns.models[EndpointDeleteResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@@ -190,7 +225,7 @@ class EndpointUpdateApi(Resource):
@console_ns.response(
200,
"Endpoint updated successfully",
console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
console_ns.models[EndpointUpdateResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@@ -221,7 +256,7 @@ class EndpointEnableApi(Resource):
@console_ns.response(
200,
"Endpoint enabled successfully",
console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
console_ns.models[EndpointEnableResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@@ -248,7 +283,7 @@ class EndpointDisableApi(Resource):
@console_ns.response(
200,
"Endpoint disabled successfully",
console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
console_ns.models[EndpointDisableResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required

View File

@@ -1,12 +1,12 @@
from urllib import parse
from flask import abort, request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter
import services
from configs import dify_config
from controllers.common.schema import get_or_create_model, register_enum_models
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
@@ -25,7 +25,7 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.member_fields import account_with_role_fields, account_with_role_list_fields
from fields.member_fields import AccountWithRole, AccountWithRoleList
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole
@@ -69,12 +69,7 @@ reg(OwnerTransferEmailPayload)
reg(OwnerTransferCheckPayload)
reg(OwnerTransferPayload)
register_enum_models(console_ns, TenantAccountRole)
account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields)
account_with_role_list_fields_copy = account_with_role_list_fields.copy()
account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model))
account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy)
register_schema_models(console_ns, AccountWithRole, AccountWithRoleList)
@console_ns.route("/workspaces/current/members")
@@ -84,13 +79,15 @@ class MemberListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_with_role_list_model)
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
def get(self):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
return {"result": "success", "accounts": members}, 200
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = AccountWithRoleList(accounts=member_models)
return response.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/members/invite-email")
@@ -235,13 +232,15 @@ class DatasetOperatorMemberListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_with_role_list_model)
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
def get(self):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
return {"result": "success", "accounts": members}, 200
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = AccountWithRoleList(accounts=member_models)
return response.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")

View File

@@ -1,16 +1,16 @@
from typing import Literal
from flask import request
from flask_restx import Namespace, Resource, fields
from flask_restx import Resource
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, TypeAdapter
from controllers.common.schema import register_schema_models
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_fields, build_annotation_model
from fields.annotation_fields import Annotation, AnnotationList
from models.model import App
from services.annotation_service import AppAnnotationService
@@ -26,7 +26,9 @@ class AnnotationReplyActionPayload(BaseModel):
embedding_model_name: str = Field(description="Embedding model name")
register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload)
register_schema_models(
service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload, Annotation, AnnotationList
)
@service_api_ns.route("/apps/annotation-reply/<string:action>")
@@ -45,10 +47,11 @@ class AnnotationReplyActionApi(Resource):
def post(self, app_model: App, action: Literal["enable", "disable"]):
"""Enable or disable annotation reply feature."""
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
match action:
case "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
return result, 200
@@ -82,23 +85,6 @@ class AnnotationReplyActionStatusApi(Resource):
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
# Define annotation list response model
annotation_list_fields = {
"data": fields.List(fields.Nested(annotation_fields)),
"has_more": fields.Boolean,
"limit": fields.Integer,
"total": fields.Integer,
"page": fields.Integer,
}
def build_annotation_list_model(api_or_ns: Namespace):
"""Build the annotation list model for the API or Namespace."""
copied_annotation_list_fields = annotation_list_fields.copy()
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
return api_or_ns.model("AnnotationList", copied_annotation_list_fields)
@service_api_ns.route("/apps/annotations")
class AnnotationListApi(Resource):
@service_api_ns.doc("list_annotations")
@@ -109,8 +95,12 @@ class AnnotationListApi(Resource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
200,
"Annotations retrieved successfully",
service_api_ns.models[AnnotationList.__name__],
)
@validate_app_token
@service_api_ns.marshal_with(build_annotation_list_model(service_api_ns))
def get(self, app_model: App):
"""List annotations for the application."""
page = request.args.get("page", default=1, type=int)
@@ -118,13 +108,15 @@ class AnnotationListApi(Resource):
keyword = request.args.get("keyword", default="", type=str)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
return {
"data": annotation_list,
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response = AnnotationList(
data=annotation_models,
has_more=len(annotation_list) == limit,
limit=limit,
total=total,
page=page,
)
return response.model_dump(mode="json")
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
@service_api_ns.doc("create_annotation")
@@ -135,13 +127,18 @@ class AnnotationListApi(Resource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
HTTPStatus.CREATED,
"Annotation created successfully",
service_api_ns.models[Annotation.__name__],
)
@validate_app_token
@service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED)
def post(self, app_model: App):
"""Create a new annotation."""
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
return annotation, 201
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json"), HTTPStatus.CREATED
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
@@ -158,14 +155,19 @@ class AnnotationUpdateDeleteApi(Resource):
404: "Annotation not found",
}
)
@service_api_ns.response(
200,
"Annotation updated successfully",
service_api_ns.models[Annotation.__name__],
)
@validate_app_token
@edit_permission_required
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation."""
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json")
@service_api_ns.doc("delete_annotation")
@service_api_ns.doc(description="Delete an annotation")

View File

@@ -17,7 +17,7 @@ from controllers.service_api.wraps import (
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import build_dataset_tag_fields
from fields.tag_fields import DataSetTag
from libs.login import current_user
from models.account import Account
from models.dataset import DatasetPermissionEnum
@@ -114,6 +114,7 @@ register_schema_models(
TagBindingPayload,
TagUnbindingPayload,
DatasetListQuery,
DataSetTag,
)
@@ -480,15 +481,14 @@ class DatasetTagsApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def get(self, _):
"""Get all knowledge type tags."""
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
tags = TagService.get_tags("knowledge", cid)
return tags, 200
tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True)
return [tag.model_dump(mode="json") for tag in tag_models], 200
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
@service_api_ns.doc("create_dataset_tag")
@@ -500,7 +500,6 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def post(self, _):
"""Add a knowledge type tag."""
assert isinstance(current_user, Account)
@@ -510,7 +509,9 @@ class DatasetTagsApi(DatasetApiResource):
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
).model_dump(mode="json")
return response, 200
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
@@ -523,7 +524,6 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def patch(self, _):
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@@ -536,8 +536,9 @@ class DatasetTagsApi(DatasetApiResource):
binding_count = TagService.get_tag_binding_count(tag_id)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
).model_dump(mode="json")
return response, 200
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])

View File

@@ -168,10 +168,11 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
if action == "enable":
MetadataService.enable_built_in_field(dataset)
elif action == "disable":
MetadataService.disable_built_in_field(dataset)
match action:
case "enable":
MetadataService.enable_built_in_field(dataset)
case "disable":
MetadataService.disable_built_in_field(dataset)
return {"result": "success"}, 200

View File

@@ -73,14 +73,14 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
# If caller needs end-user context, attach EndUser to current_user
if fetch_user_arg:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get("user")
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
user_id = request.get_json().get("user")
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get("user")
else:
user_id = None
user_id = None
match fetch_user_arg.fetch_from:
case WhereisUserArg.QUERY:
user_id = request.args.get("user")
case WhereisUserArg.JSON:
user_id = request.get_json().get("user")
case WhereisUserArg.FORM:
user_id = request.form.get("user")
if not user_id and fetch_user_arg.required:
raise ValueError("Arg user must be provided.")

View File

@@ -14,16 +14,17 @@ class AgentConfigManager:
agent_dict = config.get("agent_mode", {})
agent_strategy = agent_dict.get("strategy", "cot")
if agent_strategy == "function_call":
strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy in {"cot", "react"}:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
# old configs, try to detect default strategy
if config["model"]["provider"] == "openai":
match agent_strategy:
case "function_call":
strategy = AgentEntity.Strategy.FUNCTION_CALLING
else:
case "cot" | "react":
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
case _:
# old configs, try to detect default strategy
if config["model"]["provider"] == "openai":
strategy = AgentEntity.Strategy.FUNCTION_CALLING
else:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
agent_tools = []
for tool in agent_dict.get("tools", []):

View File

@@ -250,7 +250,7 @@ class WorkflowResponseConverter:
data=WorkflowFinishStreamResponse.Data(
id=run_id,
workflow_id=workflow_id,
status=status.value,
status=status,
outputs=encoded_outputs,
error=error,
elapsed_time=elapsed_time,
@@ -340,13 +340,13 @@ class WorkflowResponseConverter:
metadata = self._merge_metadata(event.execution_metadata, snapshot)
if isinstance(event, QueueNodeSucceededEvent):
status = WorkflowNodeExecutionStatus.SUCCEEDED.value
status = WorkflowNodeExecutionStatus.SUCCEEDED
error_message = event.error
elif isinstance(event, QueueNodeFailedEvent):
status = WorkflowNodeExecutionStatus.FAILED.value
status = WorkflowNodeExecutionStatus.FAILED
error_message = event.error
else:
status = WorkflowNodeExecutionStatus.EXCEPTION.value
status = WorkflowNodeExecutionStatus.EXCEPTION
error_message = event.error
return NodeFinishStreamResponse(
@@ -413,7 +413,7 @@ class WorkflowResponseConverter:
process_data_truncated=process_data_truncated,
outputs=outputs,
outputs_truncated=outputs_truncated,
status=WorkflowNodeExecutionStatus.RETRY.value,
status=WorkflowNodeExecutionStatus.RETRY,
error=event.error,
elapsed_time=elapsed_time,
execution_metadata=metadata,

View File

@@ -120,7 +120,7 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("Pipeline dataset is required")
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
datasource_type = DatasourceProviderType(args["datasource_type"])
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
)
@@ -660,7 +660,7 @@ class PipelineGenerator(BaseAppGenerator):
tenant_id: str,
dataset_id: str,
built_in_field_enabled: bool,
datasource_type: str,
datasource_type: DatasourceProviderType,
datasource_info: Mapping[str, Any],
created_from: str,
position: int,
@@ -668,17 +668,17 @@ class PipelineGenerator(BaseAppGenerator):
batch: str,
document_form: str,
):
if datasource_type == "local_file":
name = datasource_info.get("name", "untitled")
elif datasource_type == "online_document":
name = datasource_info.get("page", {}).get("page_name", "untitled")
elif datasource_type == "website_crawl":
name = datasource_info.get("title", "untitled")
elif datasource_type == "online_drive":
name = datasource_info.get("name", "untitled")
else:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
match datasource_type:
case DatasourceProviderType.LOCAL_FILE:
name = datasource_info.get("name", "untitled")
case DatasourceProviderType.ONLINE_DOCUMENT:
name = datasource_info.get("page", {}).get("page_name", "untitled")
case DatasourceProviderType.WEBSITE_CRAWL:
name = datasource_info.get("title", "untitled")
case DatasourceProviderType.ONLINE_DRIVE:
name = datasource_info.get("name", "untitled")
case _:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
document = Document(
tenant_id=tenant_id,
dataset_id=dataset_id,
@@ -706,7 +706,7 @@ class PipelineGenerator(BaseAppGenerator):
def _format_datasource_info_list(
self,
datasource_type: str,
datasource_type: DatasourceProviderType,
datasource_info_list: list[Mapping[str, Any]],
pipeline: Pipeline,
workflow: Workflow,
@@ -716,7 +716,7 @@ class PipelineGenerator(BaseAppGenerator):
"""
Format datasource info list.
"""
if datasource_type == "online_drive":
if datasource_type == DatasourceProviderType.ONLINE_DRIVE:
all_files: list[Mapping[str, Any]] = []
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])

View File

@@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class AnnotationReplyAccount(BaseModel):
@@ -223,7 +223,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
id: str
workflow_id: str
status: str
status: WorkflowExecutionStatus
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float
@@ -311,7 +311,7 @@ class NodeFinishStreamResponse(StreamResponse):
process_data_truncated: bool = False
outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = True
status: str
status: WorkflowNodeExecutionStatus
error: str | None = None
elapsed_time: float
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
@@ -375,7 +375,7 @@ class NodeRetryStreamResponse(StreamResponse):
process_data_truncated: bool = False
outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = False
status: str
status: WorkflowNodeExecutionStatus
error: str | None = None
elapsed_time: float
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
@@ -719,7 +719,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
id: str
workflow_id: str
status: str
status: WorkflowExecutionStatus
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float

View File

@@ -369,77 +369,78 @@ class IndexingRunner:
# Generate summary preview
summary_index_setting = tmp_processing_rule.get("summary_index_setting")
if summary_index_setting and summary_index_setting.get("enable") and preview_texts:
preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
preview_texts = index_processor.generate_summary_preview(
tenant_id, preview_texts, summary_index_setting, doc_language
)
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
) -> list[Document]:
# load file
if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
return []
data_source_info = dataset_document.data_source_info_dict
text_docs = []
if dataset_document.data_source_type == "upload_file":
if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found")
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
file_detail = db.session.scalars(stmt).one_or_none()
match dataset_document.data_source_type:
case "upload_file":
if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found")
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
file_detail = db.session.scalars(stmt).one_or_none()
if file_detail:
if file_detail:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case "notion_import":
if (
not data_source_info
or "notion_workspace_id" not in data_source_info
or "notion_page_id" not in data_source_info
):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id,
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
elif dataset_document.data_source_type == "notion_import":
if (
not data_source_info
or "notion_workspace_id" not in data_source_info
or "notion_page_id" not in data_source_info
):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id,
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
elif dataset_document.data_source_type == "website_crawl":
if (
not data_source_info
or "provider" not in data_source_info
or "url" not in data_source_info
or "job_id" not in data_source_info
):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info["url"],
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case "website_crawl":
if (
not data_source_info
or "provider" not in data_source_info
or "url" not in data_source_info
or "job_id" not in data_source_info
):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info["url"],
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case _:
return []
# update document status to splitting
self._update_document_index_status(
document_id=dataset_document.id,

View File

@@ -441,11 +441,13 @@ DEFAULT_GENERATOR_SUMMARY_PROMPT = (
Requirements:
1. Write a concise summary in plain text
2. Use the same language as the input content
2. You must write in {language}. No language other than {language} should be used.
3. Focus on important facts, concepts, and details
4. If images are included, describe their key information
5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions"
6. Write directly without extra words
7. If there is not enough content to generate a meaningful summary,
return an empty string without any explanation or prompt
Output only the summary text. Start summarizing now:

View File

@@ -48,12 +48,22 @@ class BaseIndexProcessor(ABC):
@abstractmethod
def generate_summary_preview(
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
The summary can be stored in a new attribute, e.g., summary.
This method should be implemented by subclasses.
Args:
tenant_id: Tenant ID
preview_texts: List of preview details to generate summaries for
summary_index_setting: Summary index configuration
doc_language: Optional document language to ensure summary is generated in the correct language
"""
raise NotImplementedError

View File

@@ -275,7 +275,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
raise ValueError("Chunks is not a list")
def generate_summary_preview(
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
For each segment, concurrently call generate_summary to generate a summary
@@ -298,11 +302,15 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if flask_app:
# Ensure Flask app context in worker thread
with flask_app.app_context():
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
summary, _ = self.generate_summary(
tenant_id, preview.content, summary_index_setting, document_language=doc_language
)
preview.summary = summary
else:
# Fallback: try without app context (may fail)
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
summary, _ = self.generate_summary(
tenant_id, preview.content, summary_index_setting, document_language=doc_language
)
preview.summary = summary
# Generate summaries concurrently using ThreadPoolExecutor
@@ -356,6 +364,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
text: str,
summary_index_setting: dict | None = None,
segment_id: str | None = None,
document_language: str | None = None,
) -> tuple[str, LLMUsage]:
"""
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt,
@@ -366,6 +375,8 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
text: Text content to summarize
summary_index_setting: Summary index configuration
segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table
document_language: Optional document language (e.g., "Chinese", "English")
to ensure summary is generated in the correct language
Returns:
Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object
@@ -381,8 +392,22 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
raise ValueError("model_name and model_provider_name are required in summary_index_setting")
# Import default summary prompt
is_default_prompt = False
if not summary_prompt:
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
is_default_prompt = True
# Format prompt with document language only for default prompt
# Custom prompts are used as-is to avoid interfering with user-defined templates
# If document_language is provided, use it; otherwise, use "the same language as the input content"
# This is especially important for image-only chunks where text is empty or minimal
if is_default_prompt:
language_for_prompt = document_language or "the same language as the input content"
try:
summary_prompt = summary_prompt.format(language=language_for_prompt)
except KeyError:
# If default prompt doesn't have {language} placeholder, use it as-is
pass
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(

View File

@@ -358,7 +358,11 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
}
def generate_summary_preview(
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary
@@ -389,6 +393,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
tenant_id=tenant_id,
text=preview.content,
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
preview.summary = summary
else:
@@ -397,6 +402,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
tenant_id=tenant_id,
text=preview.content,
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
preview.summary = summary

View File

@@ -241,7 +241,11 @@ class QAIndexProcessor(BaseIndexProcessor):
}
def generate_summary_preview(
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
QA model doesn't generate summaries, so this method returns preview_texts unchanged.

View File

@@ -192,32 +192,33 @@ class AgentNode(Node[AgentNodeData]):
result[parameter_name] = None
continue
agent_input = node_data.agent_parameters[parameter_name]
if agent_input.type == "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
elif agent_input.type in {"mixed", "constant"}:
# variable_pool.convert_template expects a string template,
# but if passing a dict, convert to JSON string first before rendering
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
match agent_input.type:
case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
case "mixed" | "constant":
# variable_pool.convert_template expects a string template,
# but if passing a dict, convert to JSON string first before rendering
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
# variable_pool.convert_template returns a string,
# so we need to convert it back to a dictionary
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
else:
raise AgentInputTypeError(agent_input.type)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
# variable_pool.convert_template returns a string,
# so we need to convert it back to a dictionary
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
case _:
raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
@@ -374,12 +375,13 @@ class AgentNode(Node[AgentNodeData]):
result: dict[str, Any] = {}
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
if input.type in ["mixed", "constant"]:
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
match input.type:
case "mixed" | "constant":
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
result = {node_id + "." + key: value for key, value in result.items()}

View File

@@ -270,15 +270,18 @@ class DatasourceNode(Node[DatasourceNodeData]):
if typed_node_data.datasource_parameters:
for parameter_name in typed_node_data.datasource_parameters:
input = typed_node_data.datasource_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
elif input.type == "constant":
pass
match input.type:
case "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
case "constant":
pass
case None:
pass
result = {node_id + "." + key: value for key, value in result.items()}
@@ -308,99 +311,107 @@ class DatasourceNode(Node[DatasourceNodeData]):
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
DatasourceMessage.MessageType.IMAGE_LINK,
DatasourceMessage.MessageType.BINARY_LINK,
DatasourceMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, DatasourceMessage.TextMessage)
match message.type:
case (
DatasourceMessage.MessageType.IMAGE_LINK
| DatasourceMessage.MessageType.BINARY_LINK
| DatasourceMessage.MessageType.IMAGE
):
assert isinstance(message.message, DatasourceMessage.TextMessage)
url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE
url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE
datasource_file_id = str(url).split("/")[-1].split(".")[0]
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
elif message.type == DatasourceMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, DatasourceMessage.TextMessage)
assert message.meta
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
mapping = {
"tool_file_id": datasource_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
elif message.type == DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
json.append(message.message.json_object)
elif message.type == DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
files.append(file)
case DatasourceMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, DatasourceMessage.TextMessage)
assert message.meta
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
mapping = {
"tool_file_id": datasource_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
case DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[self._node_id, variable_name],
chunk=variable_value,
selector=[self._node_id, "text"],
chunk=message.message.text,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == DatasourceMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
case DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
json.append(message.message.json_object)
case DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=stream_text,
is_final=False,
)
case DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[self._node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
case DatasourceMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
case (
DatasourceMessage.MessageType.BLOB_CHUNK
| DatasourceMessage.MessageType.LOG
| DatasourceMessage.MessageType.RETRIEVER_RESOURCES
):
pass
# mark the end of the stream
yield StreamChunkEvent(
selector=[self._node_id, "text"],

View File

@@ -78,12 +78,21 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
# Try to get document language if document_id is available
doc_language = None
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).filter_by(id=document_id.value).first()
if document and document.doc_language:
doc_language = document.doc_language
outputs = self._get_preview_output_with_summaries(
node_data.chunk_structure,
chunks,
dataset=dataset,
indexing_technique=indexing_technique,
summary_index_setting=summary_index_setting,
doc_language=doc_language,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -315,6 +324,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
dataset: Dataset,
indexing_technique: str | None = None,
summary_index_setting: dict | None = None,
doc_language: str | None = None,
) -> Mapping[str, Any]:
"""
Generate preview output with summaries for chunks in preview mode.
@@ -326,6 +336,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
dataset: Dataset object (for tenant_id)
indexing_technique: Indexing technique from node config or dataset
summary_index_setting: Summary index setting from node config or dataset
doc_language: Optional document language to ensure summary is generated in the correct language
"""
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
preview_output = index_processor.format_preview(chunks)
@@ -365,6 +376,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
if summary:
preview_item["summary"] = summary
@@ -374,6 +386,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
if summary:
preview_item["summary"] = summary

View File

@@ -303,33 +303,34 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
else:
match node_data.multiple_retrieval_config.reranking_mode:
case "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
else:
reranking_model = None
weights = None
case "weighted_score":
if node_data.multiple_retrieval_config.weights is None:
raise ValueError("weights is required")
reranking_model = None
weights = None
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
if node_data.multiple_retrieval_config.weights is None:
raise ValueError("weights is required")
reranking_model = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
"vector_setting": {
"vector_weight": vector_setting.vector_weight,
"embedding_provider_name": vector_setting.embedding_provider_name,
"embedding_model_name": vector_setting.embedding_model_name,
},
"keyword_setting": {
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
},
}
else:
reranking_model = None
weights = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
"vector_setting": {
"vector_weight": vector_setting.vector_weight,
"embedding_provider_name": vector_setting.embedding_provider_name,
"embedding_model_name": vector_setting.embedding_model_name,
},
"keyword_setting": {
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
},
}
case _:
reranking_model = None
weights = None
all_documents = dataset_retrieval.multiple_retrieve(
app_id=self.app_id,
tenant_id=self.tenant_id,
@@ -453,73 +454,74 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
)
filters: list[Any] = []
metadata_condition = None
if node_data.metadata_filtering_mode == "disabled":
return None, None, usage
elif node_data.metadata_filtering_mode == "automatic":
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
)
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters,
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or",
conditions=conditions,
match node_data.metadata_filtering_mode:
case "disabled":
return None, None, usage
case "automatic":
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
)
elif node_data.metadata_filtering_mode == "manual":
if node_data.metadata_filtering_conditions:
conditions = []
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters,
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or",
conditions=conditions,
)
filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
case "manual":
if node_data.metadata_filtering_conditions:
conditions = []
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
)
)
filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
else:
raise ValueError("Invalid metadata filtering mode")
case _:
raise ValueError("Invalid metadata filtering mode")
if filters:
if (
node_data.metadata_filtering_conditions

View File

@@ -482,16 +482,17 @@ class ToolNode(Node[ToolNodeData]):
result = {}
for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
elif input.type == "constant":
pass
match input.type:
case "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
case "constant":
pass
result = {node_id + "." + key: value for key, value in result.items()}

View File

@@ -390,8 +390,7 @@ class ClickZettaVolumeStorage(BaseStorage):
"""
content = self.load_once(filename)
with Path(target_filepath).open("wb") as f:
f.write(content)
Path(target_filepath).write_bytes(content)
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)

View File

@@ -1,36 +1,69 @@
from flask_restx import Namespace, fields
from __future__ import annotations
from libs.helper import TimestampField
from datetime import datetime
annotation_fields = {
"id": fields.String,
"question": fields.String,
"answer": fields.Raw(attribute="content"),
"hit_count": fields.Integer,
"created_at": TimestampField,
# 'account': fields.Nested(simple_account_fields, allow_null=True)
}
from pydantic import BaseModel, ConfigDict, Field, field_validator
def build_annotation_model(api_or_ns: Namespace):
"""Build the annotation model for the API or Namespace."""
return api_or_ns.model("Annotation", annotation_fields)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
annotation_list_fields = {
"data": fields.List(fields.Nested(annotation_fields)),
}
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
annotation_hit_history_fields = {
"id": fields.String,
"source": fields.String,
"score": fields.Float,
"question": fields.String,
"created_at": TimestampField,
"match": fields.String(attribute="annotation_question"),
"response": fields.String(attribute="annotation_content"),
}
annotation_hit_history_list_fields = {
"data": fields.List(fields.Nested(annotation_hit_history_fields)),
}
class Annotation(ResponseModel):
id: str
question: str | None = None
answer: str | None = Field(default=None, validation_alias="content")
hit_count: int | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AnnotationList(ResponseModel):
data: list[Annotation]
has_more: bool
limit: int
total: int
page: int
class AnnotationExportList(ResponseModel):
data: list[Annotation]
class AnnotationHitHistory(ResponseModel):
id: str
source: str | None = None
score: float | None = None
question: str | None = None
created_at: int | None = None
match: str | None = Field(default=None, validation_alias="annotation_question")
response: str | None = Field(default=None, validation_alias="annotation_content")
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AnnotationHitHistoryList(ResponseModel):
data: list[AnnotationHitHistory]
has_more: bool
limit: int
total: int
page: int

View File

@@ -1,4 +1,7 @@
from flask_restx import Namespace, fields
from __future__ import annotations
from flask_restx import fields
from pydantic import BaseModel, ConfigDict
simple_end_user_fields = {
"id": fields.String,
@@ -8,5 +11,18 @@ simple_end_user_fields = {
}
def build_simple_end_user_model(api_or_ns: Namespace):
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
class SimpleEndUser(ResponseModel):
id: str
type: str
is_anonymous: bool
session_id: str | None = None

View File

@@ -1,6 +1,11 @@
from flask_restx import Namespace, fields
from __future__ import annotations
from libs.helper import AvatarUrlField, TimestampField
from datetime import datetime
from flask_restx import fields
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
from core.file import helpers as file_helpers
simple_account_fields = {
"id": fields.String,
@@ -9,36 +14,78 @@ simple_account_fields = {
}
def build_simple_account_model(api_or_ns: Namespace):
return api_or_ns.model("SimpleAccount", simple_account_fields)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
account_fields = {
"id": fields.String,
"name": fields.String,
"avatar": fields.String,
"avatar_url": AvatarUrlField,
"email": fields.String,
"is_password_set": fields.Boolean,
"interface_language": fields.String,
"interface_theme": fields.String,
"timezone": fields.String,
"last_login_at": TimestampField,
"last_login_ip": fields.String,
"created_at": TimestampField,
}
def _build_avatar_url(avatar: str | None) -> str | None:
if avatar is None:
return None
if avatar.startswith(("http://", "https://")):
return avatar
return file_helpers.get_signed_file_url(avatar)
account_with_role_fields = {
"id": fields.String,
"name": fields.String,
"avatar": fields.String,
"avatar_url": AvatarUrlField,
"email": fields.String,
"last_login_at": TimestampField,
"last_active_at": TimestampField,
"created_at": TimestampField,
"role": fields.String,
"status": fields.String,
}
account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))}
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
class SimpleAccount(ResponseModel):
id: str
name: str
email: str
class _AccountAvatar(ResponseModel):
avatar: str | None = None
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
@property
def avatar_url(self) -> str | None:
return _build_avatar_url(self.avatar)
class Account(_AccountAvatar):
id: str
name: str
email: str
is_password_set: bool
interface_language: str | None = None
interface_theme: str | None = None
timezone: str | None = None
last_login_at: int | None = None
last_login_ip: str | None = None
created_at: int | None = None
@field_validator("last_login_at", "created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AccountWithRole(_AccountAvatar):
id: str
name: str
email: str
last_login_at: int | None = None
last_active_at: int | None = None
created_at: int | None = None
role: str
status: str
@field_validator("last_login_at", "last_active_at", "created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AccountWithRoleList(ResponseModel):
accounts: list[AccountWithRole]

View File

@@ -1,12 +1,20 @@
from flask_restx import Namespace, fields
from __future__ import annotations
dataset_tag_fields = {
"id": fields.String,
"name": fields.String,
"type": fields.String,
"binding_count": fields.String,
}
from pydantic import BaseModel, ConfigDict
def build_dataset_tag_fields(api_or_ns: Namespace):
return api_or_ns.model("DataSetTag", dataset_tag_fields)
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
class DataSetTag(ResponseModel):
id: str
name: str
type: str
binding_count: str | None = None

View File

@@ -1,7 +1,7 @@
from flask_restx import Namespace, fields
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
from fields.member_fields import build_simple_account_model, simple_account_fields
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
from fields.workflow_run_fields import (
build_workflow_run_for_archived_log_model,
build_workflow_run_for_log_model,
@@ -25,17 +25,9 @@ workflow_app_log_partial_fields = {
def build_workflow_app_log_partial_model(api_or_ns: Namespace):
"""Build the workflow app log partial model for the API or Namespace."""
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
simple_account_model = build_simple_account_model(api_or_ns)
simple_end_user_model = build_simple_end_user_model(api_or_ns)
copied_fields = workflow_app_log_partial_fields.copy()
copied_fields["workflow_run"] = fields.Nested(workflow_run_model, attribute="workflow_run", allow_null=True)
copied_fields["created_by_account"] = fields.Nested(
simple_account_model, attribute="created_by_account", allow_null=True
)
copied_fields["created_by_end_user"] = fields.Nested(
simple_end_user_model, attribute="created_by_end_user", allow_null=True
)
return api_or_ns.model("WorkflowAppLogPartial", copied_fields)
@@ -52,17 +44,9 @@ workflow_archived_log_partial_fields = {
def build_workflow_archived_log_partial_model(api_or_ns: Namespace):
"""Build the workflow archived log partial model for the API or Namespace."""
workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns)
simple_account_model = build_simple_account_model(api_or_ns)
simple_end_user_model = build_simple_end_user_model(api_or_ns)
copied_fields = workflow_archived_log_partial_fields.copy()
copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True)
copied_fields["created_by_account"] = fields.Nested(
simple_account_model, attribute="created_by_account", allow_null=True
)
copied_fields["created_by_end_user"] = fields.Nested(
simple_end_user_model, attribute="created_by_end_user", allow_null=True
)
return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields)

View File

@@ -145,7 +145,7 @@ dev = [
"types-openpyxl~=3.1.5",
"types-pexpect~=4.9.0",
"types-protobuf~=5.29.1",
"types-psutil~=7.0.0",
"types-psutil~=7.2.2",
"types-psycopg2~=2.9.21",
"types-pygments~=2.19.0",
"types-pymysql~=1.1.0",

View File

@@ -158,7 +158,7 @@ class AppAnnotationService:
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
)
annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False)
return annotations.items, annotations.total
return annotations.items, annotations.total or 0
@classmethod
def export_annotation_list_by_app_id(cls, app_id: str):
@@ -524,7 +524,7 @@ class AppAnnotationService:
annotation_hit_histories = db.paginate(
select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False
)
return annotation_hit_histories.items, annotation_hit_histories.total
return annotation_hit_histories.items, annotation_hit_histories.total or 0
@classmethod
def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None:

View File

@@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from core.db.session_factory import session_factory
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.file import helpers as file_helpers
from core.helper.name_generator import generate_incremental_name
@@ -1388,6 +1389,46 @@ class DocumentService:
).all()
return documents
@staticmethod
def update_documents_need_summary(dataset_id: str, document_ids: Sequence[str], need_summary: bool = True) -> int:
"""
Update need_summary field for multiple documents.
This method handles the case where documents were created when summary_index_setting was disabled,
and need to be updated when summary_index_setting is later enabled.
Args:
dataset_id: Dataset ID
document_ids: List of document IDs to update
need_summary: Value to set for need_summary field (default: True)
Returns:
Number of documents updated
"""
if not document_ids:
return 0
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
with session_factory.create_session() as session:
updated_count = (
session.query(Document)
.filter(
Document.id.in_(document_id_list),
Document.dataset_id == dataset_id,
Document.doc_form != "qa_model", # Skip qa_model documents
)
.update({Document.need_summary: need_summary}, synchronize_session=False)
)
session.commit()
logger.info(
"Updated need_summary to %s for %d documents in dataset %s",
need_summary,
updated_count,
dataset_id,
)
return updated_count
@staticmethod
def get_document_download_url(document: Document) -> str:
"""
@@ -2937,14 +2978,15 @@ class DocumentService:
"""
now = naive_utc_now()
if action == "enable":
return DocumentService._prepare_enable_update(document, now)
elif action == "disable":
return DocumentService._prepare_disable_update(document, user, now)
elif action == "archive":
return DocumentService._prepare_archive_update(document, user, now)
elif action == "un_archive":
return DocumentService._prepare_unarchive_update(document, now)
match action:
case "enable":
return DocumentService._prepare_enable_update(document, now)
case "disable":
return DocumentService._prepare_disable_update(document, user, now)
case "archive":
return DocumentService._prepare_archive_update(document, user, now)
case "un_archive":
return DocumentService._prepare_unarchive_update(document, now)
return None
@@ -3581,56 +3623,57 @@ class SegmentService:
# Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0:
return
if action == "enable":
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == False,
)
).all()
if not segments:
return
real_deal_segment_ids = []
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
continue
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
db.session.add(segment)
real_deal_segment_ids.append(segment.id)
db.session.commit()
match action:
case "enable":
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == False,
)
).all()
if not segments:
return
real_deal_segment_ids = []
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
continue
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
db.session.add(segment)
real_deal_segment_ids.append(segment.id)
db.session.commit()
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
elif action == "disable":
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == True,
)
).all()
if not segments:
return
real_deal_segment_ids = []
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
continue
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.disabled_by = current_user.id
db.session.add(segment)
real_deal_segment_ids.append(segment.id)
db.session.commit()
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
case "disable":
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == True,
)
).all()
if not segments:
return
real_deal_segment_ids = []
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
continue
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.disabled_by = current_user.id
db.session.add(segment)
real_deal_segment_ids.append(segment.id)
db.session.commit()
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
@classmethod
def create_child_chunk(

View File

@@ -174,6 +174,10 @@ class RagPipelineTransformService:
else:
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
# Copy summary_index_setting from dataset to knowledge_index node configuration
if dataset.summary_index_setting:
knowledge_configuration.summary_index_setting = dataset.summary_index_setting
knowledge_configuration_dict.update(knowledge_configuration.model_dump())
node["data"] = knowledge_configuration_dict
return node

View File

@@ -49,11 +49,18 @@ class SummaryIndexService:
# Use lazy import to avoid circular import
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
# Get document language to ensure summary is generated in the correct language
# This is especially important for image-only chunks where text is empty or minimal
document_language = None
if segment.document and segment.document.doc_language:
document_language = segment.document.doc_language
summary_content, usage = ParagraphIndexProcessor.generate_summary(
tenant_id=dataset.tenant_id,
text=segment.content,
summary_index_setting=summary_index_setting,
segment_id=segment.id,
document_language=document_language,
)
if not summary_content:
@@ -558,6 +565,9 @@ class SummaryIndexService:
)
session.add(summary_record)
# Commit the batch created records
session.commit()
@staticmethod
def update_summary_record_error(
segment: DocumentSegment,
@@ -762,7 +772,6 @@ class SummaryIndexService:
dataset=dataset,
status="not_started",
)
session.commit() # Commit initial records
summary_records = []

View File

@@ -24,7 +24,7 @@ class TagService:
escaped_keyword = escape_like_pattern(keyword)
query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results = query.order_by(Tag.created_at.desc()).all()
results: list = query.order_by(Tag.created_at.desc()).all()
return results
@staticmethod

View File

@@ -1,291 +0,0 @@
import builtins
import contextlib
import importlib
import sys
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from flask.views import MethodView
from werkzeug.exceptions import Unauthorized
from extensions import ext_fastopenapi
from extensions.ext_database import db
from services.feature_service import FeatureModel, SystemFeatureModel
@pytest.fixture
def app():
"""
Creates a Flask application instance configured for testing.
"""
app = Flask(__name__)
app.config["TESTING"] = True
app.config["SECRET_KEY"] = "test-secret"
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
# Initialize the database with the app
db.init_app(app)
return app
@pytest.fixture(autouse=True)
def fix_method_view_issue(monkeypatch):
"""
Automatic fixture to patch 'builtins.MethodView'.
Why this is needed:
The official legacy codebase contains a global patch in its initialization logic:
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView
Some dependencies (like ext_fastopenapi or older Flask extensions) might implicitly
rely on 'MethodView' being available in the global builtins namespace.
Refactoring Note:
While patching builtins is generally discouraged due to global side effects,
this fixture reproduces the production environment's state to ensure tests are realistic.
We use 'monkeypatch' to ensure that this change is undone after the test finishes,
keeping other tests isolated.
"""
if not hasattr(builtins, "MethodView"):
# 'raising=False' allows us to set an attribute that doesn't exist yet
monkeypatch.setattr(builtins, "MethodView", MethodView, raising=False)
# ------------------------------------------------------------------------------
# Helper Functions for Fixture Complexity Reduction
# ------------------------------------------------------------------------------
def _create_isolated_router():
"""
Creates a fresh, isolated router instance to prevent route pollution.
"""
import controllers.fastopenapi
# Dynamically get the class type (e.g., FlaskRouter) to avoid hardcoding dependencies
RouterClass = type(controllers.fastopenapi.console_router)
return RouterClass()
@contextlib.contextmanager
def _patch_auth_and_router(temp_router):
"""
Context manager that applies all necessary patches for:
1. The console_router (redirecting to our isolated temp_router)
2. Authentication decorators (disabling them with no-ops)
3. User/Account loaders (mocking authenticated state)
"""
def noop(f):
return f
# We patch the SOURCE of the decorators/functions, not the destination module.
# This ensures that when 'controllers.console.feature' imports them, it gets the mocks.
with (
patch("controllers.fastopenapi.console_router", temp_router),
patch("extensions.ext_fastopenapi.console_router", temp_router),
patch("controllers.console.wraps.setup_required", side_effect=noop),
patch("libs.login.login_required", side_effect=noop),
patch("controllers.console.wraps.account_initialization_required", side_effect=noop),
patch("controllers.console.wraps.cloud_utm_record", side_effect=noop),
patch("libs.login.current_account_with_tenant", return_value=(MagicMock(), "tenant-id")),
patch("libs.login.current_user", MagicMock(is_authenticated=True)),
):
# Explicitly reload ext_fastopenapi to ensure it uses the patched console_router
import extensions.ext_fastopenapi
importlib.reload(extensions.ext_fastopenapi)
yield
def _force_reload_module(target_module: str, alias_module: str):
"""
Forces a reload of the specified module and handles sys.modules aliasing.
Why reload?
Python decorators (like @route, @login_required) run at IMPORT time.
To apply our patches (mocks/no-ops) to these decorators, we must re-import
the module while the patches are active.
Why alias?
If 'ext_fastopenapi' imports the controller as 'api.controllers...', but we import
it as 'controllers...', Python treats them as two separate modules. This causes:
1. Double execution of decorators (registering routes twice -> AssertionError).
2. Type mismatch errors (Class A from module X is not Class A from module Y).
This function ensures both names point to the SAME loaded module instance.
"""
# 1. Clean existing entries to force re-import
if target_module in sys.modules:
del sys.modules[target_module]
if alias_module in sys.modules:
del sys.modules[alias_module]
# 2. Import the module (triggering decorators with active patches)
module = importlib.import_module(target_module)
# 3. Alias the module in sys.modules to prevent double loading
sys.modules[alias_module] = sys.modules[target_module]
return module
def _cleanup_modules(target_module: str, alias_module: str):
"""
Removes the module and its alias from sys.modules to prevent side effects
on other tests.
"""
if target_module in sys.modules:
del sys.modules[target_module]
if alias_module in sys.modules:
del sys.modules[alias_module]
@pytest.fixture
def mock_feature_module_env():
"""
Sets up a mocked environment for the feature module.
This fixture orchestrates:
1. Creating an isolated router.
2. Patching authentication and global dependencies.
3. Reloading the controller module to apply patches to decorators.
4. cleaning up sys.modules afterwards.
"""
target_module = "controllers.console.feature"
alias_module = "api.controllers.console.feature"
# 1. Prepare isolated router
temp_router = _create_isolated_router()
# 2. Apply patches
try:
with _patch_auth_and_router(temp_router):
# 3. Reload module to register routes on the temp_router
feature_module = _force_reload_module(target_module, alias_module)
yield feature_module
finally:
# 4. Teardown: Clean up sys.modules
_cleanup_modules(target_module, alias_module)
# ------------------------------------------------------------------------------
# Test Cases
# ------------------------------------------------------------------------------
@pytest.mark.parametrize(
("url", "service_mock_path", "mock_model_instance", "json_key"),
[
(
"/console/api/features",
"controllers.console.feature.FeatureService.get_features",
FeatureModel(can_replace_logo=True),
"features",
),
(
"/console/api/system-features",
"controllers.console.feature.FeatureService.get_system_features",
SystemFeatureModel(enable_marketplace=True),
"features",
),
],
)
def test_console_features_success(app, mock_feature_module_env, url, service_mock_path, mock_model_instance, json_key):
"""
Tests that the feature APIs return a 200 OK status and correct JSON structure.
"""
# Patch the service layer to return our mock model instance
with patch(service_mock_path, return_value=mock_model_instance):
# Initialize the API extension
ext_fastopenapi.init_app(app)
client = app.test_client()
response = client.get(url)
# Assertions
assert response.status_code == 200, f"Request failed with status {response.status_code}: {response.text}"
# Verify the JSON response matches the Pydantic model dump
expected_data = mock_model_instance.model_dump(mode="json")
assert response.get_json() == {json_key: expected_data}
@pytest.mark.parametrize(
("url", "service_mock_path"),
[
("/console/api/features", "controllers.console.feature.FeatureService.get_features"),
("/console/api/system-features", "controllers.console.feature.FeatureService.get_system_features"),
],
)
def test_console_features_service_error(app, mock_feature_module_env, url, service_mock_path):
"""
Tests how the application handles Service layer errors.
Note: When an exception occurs in the view, it is typically caught by the framework
(Flask or the OpenAPI wrapper) and converted to a 500 error response.
This test verifies that the application returns a 500 status code.
"""
# Simulate a service failure
with patch(service_mock_path, side_effect=ValueError("Service Failure")):
ext_fastopenapi.init_app(app)
client = app.test_client()
# When an exception occurs in the view, it is typically caught by the framework
# (Flask or the OpenAPI wrapper) and converted to a 500 error response.
response = client.get(url)
assert response.status_code == 500
# Check if the error details are exposed in the response (depends on error handler config)
# We accept either generic 500 or the specific error message
assert "Service Failure" in response.text or "Internal Server Error" in response.text
def test_system_features_unauthenticated(app, mock_feature_module_env):
"""
Tests that /console/api/system-features endpoint works without authentication.
This test verifies the try-except block in get_system_features that handles
unauthenticated requests by passing is_authenticated=False to the service layer.
"""
feature_module = mock_feature_module_env
# Override the behavior of the current_user mock
# The fixture patched 'libs.login.current_user', so 'controllers.console.feature.current_user'
# refers to that same Mock object.
mock_user = feature_module.current_user
# Simulate property access raising Unauthorized
# Note: We must reset side_effect if it was set, or set it here.
# The fixture initialized it as MagicMock(is_authenticated=True).
# We want type(mock_user).is_authenticated to raise Unauthorized.
type(mock_user).is_authenticated = PropertyMock(side_effect=Unauthorized)
# Patch the service layer for this specific test
with patch("controllers.console.feature.FeatureService.get_system_features") as mock_service:
# Setup mock service return value
mock_model = SystemFeatureModel(enable_marketplace=True)
mock_service.return_value = mock_model
# Initialize app
ext_fastopenapi.init_app(app)
client = app.test_client()
# Act
response = client.get("/console/api/system-features")
# Assert
assert response.status_code == 200, f"Request failed: {response.text}"
# Verify service was called with is_authenticated=False
mock_service.assert_called_once_with(is_authenticated=False)
# Verify response body
expected_data = mock_model.model_dump(mode="json")
assert response.get_json() == {"features": expected_data}

View File

@@ -1,222 +0,0 @@
import builtins
import contextlib
import importlib
import sys
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask.views import MethodView
from extensions import ext_fastopenapi
from extensions.ext_database import db
@pytest.fixture
def app():
app = Flask(__name__)
app.config["TESTING"] = True
app.config["SECRET_KEY"] = "test-secret"
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
db.init_app(app)
return app
@pytest.fixture(autouse=True)
def fix_method_view_issue(monkeypatch):
if not hasattr(builtins, "MethodView"):
monkeypatch.setattr(builtins, "MethodView", MethodView, raising=False)
def _create_isolated_router():
import controllers.fastopenapi
router_class = type(controllers.fastopenapi.console_router)
return router_class()
@contextlib.contextmanager
def _patch_auth_and_router(temp_router):
def noop(func):
return func
default_user = MagicMock(has_edit_permission=True, is_dataset_editor=False)
with (
patch("controllers.fastopenapi.console_router", temp_router),
patch("extensions.ext_fastopenapi.console_router", temp_router),
patch("controllers.console.wraps.setup_required", side_effect=noop),
patch("libs.login.login_required", side_effect=noop),
patch("controllers.console.wraps.account_initialization_required", side_effect=noop),
patch("controllers.console.wraps.edit_permission_required", side_effect=noop),
patch("libs.login.current_account_with_tenant", return_value=(default_user, "tenant-id")),
patch("configs.dify_config.EDITION", "CLOUD"),
):
import extensions.ext_fastopenapi
importlib.reload(extensions.ext_fastopenapi)
yield
def _force_reload_module(target_module: str, alias_module: str):
if target_module in sys.modules:
del sys.modules[target_module]
if alias_module in sys.modules:
del sys.modules[alias_module]
module = importlib.import_module(target_module)
sys.modules[alias_module] = sys.modules[target_module]
return module
def _dedupe_routes(router):
seen = set()
unique_routes = []
for path, method, endpoint in reversed(router.get_routes()):
key = (path, method, endpoint.__name__)
if key in seen:
continue
seen.add(key)
unique_routes.append((path, method, endpoint))
router._routes = list(reversed(unique_routes))
def _cleanup_modules(target_module: str, alias_module: str):
if target_module in sys.modules:
del sys.modules[target_module]
if alias_module in sys.modules:
del sys.modules[alias_module]
@pytest.fixture
def mock_tags_module_env():
target_module = "controllers.console.tag.tags"
alias_module = "api.controllers.console.tag.tags"
temp_router = _create_isolated_router()
try:
with _patch_auth_and_router(temp_router):
tags_module = _force_reload_module(target_module, alias_module)
_dedupe_routes(temp_router)
yield tags_module
finally:
_cleanup_modules(target_module, alias_module)
def test_list_tags_success(app: Flask, mock_tags_module_env):
# Arrange
tag = SimpleNamespace(id="tag-1", name="Alpha", type="app", binding_count=2)
with patch("controllers.console.tag.tags.TagService.get_tags", return_value=[tag]):
ext_fastopenapi.init_app(app)
client = app.test_client()
# Act
response = client.get("/console/api/tags?type=app&keyword=Alpha")
# Assert
assert response.status_code == 200
assert response.get_json() == [
{"id": "tag-1", "name": "Alpha", "type": "app", "binding_count": 2},
]
def test_create_tag_success(app: Flask, mock_tags_module_env):
# Arrange
tag = SimpleNamespace(id="tag-2", name="Beta", type="app")
with patch("controllers.console.tag.tags.TagService.save_tags", return_value=tag) as mock_save:
ext_fastopenapi.init_app(app)
client = app.test_client()
# Act
response = client.post("/console/api/tags", json={"name": "Beta", "type": "app"})
# Assert
assert response.status_code == 200
assert response.get_json() == {
"id": "tag-2",
"name": "Beta",
"type": "app",
"binding_count": 0,
}
mock_save.assert_called_once_with({"name": "Beta", "type": "app"})
def test_update_tag_success(app: Flask, mock_tags_module_env):
# Arrange
tag = SimpleNamespace(id="tag-3", name="Gamma", type="app")
with (
patch("controllers.console.tag.tags.TagService.update_tags", return_value=tag) as mock_update,
patch("controllers.console.tag.tags.TagService.get_tag_binding_count", return_value=4),
):
ext_fastopenapi.init_app(app)
client = app.test_client()
# Act
response = client.patch(
"/console/api/tags/11111111-1111-1111-1111-111111111111",
json={"name": "Gamma", "type": "app"},
)
# Assert
assert response.status_code == 200
assert response.get_json() == {
"id": "tag-3",
"name": "Gamma",
"type": "app",
"binding_count": 4,
}
mock_update.assert_called_once_with(
{"name": "Gamma", "type": "app"},
"11111111-1111-1111-1111-111111111111",
)
def test_delete_tag_success(app: Flask, mock_tags_module_env):
# Arrange
with patch("controllers.console.tag.tags.TagService.delete_tag") as mock_delete:
ext_fastopenapi.init_app(app)
client = app.test_client()
# Act
response = client.delete("/console/api/tags/11111111-1111-1111-1111-111111111111")
# Assert
assert response.status_code == 204
mock_delete.assert_called_once_with("11111111-1111-1111-1111-111111111111")
def test_create_tag_binding_success(app: Flask, mock_tags_module_env):
# Arrange
payload = {"tag_ids": ["tag-1", "tag-2"], "target_id": "target-1", "type": "app"}
with patch("controllers.console.tag.tags.TagService.save_tag_binding") as mock_bind:
ext_fastopenapi.init_app(app)
client = app.test_client()
# Act
response = client.post("/console/api/tag-bindings/create", json=payload)
# Assert
assert response.status_code == 200
assert response.get_json() == {"result": "success"}
mock_bind.assert_called_once_with(payload)
def test_delete_tag_binding_success(app: Flask, mock_tags_module_env):
# Arrange
payload = {"tag_id": "tag-1", "target_id": "target-1", "type": "app"}
with patch("controllers.console.tag.tags.TagService.delete_tag_binding") as mock_unbind:
ext_fastopenapi.init_app(app)
client = app.test_client()
# Act
response = client.post("/console/api/tag-bindings/remove", json=payload)
# Assert
assert response.status_code == 200
assert response.get_json() == {"result": "success"}
mock_unbind.assert_called_once_with(payload)

View File

@@ -0,0 +1,364 @@
"""Endpoint tests for controllers.console.workspace.tool_providers."""
from __future__ import annotations
import builtins
import importlib
from contextlib import contextmanager
from types import ModuleType, SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask.views import MethodView
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
_CONTROLLER_MODULE: ModuleType | None = None
_WRAPS_MODULE: ModuleType | None = None
_CONTROLLER_PATCHERS: list[patch] = []
@contextmanager
def _mock_db():
mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True))
with patch("extensions.ext_database.db.session", mock_session):
yield
@pytest.fixture
def app() -> Flask:
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
@pytest.fixture
def controller_module(monkeypatch: pytest.MonkeyPatch):
module_name = "controllers.console.workspace.tool_providers"
global _CONTROLLER_MODULE
if _CONTROLLER_MODULE is None:
def _noop(func):
return func
patch_targets = [
("libs.login.login_required", _noop),
("controllers.console.wraps.setup_required", _noop),
("controllers.console.wraps.account_initialization_required", _noop),
("controllers.console.wraps.is_admin_or_owner_required", _noop),
("controllers.console.wraps.enterprise_license_required", _noop),
]
for target, value in patch_targets:
patcher = patch(target, value)
patcher.start()
_CONTROLLER_PATCHERS.append(patcher)
monkeypatch.setenv("DIFY_SETUP_READY", "true")
with _mock_db():
_CONTROLLER_MODULE = importlib.import_module(module_name)
module = _CONTROLLER_MODULE
monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload)
# Ensure decorators that consult deployment edition do not reach the database.
global _WRAPS_MODULE
wraps_module = importlib.import_module("controllers.console.wraps")
_WRAPS_MODULE = wraps_module
monkeypatch.setattr(module.dify_config, "EDITION", "CLOUD")
monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD")
login_module = importlib.import_module("libs.login")
monkeypatch.setattr(login_module, "check_csrf_token", lambda *args, **kwargs: None)
return module
def _mock_account(user_id: str = "user-123") -> SimpleNamespace:
return SimpleNamespace(id=user_id, status="active", is_authenticated=True, current_tenant_id=None)
def _set_current_account(
monkeypatch: pytest.MonkeyPatch,
controller_module: ModuleType,
user: SimpleNamespace,
tenant_id: str,
) -> None:
def _getter():
return user, tenant_id
user.current_tenant_id = tenant_id
monkeypatch.setattr(controller_module, "current_account_with_tenant", _getter)
if _WRAPS_MODULE is not None:
monkeypatch.setattr(_WRAPS_MODULE, "current_account_with_tenant", _getter)
login_module = importlib.import_module("libs.login")
monkeypatch.setattr(login_module, "_get_user", lambda: user)
def test_tool_provider_list_calls_service_with_query(
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
service_mock = MagicMock(return_value=[{"provider": "builtin"}])
monkeypatch.setattr(controller_module.ToolCommonService, "list_tool_providers", service_mock)
with app.test_request_context("/workspaces/current/tool-providers?type=builtin"):
response = controller_module.ToolProviderListApi().get()
assert response == [{"provider": "builtin"}]
service_mock.assert_called_once_with(user.id, "tenant-456", "builtin")
def test_builtin_provider_add_passes_payload(
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
service_mock = MagicMock(return_value={"status": "ok"})
monkeypatch.setattr(controller_module.BuiltinToolManageService, "add_builtin_tool_provider", service_mock)
payload = {
"credentials": {"api_key": "sk-test"},
"name": "MyTool",
"type": controller_module.CredentialType.API_KEY,
}
with app.test_request_context(
"/workspaces/current/tool-provider/builtin/openai/add",
method="POST",
json=payload,
):
response = controller_module.ToolBuiltinProviderAddApi().post(provider="openai")
assert response == {"status": "ok"}
service_mock.assert_called_once_with(
user_id="user-123",
tenant_id="tenant-456",
provider="openai",
credentials={"api_key": "sk-test"},
name="MyTool",
api_type=controller_module.CredentialType.API_KEY,
)
def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-tenant-789")
_set_current_account(monkeypatch, controller_module, user, "tenant-789")
service_mock = MagicMock(return_value=[{"name": "tool-a"}])
monkeypatch.setattr(controller_module.BuiltinToolManageService, "list_builtin_tool_provider_tools", service_mock)
monkeypatch.setattr(controller_module, "jsonable_encoder", lambda payload: payload)
with app.test_request_context(
"/workspaces/current/tool-provider/builtin/my-provider/tools",
method="GET",
):
response = controller_module.ToolBuiltinProviderListToolsApi().get(provider="my-provider")
assert response == [{"name": "tool-a"}]
service_mock.assert_called_once_with("tenant-789", "my-provider")
def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-tenant-9")
_set_current_account(monkeypatch, controller_module, user, "tenant-9")
service_mock = MagicMock(return_value={"info": True})
monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock)
with app.test_request_context("/info", method="GET"):
resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo")
assert resp == {"info": True}
service_mock.assert_called_once_with("tenant-9", "demo")
def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-tenant-cred")
_set_current_account(monkeypatch, controller_module, user, "tenant-cred")
service_mock = MagicMock(return_value=[{"cred": 1}])
monkeypatch.setattr(
controller_module.BuiltinToolManageService,
"get_builtin_tool_provider_credentials",
service_mock,
)
with app.test_request_context("/creds", method="GET"):
resp = controller_module.ToolBuiltinProviderGetCredentialsApi().get(provider="demo")
assert resp == [{"cred": 1}]
service_mock.assert_called_once_with(tenant_id="tenant-cred", provider_name="demo")
def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-10")
service_mock = MagicMock(return_value={"schema": "ok"})
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider_remote_schema", service_mock)
with app.test_request_context("/remote?url=https://example.com/"):
resp = controller_module.ToolApiProviderGetRemoteSchemaApi().get()
assert resp == {"schema": "ok"}
service_mock.assert_called_once_with(user.id, "tenant-10", "https://example.com/")
def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-11")
service_mock = MagicMock(return_value=[{"tool": "t"}])
monkeypatch.setattr(controller_module.ApiToolManageService, "list_api_tool_provider_tools", service_mock)
with app.test_request_context("/tools?provider=foo"):
resp = controller_module.ToolApiProviderListToolsApi().get()
assert resp == [{"tool": "t"}]
service_mock.assert_called_once_with(user.id, "tenant-11", "foo")
def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-12")
service_mock = MagicMock(return_value={"provider": "foo"})
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider", service_mock)
with app.test_request_context("/get?provider=foo"):
resp = controller_module.ToolApiProviderGetApi().get()
assert resp == {"provider": "foo"}
service_mock.assert_called_once_with(user.id, "tenant-12", "foo")
def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-tenant-13")
_set_current_account(monkeypatch, controller_module, user, "tenant-13")
service_mock = MagicMock(return_value={"schema": True})
monkeypatch.setattr(
controller_module.BuiltinToolManageService,
"list_builtin_provider_credentials_schema",
service_mock,
)
with app.test_request_context("/schema", method="GET"):
resp = controller_module.ToolBuiltinProviderCredentialsSchemaApi().get(
provider="demo", credential_type="api-key"
)
assert resp == {"schema": True}
service_mock.assert_called_once()
def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-wf")
tool_service = MagicMock(return_value={"wf": 1})
monkeypatch.setattr(
controller_module.WorkflowToolManageService,
"get_workflow_tool_by_tool_id",
tool_service,
)
tool_id = "00000000-0000-0000-0000-000000000001"
with app.test_request_context(f"/workflow?workflow_tool_id={tool_id}"):
resp = controller_module.ToolWorkflowProviderGetApi().get()
assert resp == {"wf": 1}
tool_service.assert_called_once_with(user.id, "tenant-wf", tool_id)
def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-wf2")
service_mock = MagicMock(return_value={"app": 1})
monkeypatch.setattr(
controller_module.WorkflowToolManageService,
"get_workflow_tool_by_app_id",
service_mock,
)
app_id = "00000000-0000-0000-0000-000000000002"
with app.test_request_context(f"/workflow?workflow_app_id={app_id}"):
resp = controller_module.ToolWorkflowProviderGetApi().get()
assert resp == {"app": 1}
service_mock.assert_called_once_with(user.id, "tenant-wf2", app_id)
def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-wf3")
service_mock = MagicMock(return_value=[{"id": 1}])
monkeypatch.setattr(controller_module.WorkflowToolManageService, "list_single_workflow_tools", service_mock)
tool_id = "00000000-0000-0000-0000-000000000003"
with app.test_request_context(f"/workflow/tools?workflow_tool_id={tool_id}"):
resp = controller_module.ToolWorkflowProviderListToolApi().get()
assert resp == [{"id": 1}]
service_mock.assert_called_once_with(user.id, "tenant-wf3", tool_id)
def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-bt")
provider = SimpleNamespace(to_dict=lambda: {"name": "builtin"})
monkeypatch.setattr(
controller_module.BuiltinToolManageService,
"list_builtin_tools",
MagicMock(return_value=[provider]),
)
with app.test_request_context("/tools/builtin"):
resp = controller_module.ToolBuiltinListApi().get()
assert resp == [{"name": "builtin"}]
def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-tenant-api")
_set_current_account(monkeypatch, controller_module, user, "tenant-api")
provider = SimpleNamespace(to_dict=lambda: {"name": "api"})
monkeypatch.setattr(
controller_module.ApiToolManageService,
"list_api_tools",
MagicMock(return_value=[provider]),
)
with app.test_request_context("/tools/api"):
resp = controller_module.ToolApiListApi().get()
assert resp == [{"name": "api"}]
def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-wf4")
provider = SimpleNamespace(to_dict=lambda: {"name": "wf"})
monkeypatch.setattr(
controller_module.WorkflowToolManageService,
"list_tenant_workflow_tools",
MagicMock(return_value=[provider]),
)
with app.test_request_context("/tools/workflow"):
resp = controller_module.ToolWorkflowListApi().get()
assert resp == [{"name": "wf"}]
def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-label")
_set_current_account(monkeypatch, controller_module, user, "tenant-labels")
monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: ["a", "b"])
with app.test_request_context("/tool-labels"):
resp = controller_module.ToolLabelsApi().get()
assert resp == ["a", "b"]

8
api/uv.lock generated
View File

@@ -1707,7 +1707,7 @@ dev = [
{ name = "types-openpyxl", specifier = "~=3.1.5" },
{ name = "types-pexpect", specifier = "~=4.9.0" },
{ name = "types-protobuf", specifier = "~=5.29.1" },
{ name = "types-psutil", specifier = "~=7.0.0" },
{ name = "types-psutil", specifier = "~=7.2.2" },
{ name = "types-psycopg2", specifier = "~=2.9.21" },
{ name = "types-pygments", specifier = "~=2.19.0" },
{ name = "types-pymysql", specifier = "~=1.1.0" },
@@ -6508,11 +6508,11 @@ wheels = [
[[package]]
name = "types-psutil"
version = "7.0.0.20251116"
version = "7.2.2.20260130"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/47/ec/c1e9308b91582cad1d7e7d3007fd003ef45a62c2500f8219313df5fc3bba/types_psutil-7.0.0.20251116.tar.gz", hash = "sha256:92b5c78962e55ce1ed7b0189901a4409ece36ab9fd50c3029cca7e681c606c8a", size = 22192, upload-time = "2025-11-16T03:10:32.859Z" }
sdist = { url = "https://files.pythonhosted.org/packages/69/14/fc5fb0a6ddfadf68c27e254a02ececd4d5c7fdb0efcb7e7e917a183497fb/types_psutil-7.2.2.20260130.tar.gz", hash = "sha256:15b0ab69c52841cf9ce3c383e8480c620a4d13d6a8e22b16978ebddac5590950", size = 26535, upload-time = "2026-01-30T03:58:14.116Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c3/0e/11ba08a5375c21039ed5f8e6bba41e9452fb69f0e2f7ee05ed5cca2a2cdf/types_psutil-7.0.0.20251116-py3-none-any.whl", hash = "sha256:74c052de077c2024b85cd435e2cba971165fe92a5eace79cbeb821e776dbc047", size = 25376, upload-time = "2025-11-16T03:10:31.813Z" },
{ url = "https://files.pythonhosted.org/packages/17/d7/60974b7e31545d3768d1770c5fe6e093182c3bfd819429b33133ba6b3e89/types_psutil-7.2.2.20260130-py3-none-any.whl", hash = "sha256:15523a3caa7b3ff03ac7f9b78a6470a59f88f48df1d74a39e70e06d2a99107da", size = 32876, upload-time = "2026-01-30T03:58:13.172Z" },
]
[[package]]

View File

@@ -1,5 +1,5 @@
#!/bin/bash
set -x
set -euxo pipefail
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/../.."

View File

@@ -662,13 +662,14 @@ services:
- "${IRIS_SUPER_SERVER_PORT:-1972}:1972"
- "${IRIS_WEB_SERVER_PORT:-52773}:52773"
volumes:
- ./volumes/iris:/opt/iris
- ./volumes/iris:/durable
- ./iris/iris-init.script:/iris-init.script
- ./iris/docker-entrypoint.sh:/custom-entrypoint.sh
entrypoint: ["/custom-entrypoint.sh"]
tty: true
environment:
TZ: ${IRIS_TIMEZONE:-UTC}
ISC_DATA_DIRECTORY: /durable/iris
# Oracle vector database
oracle:

View File

@@ -1348,13 +1348,14 @@ services:
- "${IRIS_SUPER_SERVER_PORT:-1972}:1972"
- "${IRIS_WEB_SERVER_PORT:-52773}:52773"
volumes:
- ./volumes/iris:/opt/iris
- ./volumes/iris:/durable
- ./iris/iris-init.script:/iris-init.script
- ./iris/docker-entrypoint.sh:/custom-entrypoint.sh
entrypoint: ["/custom-entrypoint.sh"]
tty: true
environment:
TZ: ${IRIS_TIMEZONE:-UTC}
ISC_DATA_DIRECTORY: /durable/iris
# Oracle vector database
oracle:

View File

@@ -1,15 +1,33 @@
#!/bin/bash
set -e
# IRIS configuration flag file
IRIS_CONFIG_DONE="/opt/iris/.iris-configured"
# IRIS configuration flag file (stored in durable directory to persist with data)
IRIS_CONFIG_DONE="/durable/.iris-configured"
# Function to wait for IRIS to be ready
wait_for_iris() {
echo "Waiting for IRIS to be ready..."
local max_attempts=30
local attempt=1
while [ "$attempt" -le "$max_attempts" ]; do
if iris qlist IRIS 2>/dev/null | grep -q "running"; then
echo "IRIS is ready."
return 0
fi
echo "Attempt $attempt/$max_attempts: IRIS not ready yet, waiting..."
sleep 2
attempt=$((attempt + 1))
done
echo "ERROR: IRIS failed to start within expected time." >&2
return 1
}
# Function to configure IRIS
configure_iris() {
echo "Configuring IRIS for first-time setup..."
# Wait for IRIS to be fully started
sleep 5
wait_for_iris
# Execute the initialization script
iris session IRIS < /iris-init.script

View File

@@ -3,7 +3,7 @@
import type { ReactNode } from 'react'
import Cookies from 'js-cookie'
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
import { parseAsString, useQueryState } from 'nuqs'
import { parseAsBoolean, useQueryState } from 'nuqs'
import { useCallback, useEffect, useState } from 'react'
import {
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
@@ -28,7 +28,7 @@ export const AppInitializer = ({
const [init, setInit] = useState(false)
const [oauthNewUser, setOauthNewUser] = useQueryState(
'oauth_new_user',
parseAsString.withOptions({ history: 'replace' }),
parseAsBoolean.withOptions({ history: 'replace' }),
)
const isSetupFinished = useCallback(async () => {
@@ -46,7 +46,7 @@ export const AppInitializer = ({
(async () => {
const action = searchParams.get('action')
if (oauthNewUser === 'true') {
if (oauthNewUser) {
let utmInfo = null
const utmInfoStr = Cookies.get('utm_info')
if (utmInfoStr) {

View File

@@ -62,19 +62,19 @@ const AppCard = ({
{app.description}
</div>
</div>
{canCreate && (
{(canCreate || isTrialApp) && (
<div className={cn('absolute bottom-0 left-0 right-0 hidden bg-gradient-to-t from-components-panel-gradient-2 from-[60.27%] to-transparent p-4 pt-8 group-hover:flex')}>
<div className={cn('grid h-8 w-full grid-cols-1 items-center space-x-2', isTrialApp && 'grid-cols-2')}>
<Button variant="primary" onClick={() => onCreate()}>
<PlusIcon className="mr-1 h-4 w-4" />
<span className="text-xs">{t('newApp.useTemplate', { ns: 'app' })}</span>
</Button>
{isTrialApp && (
<Button onClick={showTryAPPPanel(app.app_id)}>
<RiInformation2Line className="mr-1 size-4" />
<span>{t('appCard.try', { ns: 'explore' })}</span>
<div className={cn('grid h-8 w-full grid-cols-1 items-center space-x-2', canCreate && 'grid-cols-2')}>
{canCreate && (
<Button variant="primary" onClick={() => onCreate()}>
<PlusIcon className="mr-1 h-4 w-4" />
<span className="text-xs">{t('newApp.useTemplate', { ns: 'app' })}</span>
</Button>
)}
<Button onClick={showTryAPPPanel(app.app_id)}>
<RiInformation2Line className="mr-1 size-4" />
<span>{t('appCard.try', { ns: 'explore' })}</span>
</Button>
</div>
</div>
)}

View File

@@ -124,7 +124,7 @@ describe('CreateAppModal', () => {
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
fireEvent.change(nameInput, { target: { value: 'My App' } })
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ }))
await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({
name: 'My App',
@@ -152,7 +152,7 @@ describe('CreateAppModal', () => {
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
fireEvent.change(nameInput, { target: { value: 'My App' } })
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ }))
await waitFor(() => expect(mockCreateApp).toHaveBeenCalled())
expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' })

View File

@@ -1,8 +1,9 @@
import type { ListChildComponentProps } from 'react-window'
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
import { RiArrowDownSLine, RiArrowRightSLine } from '@remixicon/react'
import { useVirtualizer } from '@tanstack/react-virtual'
import { memo, useCallback, useMemo, useRef, useState } from 'react'
import { memo, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { areEqual, FixedSizeList as List } from 'react-window'
import { cn } from '@/utils/classnames'
import Checkbox from '../../checkbox'
import NotionIcon from '../../notion-icon'
@@ -31,22 +32,6 @@ type NotionPageItem = {
depth: number
} & DataSourceNotionPage
type ItemProps = {
virtualStart: number
virtualSize: number
current: NotionPageItem
onToggle: (pageId: string) => void
checkedIds: Set<string>
disabledCheckedIds: Set<string>
onCheck: (pageId: string) => void
canPreview?: boolean
onPreview: (pageId: string) => void
listMapWithChildrenAndDescendants: NotionPageTreeMap
searchValue: string
previewPageId: string
pagesMap: DataSourceNotionPageMap
}
const recursivePushInParentDescendants = (
pagesMap: DataSourceNotionPageMap,
listTreeMap: NotionPageTreeMap,
@@ -84,22 +69,34 @@ const recursivePushInParentDescendants = (
}
}
const ItemComponent = ({
virtualStart,
virtualSize,
current,
onToggle,
checkedIds,
disabledCheckedIds,
onCheck,
canPreview,
onPreview,
listMapWithChildrenAndDescendants,
searchValue,
previewPageId,
pagesMap,
}: ItemProps) => {
const ItemComponent = ({ index, style, data }: ListChildComponentProps<{
dataList: NotionPageItem[]
handleToggle: (index: number) => void
checkedIds: Set<string>
disabledCheckedIds: Set<string>
handleCheck: (index: number) => void
canPreview?: boolean
handlePreview: (index: number) => void
listMapWithChildrenAndDescendants: NotionPageTreeMap
searchValue: string
previewPageId: string
pagesMap: DataSourceNotionPageMap
}>) => {
const { t } = useTranslation()
const {
dataList,
handleToggle,
checkedIds,
disabledCheckedIds,
handleCheck,
canPreview,
handlePreview,
listMapWithChildrenAndDescendants,
searchValue,
previewPageId,
pagesMap,
} = data
const current = dataList[index]
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[current.page_id]
const hasChild = currentWithChildrenAndDescendants.descendants.size > 0
const ancestors = currentWithChildrenAndDescendants.ancestors
@@ -112,7 +109,7 @@ const ItemComponent = ({
<div
className="mr-1 flex h-5 w-5 shrink-0 items-center justify-center rounded-md hover:bg-components-button-ghost-bg-hover"
style={{ marginLeft: current.depth * 8 }}
onClick={() => onToggle(current.page_id)}
onClick={() => handleToggle(index)}
>
{
current.expand
@@ -135,21 +132,15 @@ const ItemComponent = ({
return (
<div
className={cn('group flex cursor-pointer items-center rounded-md pl-2 pr-[2px] hover:bg-state-base-hover', previewPageId === current.page_id && 'bg-state-base-hover')}
style={{
position: 'absolute',
top: 0,
left: 8,
right: 8,
width: 'calc(100% - 16px)',
height: virtualSize,
transform: `translateY(${virtualStart + 8}px)`,
}}
style={{ ...style, top: style.top as number + 8, left: 8, right: 8, width: 'calc(100% - 16px)' }}
>
<Checkbox
className="mr-2 shrink-0"
checked={checkedIds.has(current.page_id)}
disabled={disabled}
onCheck={() => onCheck(current.page_id)}
onCheck={() => {
handleCheck(index)
}}
/>
{!searchValue && renderArrow()}
<NotionIcon
@@ -169,7 +160,7 @@ const ItemComponent = ({
className="ml-1 hidden h-6 shrink-0 cursor-pointer items-center rounded-md border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-2 text-xs
font-medium leading-4 text-components-button-secondary-text shadow-xs shadow-shadow-shadow-3 backdrop-blur-[10px]
hover:border-components-button-secondary-border-hover hover:bg-components-button-secondary-bg-hover group-hover:flex"
onClick={() => onPreview(current.page_id)}
onClick={() => handlePreview(index)}
>
{t('dataSource.notion.selector.preview', { ns: 'common' })}
</div>
@@ -188,7 +179,7 @@ const ItemComponent = ({
</div>
)
}
const Item = memo(ItemComponent)
const Item = memo(ItemComponent, areEqual)
const PageSelector = ({
value,
@@ -202,10 +193,31 @@ const PageSelector = ({
onPreview,
}: PageSelectorProps) => {
const { t } = useTranslation()
const parentRef = useRef<HTMLDivElement>(null)
const [expandedIds, setExpandedIds] = useState<Set<string>>(() => new Set())
const [dataList, setDataList] = useState<NotionPageItem[]>([])
const [localPreviewPageId, setLocalPreviewPageId] = useState('')
useEffect(() => {
setDataList(list.filter(item => item.parent_id === 'root' || !pagesMap[item.parent_id]).map((item) => {
return {
...item,
expand: false,
depth: 0,
}
}))
}, [list])
const searchDataList = list.filter((item) => {
return item.page_name.includes(searchValue)
}).map((item) => {
return {
...item,
expand: false,
depth: 0,
}
})
const currentDataList = searchValue ? searchDataList : dataList
const currentPreviewPageId = previewPageId === undefined ? localPreviewPageId : previewPageId
const listMapWithChildrenAndDescendants = useMemo(() => {
return list.reduce((prev: NotionPageTreeMap, next: DataSourceNotionPage) => {
const pageId = next.page_id
@@ -217,89 +229,47 @@ const PageSelector = ({
}, {})
}, [list, pagesMap])
const childrenByParent = useMemo(() => {
const map = new Map<string | null, DataSourceNotionPage[]>()
for (const item of list) {
const isRoot = item.parent_id === 'root' || !pagesMap[item.parent_id]
const parentKey = isRoot ? null : item.parent_id
const children = map.get(parentKey) || []
children.push(item)
map.set(parentKey, children)
}
return map
}, [list, pagesMap])
const dataList = useMemo(() => {
const result: NotionPageItem[] = []
const buildVisibleList = (parentId: string | null, depth: number) => {
const items = childrenByParent.get(parentId) || []
for (const item of items) {
const isExpanded = expandedIds.has(item.page_id)
result.push({
...item,
expand: isExpanded,
depth,
})
if (isExpanded) {
buildVisibleList(item.page_id, depth + 1)
}
}
}
buildVisibleList(null, 0)
return result
}, [childrenByParent, expandedIds])
const searchDataList = useMemo(() => list.filter((item) => {
return item.page_name.includes(searchValue)
}).map((item) => {
return {
...item,
expand: false,
depth: 0,
}
}), [list, searchValue])
const currentDataList = searchValue ? searchDataList : dataList
const currentPreviewPageId = previewPageId === undefined ? localPreviewPageId : previewPageId
const virtualizer = useVirtualizer({
count: currentDataList.length,
getScrollElement: () => parentRef.current,
estimateSize: () => 28,
overscan: 5,
getItemKey: index => currentDataList[index].page_id,
})
const handleToggle = useCallback((pageId: string) => {
setExpandedIds((prev) => {
const next = new Set(prev)
if (prev.has(pageId)) {
next.delete(pageId)
const descendants = listMapWithChildrenAndDescendants[pageId]?.descendants
if (descendants) {
for (const descendantId of descendants)
next.delete(descendantId)
}
}
else {
next.add(pageId)
}
return next
})
}, [listMapWithChildrenAndDescendants])
const handleCheck = useCallback((pageId: string) => {
const handleToggle = (index: number) => {
const current = dataList[index]
const pageId = current.page_id
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId]
const descendantsIds = Array.from(currentWithChildrenAndDescendants.descendants)
const childrenIds = Array.from(currentWithChildrenAndDescendants.children)
let newDataList = []
if (current.expand) {
current.expand = false
newDataList = dataList.filter(item => !descendantsIds.includes(item.page_id))
}
else {
current.expand = true
newDataList = [
...dataList.slice(0, index + 1),
...childrenIds.map(item => ({
...pagesMap[item],
expand: false,
depth: listMapWithChildrenAndDescendants[item].depth,
})),
...dataList.slice(index + 1),
]
}
setDataList(newDataList)
}
const copyValue = new Set(value)
const handleCheck = (index: number) => {
const current = currentDataList[index]
const pageId = current.page_id
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId]
const copyValue = new Set(value)
if (copyValue.has(pageId)) {
if (!searchValue) {
for (const item of currentWithChildrenAndDescendants.descendants)
copyValue.delete(item)
}
copyValue.delete(pageId)
}
else {
@@ -307,17 +277,22 @@ const PageSelector = ({
for (const item of currentWithChildrenAndDescendants.descendants)
copyValue.add(item)
}
copyValue.add(pageId)
}
onSelect(copyValue)
}, [listMapWithChildrenAndDescendants, onSelect, searchValue, value])
onSelect(new Set(copyValue))
}
const handlePreview = (index: number) => {
const current = currentDataList[index]
const pageId = current.page_id
const handlePreview = useCallback((pageId: string) => {
setLocalPreviewPageId(pageId)
if (onPreview)
onPreview(pageId)
}, [onPreview])
}
if (!currentDataList.length) {
return (
@@ -328,41 +303,29 @@ const PageSelector = ({
}
return (
<div
ref={parentRef}
<List
className="py-2"
style={{ height: 296, width: '100%', overflow: 'auto' }}
height={296}
itemCount={currentDataList.length}
itemSize={28}
width="100%"
itemKey={(index, data) => data.dataList[index].page_id}
itemData={{
dataList: currentDataList,
handleToggle,
checkedIds: value,
disabledCheckedIds: disabledValue,
handleCheck,
canPreview,
handlePreview,
listMapWithChildrenAndDescendants,
searchValue,
previewPageId: currentPreviewPageId,
pagesMap,
}}
>
<div
style={{
height: virtualizer.getTotalSize(),
width: '100%',
position: 'relative',
}}
>
{virtualizer.getVirtualItems().map((virtualRow) => {
const current = currentDataList[virtualRow.index]
return (
<Item
key={virtualRow.key}
virtualStart={virtualRow.start}
virtualSize={virtualRow.size}
current={current}
onToggle={handleToggle}
checkedIds={value}
disabledCheckedIds={disabledValue}
onCheck={handleCheck}
canPreview={canPreview}
onPreview={handlePreview}
listMapWithChildrenAndDescendants={listMapWithChildrenAndDescendants}
searchValue={searchValue}
previewPageId={currentPreviewPageId}
pagesMap={pagesMap}
/>
)
})}
</div>
</div>
{Item}
</List>
)
}

View File

@@ -11,18 +11,21 @@ import { recursivePushInParentDescendants } from './utils'
// Note: react-i18next uses global mock from web/vitest.setup.ts
// Mock @tanstack/react-virtual useVirtualizer hook - renders items directly for testing
vi.mock('@tanstack/react-virtual', () => ({
useVirtualizer: ({ count, getItemKey }: { count: number, getItemKey?: (index: number) => string }) => ({
getVirtualItems: () =>
Array.from({ length: count }).map((_, index) => ({
index,
key: getItemKey ? getItemKey(index) : index,
start: index * 28,
size: 28,
})),
getTotalSize: () => count * 28,
}),
// Mock react-window FixedSizeList - renders items directly for testing
vi.mock('react-window', () => ({
FixedSizeList: ({ children: ItemComponent, itemCount, itemData, itemKey }: any) => (
<div data-testid="virtual-list">
{Array.from({ length: itemCount }).map((_, index) => (
<ItemComponent
key={itemKey?.(index, itemData) || index}
index={index}
style={{ top: index * 28, left: 0, right: 0, width: '100%', position: 'absolute' }}
data={itemData}
/>
))}
</div>
),
areEqual: (prevProps: any, nextProps: any) => prevProps === nextProps,
}))
// Note: NotionIcon from @/app/components/base/ is NOT mocked - using real component per testing guidelines
@@ -116,7 +119,7 @@ describe('PageSelector', () => {
render(<PageSelector {...props} />)
// Assert
expect(screen.getByText('Test Page')).toBeInTheDocument()
expect(screen.getByTestId('virtual-list')).toBeInTheDocument()
})
it('should render empty state when list is empty', () => {
@@ -131,7 +134,7 @@ describe('PageSelector', () => {
// Assert
expect(screen.getByText('common.dataSource.notion.selector.noSearchResult')).toBeInTheDocument()
expect(screen.queryByText('Test Page')).not.toBeInTheDocument()
expect(screen.queryByTestId('virtual-list')).not.toBeInTheDocument()
})
it('should render items using FixedSizeList', () => {
@@ -1163,7 +1166,7 @@ describe('PageSelector', () => {
render(<PageSelector {...props} />)
// Assert
expect(screen.getByText('Test Page')).toBeInTheDocument()
expect(screen.getByTestId('virtual-list')).toBeInTheDocument()
})
it('should handle special characters in page name', () => {
@@ -1337,7 +1340,7 @@ describe('PageSelector', () => {
render(<PageSelector {...props} />)
// Assert
expect(screen.getByText('Test Page')).toBeInTheDocument()
expect(screen.getByTestId('virtual-list')).toBeInTheDocument()
if (propVariation.canPreview)
expect(screen.getByText('common.dataSource.notion.selector.preview')).toBeInTheDocument()
else

View File

@@ -1,7 +1,7 @@
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
import { useVirtualizer } from '@tanstack/react-virtual'
import { useCallback, useMemo, useRef, useState } from 'react'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { FixedSizeList as List } from 'react-window'
import Item from './item'
import { recursivePushInParentDescendants } from './utils'
@@ -45,16 +45,29 @@ const PageSelector = ({
currentCredentialId,
}: PageSelectorProps) => {
const { t } = useTranslation()
const parentRef = useRef<HTMLDivElement>(null)
const [expandedIds, setExpandedIds] = useState<Set<string>>(() => new Set())
const [dataList, setDataList] = useState<NotionPageItem[]>([])
const [currentPreviewPageId, setCurrentPreviewPageId] = useState('')
const prevCredentialIdRef = useRef(currentCredentialId)
// Reset expanded state when credential changes (render-time detection)
if (prevCredentialIdRef.current !== currentCredentialId) {
prevCredentialIdRef.current = currentCredentialId
setExpandedIds(new Set())
}
useEffect(() => {
setDataList(list.filter(item => item.parent_id === 'root' || !pagesMap[item.parent_id]).map((item) => {
return {
...item,
expand: false,
depth: 0,
}
}))
}, [currentCredentialId])
const searchDataList = list.filter((item) => {
return item.page_name.includes(searchValue)
}).map((item) => {
return {
...item,
expand: false,
depth: 0,
}
})
const currentDataList = searchValue ? searchDataList : dataList
const listMapWithChildrenAndDescendants = useMemo(() => {
return list.reduce((prev: NotionPageTreeMap, next: DataSourceNotionPage) => {
@@ -67,86 +80,39 @@ const PageSelector = ({
}, {})
}, [list, pagesMap])
// Pre-build children index for O(1) lookup instead of O(n) filter
const childrenByParent = useMemo(() => {
const map = new Map<string | null, DataSourceNotionPage[]>()
for (const item of list) {
const isRoot = item.parent_id === 'root' || !pagesMap[item.parent_id]
const parentKey = isRoot ? null : item.parent_id
const children = map.get(parentKey) || []
children.push(item)
map.set(parentKey, children)
const handleToggle = useCallback((index: number) => {
const current = dataList[index]
const pageId = current.page_id
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId]
const descendantsIds = Array.from(currentWithChildrenAndDescendants.descendants)
const childrenIds = Array.from(currentWithChildrenAndDescendants.children)
let newDataList = []
if (current.expand) {
current.expand = false
newDataList = dataList.filter(item => !descendantsIds.includes(item.page_id))
}
return map
}, [list, pagesMap])
else {
current.expand = true
// Compute visible data list based on expanded state
const dataList = useMemo(() => {
const result: NotionPageItem[] = []
const buildVisibleList = (parentId: string | null, depth: number) => {
const items = childrenByParent.get(parentId) || []
for (const item of items) {
const isExpanded = expandedIds.has(item.page_id)
result.push({
...item,
expand: isExpanded,
depth,
})
if (isExpanded) {
buildVisibleList(item.page_id, depth + 1)
}
}
newDataList = [
...dataList.slice(0, index + 1),
...childrenIds.map(item => ({
...pagesMap[item],
expand: false,
depth: listMapWithChildrenAndDescendants[item].depth,
})),
...dataList.slice(index + 1),
]
}
setDataList(newDataList)
}, [dataList, listMapWithChildrenAndDescendants, pagesMap])
buildVisibleList(null, 0)
return result
}, [childrenByParent, expandedIds])
const searchDataList = useMemo(() => list.filter((item) => {
return item.page_name.includes(searchValue)
}).map((item) => {
return {
...item,
expand: false,
depth: 0,
}
}), [list, searchValue])
const currentDataList = searchValue ? searchDataList : dataList
const virtualizer = useVirtualizer({
count: currentDataList.length,
getScrollElement: () => parentRef.current,
estimateSize: () => 28,
overscan: 5,
getItemKey: index => currentDataList[index].page_id,
})
// Stable callback - no dependencies on dataList
const handleToggle = useCallback((pageId: string) => {
setExpandedIds((prev) => {
const next = new Set(prev)
if (prev.has(pageId)) {
// Collapse: remove current and all descendants
next.delete(pageId)
const descendants = listMapWithChildrenAndDescendants[pageId]?.descendants
if (descendants) {
for (const descendantId of descendants)
next.delete(descendantId)
}
}
else {
next.add(pageId)
}
return next
})
}, [listMapWithChildrenAndDescendants])
// Stable callback - uses pageId parameter instead of index
const handleCheck = useCallback((pageId: string) => {
const handleCheck = useCallback((index: number) => {
const copyValue = new Set(checkedIds)
const current = currentDataList[index]
const pageId = current.page_id
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId]
if (copyValue.has(pageId)) {
@@ -154,6 +120,7 @@ const PageSelector = ({
for (const item of currentWithChildrenAndDescendants.descendants)
copyValue.delete(item)
}
copyValue.delete(pageId)
}
else {
@@ -171,15 +138,18 @@ const PageSelector = ({
}
}
onSelect(copyValue)
}, [checkedIds, isMultipleChoice, listMapWithChildrenAndDescendants, onSelect, searchValue])
onSelect(new Set(copyValue))
}, [currentDataList, isMultipleChoice, listMapWithChildrenAndDescendants, onSelect, searchValue, checkedIds])
const handlePreview = useCallback((index: number) => {
const current = currentDataList[index]
const pageId = current.page_id
// Stable callback
const handlePreview = useCallback((pageId: string) => {
setCurrentPreviewPageId(pageId)
if (onPreview)
onPreview(pageId)
}, [onPreview])
}, [currentDataList, onPreview])
if (!currentDataList.length) {
return (
@@ -190,42 +160,30 @@ const PageSelector = ({
}
return (
<div
ref={parentRef}
<List
className="py-2"
style={{ height: 296, width: '100%', overflow: 'auto' }}
height={296}
itemCount={currentDataList.length}
itemSize={28}
width="100%"
itemKey={(index, data) => data.dataList[index].page_id}
itemData={{
dataList: currentDataList,
handleToggle,
checkedIds,
disabledCheckedIds: disabledValue,
handleCheck,
canPreview,
handlePreview,
listMapWithChildrenAndDescendants,
searchValue,
previewPageId: currentPreviewPageId,
pagesMap,
isMultipleChoice,
}}
>
<div
style={{
height: virtualizer.getTotalSize(),
width: '100%',
position: 'relative',
}}
>
{virtualizer.getVirtualItems().map((virtualRow) => {
const current = currentDataList[virtualRow.index]
return (
<Item
key={virtualRow.key}
virtualStart={virtualRow.start}
virtualSize={virtualRow.size}
current={current}
onToggle={handleToggle}
checkedIds={checkedIds}
disabledCheckedIds={disabledValue}
onCheck={handleCheck}
canPreview={canPreview}
onPreview={handlePreview}
listMapWithChildrenAndDescendants={listMapWithChildrenAndDescendants}
searchValue={searchValue}
previewPageId={currentPreviewPageId}
pagesMap={pagesMap}
isMultipleChoice={isMultipleChoice}
/>
)
})}
</div>
</div>
{Item}
</List>
)
}

View File

@@ -1,7 +1,9 @@
import type { ListChildComponentProps } from 'react-window'
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
import { RiArrowDownSLine, RiArrowRightSLine } from '@remixicon/react'
import { memo } from 'react'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { areEqual } from 'react-window'
import Checkbox from '@/app/components/base/checkbox'
import NotionIcon from '@/app/components/base/notion-icon'
import Radio from '@/app/components/base/radio/ui'
@@ -21,40 +23,36 @@ type NotionPageItem = {
depth: number
} & DataSourceNotionPage
type ItemProps = {
virtualStart: number
virtualSize: number
current: NotionPageItem
onToggle: (pageId: string) => void
const Item = ({ index, style, data }: ListChildComponentProps<{
dataList: NotionPageItem[]
handleToggle: (index: number) => void
checkedIds: Set<string>
disabledCheckedIds: Set<string>
onCheck: (pageId: string) => void
handleCheck: (index: number) => void
canPreview?: boolean
onPreview: (pageId: string) => void
handlePreview: (index: number) => void
listMapWithChildrenAndDescendants: NotionPageTreeMap
searchValue: string
previewPageId: string
pagesMap: DataSourceNotionPageMap
isMultipleChoice?: boolean
}
const Item = ({
virtualStart,
virtualSize,
current,
onToggle,
checkedIds,
disabledCheckedIds,
onCheck,
canPreview,
onPreview,
listMapWithChildrenAndDescendants,
searchValue,
previewPageId,
pagesMap,
isMultipleChoice,
}: ItemProps) => {
}>) => {
const { t } = useTranslation()
const {
dataList,
handleToggle,
checkedIds,
disabledCheckedIds,
handleCheck,
canPreview,
handlePreview,
listMapWithChildrenAndDescendants,
searchValue,
previewPageId,
pagesMap,
isMultipleChoice,
} = data
const current = dataList[index]
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[current.page_id]
const hasChild = currentWithChildrenAndDescendants.descendants.size > 0
const ancestors = currentWithChildrenAndDescendants.ancestors
@@ -67,7 +65,7 @@ const Item = ({
<div
className="mr-1 flex h-5 w-5 shrink-0 items-center justify-center rounded-md hover:bg-components-button-ghost-bg-hover"
style={{ marginLeft: current.depth * 8 }}
onClick={() => onToggle(current.page_id)}
onClick={() => handleToggle(index)}
>
{
current.expand
@@ -90,15 +88,7 @@ const Item = ({
return (
<div
className={cn('group flex cursor-pointer items-center rounded-md pl-2 pr-[2px] hover:bg-state-base-hover', previewPageId === current.page_id && 'bg-state-base-hover')}
style={{
position: 'absolute',
top: 0,
left: 8,
right: 8,
width: 'calc(100% - 16px)',
height: virtualSize,
transform: `translateY(${virtualStart + 8}px)`,
}}
style={{ ...style, top: style.top as number + 8, left: 8, right: 8, width: 'calc(100% - 16px)' }}
>
{isMultipleChoice
? (
@@ -106,7 +96,9 @@ const Item = ({
className="mr-2 shrink-0"
checked={checkedIds.has(current.page_id)}
disabled={disabled}
onCheck={() => onCheck(current.page_id)}
onCheck={() => {
handleCheck(index)
}}
/>
)
: (
@@ -114,7 +106,9 @@ const Item = ({
className="mr-2 shrink-0"
isChecked={checkedIds.has(current.page_id)}
disabled={disabled}
onCheck={() => onCheck(current.page_id)}
onCheck={() => {
handleCheck(index)
}}
/>
)}
{!searchValue && renderArrow()}
@@ -135,7 +129,7 @@ const Item = ({
className="ml-1 hidden h-6 shrink-0 cursor-pointer items-center rounded-md border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-2 text-xs
font-medium leading-4 text-components-button-secondary-text shadow-xs shadow-shadow-shadow-3 backdrop-blur-[10px]
hover:border-components-button-secondary-border-hover hover:bg-components-button-secondary-bg-hover group-hover:flex"
onClick={() => onPreview(current.page_id)}
onClick={() => handlePreview(index)}
>
{t('dataSource.notion.selector.preview', { ns: 'common' })}
</div>
@@ -155,4 +149,4 @@ const Item = ({
)
}
export default memo(Item)
export default React.memo(Item, areEqual)

View File

@@ -3,8 +3,6 @@ import type { FC, ReactNode } from 'react'
import type { SliceProps } from './type'
import { autoUpdate, flip, FloatingFocusManager, offset, shift, useDismiss, useFloating, useHover, useInteractions, useRole } from '@floating-ui/react'
import { RiDeleteBinLine } from '@remixicon/react'
// @ts-expect-error no types available
import lineClamp from 'line-clamp'
import { useState } from 'react'
import ActionButton, { ActionButtonState } from '@/app/components/base/action-button'
import { cn } from '@/utils/classnames'
@@ -58,12 +56,8 @@ export const EditSlice: FC<EditSliceProps> = (props) => {
<>
<SliceContainer
{...rest}
className={cn('mr-0 block', className)}
ref={(ref) => {
refs.setReference(ref)
if (ref)
lineClamp(ref, 4)
}}
className={cn('mr-0 line-clamp-4 block', className)}
ref={refs.setReference}
{...getReferenceProps()}
>
<SliceLabel

View File

@@ -74,11 +74,15 @@ const AppCard = ({
</div>
{isExplore && (canCreate || isTrialApp) && (
<div className={cn('absolute bottom-0 left-0 right-0 hidden bg-gradient-to-t from-components-panel-gradient-2 from-[60.27%] to-transparent p-4 pt-8 group-hover:flex')}>
<div className={cn('grid h-8 w-full grid-cols-2 space-x-2')}>
<Button variant="primary" className="h-7" onClick={() => onCreate()}>
<PlusIcon className="mr-1 h-4 w-4" />
<span className="text-xs">{t('appCard.addToWorkspace', { ns: 'explore' })}</span>
</Button>
<div className={cn('grid h-8 w-full grid-cols-1 space-x-2', canCreate && 'grid-cols-2')}>
{
canCreate && (
<Button variant="primary" className="h-7" onClick={() => onCreate()}>
<PlusIcon className="mr-1 h-4 w-4" />
<span className="text-xs">{t('appCard.addToWorkspace', { ns: 'explore' })}</span>
</Button>
)
}
<Button className="h-7" onClick={showTryAPPPanel(app.app_id)}>
<RiInformation2Line className="mr-1 size-4" />
<span>{t('appCard.try', { ns: 'explore' })}</span>

View File

@@ -138,7 +138,7 @@ describe('CreateAppModal', () => {
setup({ appName: 'My App', isEditModal: false })
expect(screen.getByText('explore.appCustomize.title:{"name":"My App"}')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeInTheDocument()
expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.cancel' })).toBeInTheDocument()
})
@@ -146,7 +146,7 @@ describe('CreateAppModal', () => {
setup({ isEditModal: true, appMode: AppModeEnum.CHAT, max_active_requests: 5 })
expect(screen.getByText('app.editAppTitle')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
expect(screen.getByRole('button', { name: /common\.operation\.save/ })).toBeInTheDocument()
expect(screen.getByRole('switch')).toBeInTheDocument()
expect((screen.getByRole('spinbutton') as HTMLInputElement).value).toBe('5')
})
@@ -166,7 +166,7 @@ describe('CreateAppModal', () => {
it('should not render modal content when hidden', () => {
setup({ show: false })
expect(screen.queryByRole('button', { name: 'common.operation.create' })).not.toBeInTheDocument()
expect(screen.queryByRole('button', { name: /common\.operation\.create/ })).not.toBeInTheDocument()
})
})
@@ -175,13 +175,13 @@ describe('CreateAppModal', () => {
it('should disable confirm action when confirmDisabled is true', () => {
setup({ confirmDisabled: true })
expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled()
expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled()
})
it('should disable confirm action when appName is empty', () => {
setup({ appName: ' ' })
expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled()
expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled()
})
})
@@ -245,7 +245,7 @@ describe('CreateAppModal', () => {
setup({ isEditModal: false })
expect(screen.getByText('billing.apps.fullTip2')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled()
expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled()
})
it('should allow saving when apps quota is reached in edit mode', () => {
@@ -257,7 +257,7 @@ describe('CreateAppModal', () => {
setup({ isEditModal: true })
expect(screen.queryByText('billing.apps.fullTip2')).not.toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeEnabled()
expect(screen.getByRole('button', { name: /common\.operation\.save/ })).toBeEnabled()
})
})
@@ -384,7 +384,7 @@ describe('CreateAppModal', () => {
fireEvent.click(screen.getByRole('button', { name: 'app.iconPicker.ok' }))
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
act(() => {
vi.advanceTimersByTime(300)
})
@@ -433,7 +433,7 @@ describe('CreateAppModal', () => {
expect(screen.queryByRole('button', { name: 'app.iconPicker.cancel' })).not.toBeInTheDocument()
// Submit and verify the payload uses the original icon (cancel reverts to props)
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
act(() => {
vi.advanceTimersByTime(300)
})
@@ -471,7 +471,7 @@ describe('CreateAppModal', () => {
appIconBackground: '#000000',
})
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
act(() => {
vi.advanceTimersByTime(300)
})
@@ -495,7 +495,7 @@ describe('CreateAppModal', () => {
const { onConfirm } = setup({ appDescription: 'Old description' })
fireEvent.change(screen.getByPlaceholderText('app.newApp.appDescriptionPlaceholder'), { target: { value: 'Updated description' } })
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
act(() => {
vi.advanceTimersByTime(300)
})
@@ -512,7 +512,7 @@ describe('CreateAppModal', () => {
appIconBackground: null,
})
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
act(() => {
vi.advanceTimersByTime(300)
})
@@ -536,7 +536,7 @@ describe('CreateAppModal', () => {
fireEvent.click(screen.getByRole('switch'))
fireEvent.change(screen.getByRole('spinbutton'), { target: { value: '12' } })
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ }))
act(() => {
vi.advanceTimersByTime(300)
})
@@ -551,7 +551,7 @@ describe('CreateAppModal', () => {
it('should omit max_active_requests when input is empty', () => {
const { onConfirm } = setup({ isEditModal: true, max_active_requests: null })
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ }))
act(() => {
vi.advanceTimersByTime(300)
})
@@ -564,7 +564,7 @@ describe('CreateAppModal', () => {
const { onConfirm } = setup({ isEditModal: true, max_active_requests: null })
fireEvent.change(screen.getByRole('spinbutton'), { target: { value: 'abc' } })
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ }))
act(() => {
vi.advanceTimersByTime(300)
})
@@ -576,7 +576,7 @@ describe('CreateAppModal', () => {
it('should show toast error and not submit when name becomes empty before debounced submit runs', () => {
const { onConfirm, onHide } = setup({ appName: 'My App' })
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
fireEvent.change(screen.getByPlaceholderText('app.newApp.appNamePlaceholder'), { target: { value: ' ' } })
act(() => {

View File

@@ -16,6 +16,14 @@ vi.mock('react-i18next', () => ({
}),
}))
vi.mock('@/config', async (importOriginal) => {
const actual = await importOriginal() as object
return {
...actual,
IS_CLOUD_EDITION: true,
}
})
const mockUseGetTryAppInfo = vi.fn()
vi.mock('@/service/use-try-app', () => ({

View File

@@ -14,6 +14,14 @@ vi.mock('react-i18next', () => ({
}),
}))
vi.mock('@/config', async (importOriginal) => {
const actual = await importOriginal() as object
return {
...actual,
IS_CLOUD_EDITION: true,
}
})
describe('Tab', () => {
afterEach(() => {
cleanup()

View File

@@ -81,4 +81,205 @@ describe('CommandSelector', () => {
expect(onSelect).toHaveBeenCalledWith('/zen')
})
it('should show all slash commands when no filter provided', () => {
const actions = createActions()
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter=""
originalQuery="/"
/>
</Command>,
)
// Should show the zen command from mock
expect(screen.getByText('/zen')).toBeInTheDocument()
})
it('should exclude slash action when in @ mode', () => {
const actions = {
...createActions(),
slash: {
key: '/',
shortcut: '/',
title: 'Slash',
search: vi.fn(),
description: '',
} as ActionItem,
}
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter=""
originalQuery="@"
/>
</Command>,
)
// Should show @ commands but not /
expect(screen.getByText('@app')).toBeInTheDocument()
expect(screen.queryByText('/')).not.toBeInTheDocument()
})
it('should show all actions when no filter in @ mode', () => {
const actions = createActions()
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter=""
originalQuery="@"
/>
</Command>,
)
expect(screen.getByText('@app')).toBeInTheDocument()
expect(screen.getByText('@plugin')).toBeInTheDocument()
})
it('should set default command value when items exist but value does not', () => {
const actions = createActions()
const onSelect = vi.fn()
const onCommandValueChange = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter=""
originalQuery="@"
commandValue="non-existent"
onCommandValueChange={onCommandValueChange}
/>
</Command>,
)
expect(onCommandValueChange).toHaveBeenCalledWith('@app')
})
it('should NOT set command value when value already exists in items', () => {
const actions = createActions()
const onSelect = vi.fn()
const onCommandValueChange = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter=""
originalQuery="@"
commandValue="@app"
onCommandValueChange={onCommandValueChange}
/>
</Command>,
)
expect(onCommandValueChange).not.toHaveBeenCalled()
})
it('should show no matching commands message when filter has no results', () => {
const actions = createActions()
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter="nonexistent"
originalQuery="@nonexistent"
/>
</Command>,
)
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument()
})
it('should show no matching commands for slash mode with no results', () => {
const actions = createActions()
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter="nonexistentcommand"
originalQuery="/nonexistentcommand"
/>
</Command>,
)
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
})
it('should render description for @ commands', () => {
const actions = createActions()
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter=""
originalQuery="@"
/>
</Command>,
)
expect(screen.getByText('app.gotoAnything.actions.searchApplicationsDesc')).toBeInTheDocument()
expect(screen.getByText('app.gotoAnything.actions.searchPluginsDesc')).toBeInTheDocument()
})
it('should render group header for @ mode', () => {
const actions = createActions()
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter=""
originalQuery="@"
/>
</Command>,
)
expect(screen.getByText('app.gotoAnything.selectSearchType')).toBeInTheDocument()
})
it('should render group header for slash mode', () => {
const actions = createActions()
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter=""
originalQuery="/"
/>
</Command>,
)
expect(screen.getByText('app.gotoAnything.groups.commands')).toBeInTheDocument()
})
})

View File

@@ -0,0 +1,157 @@
import { render, screen } from '@testing-library/react'
import EmptyState from './empty-state'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string, options?: { ns?: string, shortcuts?: string }) => {
if (options?.shortcuts !== undefined)
return `${key}:${options.shortcuts}`
return `${options?.ns || 'common'}.${key}`
},
}),
}))
describe('EmptyState', () => {
describe('loading variant', () => {
it('should render loading spinner', () => {
render(<EmptyState variant="loading" />)
expect(screen.getByText('app.gotoAnything.searching')).toBeInTheDocument()
})
it('should have spinner animation class', () => {
const { container } = render(<EmptyState variant="loading" />)
const spinner = container.querySelector('.animate-spin')
expect(spinner).toBeInTheDocument()
})
})
describe('error variant', () => {
it('should render error message when error has message', () => {
const error = new Error('Connection failed')
render(<EmptyState variant="error" error={error} />)
expect(screen.getByText('app.gotoAnything.searchFailed')).toBeInTheDocument()
expect(screen.getByText('Connection failed')).toBeInTheDocument()
})
it('should render generic error when error has no message', () => {
render(<EmptyState variant="error" error={null} />)
expect(screen.getByText('app.gotoAnything.searchTemporarilyUnavailable')).toBeInTheDocument()
expect(screen.getByText('app.gotoAnything.servicesUnavailableMessage')).toBeInTheDocument()
})
it('should render generic error when error is undefined', () => {
render(<EmptyState variant="error" />)
expect(screen.getByText('app.gotoAnything.searchTemporarilyUnavailable')).toBeInTheDocument()
})
it('should have red error text styling', () => {
const error = new Error('Test error')
const { container } = render(<EmptyState variant="error" error={error} />)
const errorText = container.querySelector('.text-red-500')
expect(errorText).toBeInTheDocument()
})
})
describe('default variant', () => {
it('should render search title', () => {
render(<EmptyState variant="default" />)
expect(screen.getByText('app.gotoAnything.searchTitle')).toBeInTheDocument()
})
it('should render all hint messages', () => {
render(<EmptyState variant="default" />)
expect(screen.getByText('app.gotoAnything.searchHint')).toBeInTheDocument()
expect(screen.getByText('app.gotoAnything.commandHint')).toBeInTheDocument()
expect(screen.getByText('app.gotoAnything.slashHint')).toBeInTheDocument()
})
})
describe('no-results variant', () => {
describe('general search mode', () => {
it('should render generic no results message', () => {
render(<EmptyState variant="no-results" searchMode="general" />)
expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument()
})
it('should show specific search hint with shortcuts', () => {
const Actions = {
app: { key: '@app', shortcut: '@app' },
plugin: { key: '@plugin', shortcut: '@plugin' },
} as unknown as Record<string, import('../actions/types').ActionItem>
render(<EmptyState variant="no-results" searchMode="general" Actions={Actions} />)
expect(screen.getByText('gotoAnything.emptyState.trySpecificSearch:@app, @plugin')).toBeInTheDocument()
})
})
describe('app search mode', () => {
it('should render no apps found message', () => {
render(<EmptyState variant="no-results" searchMode="@app" />)
expect(screen.getByText('app.gotoAnything.emptyState.noAppsFound')).toBeInTheDocument()
})
it('should show try different term hint', () => {
render(<EmptyState variant="no-results" searchMode="@app" />)
expect(screen.getByText('app.gotoAnything.emptyState.tryDifferentTerm')).toBeInTheDocument()
})
})
describe('plugin search mode', () => {
it('should render no plugins found message', () => {
render(<EmptyState variant="no-results" searchMode="@plugin" />)
expect(screen.getByText('app.gotoAnything.emptyState.noPluginsFound')).toBeInTheDocument()
})
})
describe('knowledge search mode', () => {
it('should render no knowledge bases found message', () => {
render(<EmptyState variant="no-results" searchMode="@knowledge" />)
expect(screen.getByText('app.gotoAnything.emptyState.noKnowledgeBasesFound')).toBeInTheDocument()
})
})
describe('node search mode', () => {
it('should render no workflow nodes found message', () => {
render(<EmptyState variant="no-results" searchMode="@node" />)
expect(screen.getByText('app.gotoAnything.emptyState.noWorkflowNodesFound')).toBeInTheDocument()
})
})
describe('unknown search mode', () => {
it('should fallback to generic no results message', () => {
render(<EmptyState variant="no-results" searchMode="@unknown" />)
expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument()
})
})
})
describe('default props', () => {
it('should use general as default searchMode', () => {
render(<EmptyState variant="no-results" />)
expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument()
})
it('should use empty object as default Actions', () => {
render(<EmptyState variant="no-results" searchMode="general" />)
// Should show empty shortcuts
expect(screen.getByText('gotoAnything.emptyState.trySpecificSearch:')).toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,105 @@
'use client'
import type { FC } from 'react'
import type { ActionItem } from '../actions/types'
import { useTranslation } from 'react-i18next'
export type EmptyStateVariant = 'no-results' | 'error' | 'default' | 'loading'
export type EmptyStateProps = {
variant: EmptyStateVariant
searchMode?: string
error?: Error | null
Actions?: Record<string, ActionItem>
}
const EmptyState: FC<EmptyStateProps> = ({
variant,
searchMode = 'general',
error,
Actions = {},
}) => {
const { t } = useTranslation()
if (variant === 'loading') {
return (
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
<div className="flex items-center gap-2">
<div className="h-4 w-4 animate-spin rounded-full border-2 border-gray-300 border-t-gray-600"></div>
<span className="text-sm">{t('gotoAnything.searching', { ns: 'app' })}</span>
</div>
</div>
)
}
if (variant === 'error') {
return (
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
<div>
<div className="text-sm font-medium text-red-500">
{error?.message
? t('gotoAnything.searchFailed', { ns: 'app' })
: t('gotoAnything.searchTemporarilyUnavailable', { ns: 'app' })}
</div>
<div className="mt-1 text-xs text-text-quaternary">
{error?.message || t('gotoAnything.servicesUnavailableMessage', { ns: 'app' })}
</div>
</div>
</div>
)
}
if (variant === 'default') {
return (
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
<div>
<div className="text-sm font-medium">{t('gotoAnything.searchTitle', { ns: 'app' })}</div>
<div className="mt-3 space-y-1 text-xs text-text-quaternary">
<div>{t('gotoAnything.searchHint', { ns: 'app' })}</div>
<div>{t('gotoAnything.commandHint', { ns: 'app' })}</div>
<div>{t('gotoAnything.slashHint', { ns: 'app' })}</div>
</div>
</div>
</div>
)
}
// variant === 'no-results'
const isCommandSearch = searchMode !== 'general'
const commandType = isCommandSearch ? searchMode.replace('@', '') : ''
const getNoResultsMessage = () => {
if (!isCommandSearch) {
return t('gotoAnything.noResults', { ns: 'app' })
}
const keyMap = {
app: 'gotoAnything.emptyState.noAppsFound',
plugin: 'gotoAnything.emptyState.noPluginsFound',
knowledge: 'gotoAnything.emptyState.noKnowledgeBasesFound',
node: 'gotoAnything.emptyState.noWorkflowNodesFound',
} as const
return t(keyMap[commandType as keyof typeof keyMap] || 'gotoAnything.noResults', { ns: 'app' })
}
const getHintMessage = () => {
if (isCommandSearch) {
return t('gotoAnything.emptyState.tryDifferentTerm', { ns: 'app' })
}
const shortcuts = Object.values(Actions).map(action => action.shortcut).join(', ')
return t('gotoAnything.emptyState.trySpecificSearch', { ns: 'app', shortcuts })
}
return (
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
<div>
<div className="text-sm font-medium">{getNoResultsMessage()}</div>
<div className="mt-1 text-xs text-text-quaternary">{getHintMessage()}</div>
</div>
</div>
)
}
export default EmptyState

View File

@@ -0,0 +1,273 @@
import { render, screen } from '@testing-library/react'
import Footer from './footer'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string, options?: { ns?: string, count?: number, scope?: string }) => {
if (options?.count !== undefined)
return `${key}:${options.count}`
if (options?.scope)
return `${key}:${options.scope}`
return `${options?.ns || 'common'}.${key}`
},
}),
}))
describe('Footer', () => {
describe('left content', () => {
describe('when there are results', () => {
it('should show result count', () => {
render(
<Footer
resultCount={5}
searchMode="general"
isError={false}
isCommandsMode={false}
hasQuery={true}
/>,
)
expect(screen.getByText('gotoAnything.resultCount:5')).toBeInTheDocument()
})
it('should show scope when not in general mode', () => {
render(
<Footer
resultCount={3}
searchMode="@app"
isError={false}
isCommandsMode={false}
hasQuery={true}
/>,
)
expect(screen.getByText('gotoAnything.inScope:app')).toBeInTheDocument()
})
it('should NOT show scope when in general mode', () => {
render(
<Footer
resultCount={3}
searchMode="general"
isError={false}
isCommandsMode={false}
hasQuery={true}
/>,
)
expect(screen.queryByText(/inScope/)).not.toBeInTheDocument()
})
})
describe('when there is an error', () => {
it('should show error message', () => {
render(
<Footer
resultCount={0}
searchMode="general"
isError={true}
isCommandsMode={false}
hasQuery={true}
/>,
)
expect(screen.getByText('app.gotoAnything.someServicesUnavailable')).toBeInTheDocument()
})
it('should have red text styling', () => {
const { container } = render(
<Footer
resultCount={0}
searchMode="general"
isError={true}
isCommandsMode={false}
hasQuery={true}
/>,
)
const errorText = container.querySelector('.text-red-500')
expect(errorText).toBeInTheDocument()
})
it('should show error even with results', () => {
render(
<Footer
resultCount={5}
searchMode="general"
isError={true}
isCommandsMode={false}
hasQuery={true}
/>,
)
expect(screen.getByText('app.gotoAnything.someServicesUnavailable')).toBeInTheDocument()
})
})
describe('when no results and no error', () => {
it('should show select to navigate in commands mode', () => {
render(
<Footer
resultCount={0}
searchMode="general"
isError={false}
isCommandsMode={true}
hasQuery={false}
/>,
)
expect(screen.getByText('app.gotoAnything.selectToNavigate')).toBeInTheDocument()
})
it('should show searching when has query', () => {
render(
<Footer
resultCount={0}
searchMode="general"
isError={false}
isCommandsMode={false}
hasQuery={true}
/>,
)
expect(screen.getByText('app.gotoAnything.searching')).toBeInTheDocument()
})
it('should show start typing when no query', () => {
render(
<Footer
resultCount={0}
searchMode="general"
isError={false}
isCommandsMode={false}
hasQuery={false}
/>,
)
expect(screen.getByText('app.gotoAnything.startTyping')).toBeInTheDocument()
})
})
})
describe('right content', () => {
describe('when there are results or error', () => {
it('should show clear to search all when in specific mode', () => {
render(
<Footer
resultCount={5}
searchMode="@app"
isError={false}
isCommandsMode={false}
hasQuery={true}
/>,
)
expect(screen.getByText('app.gotoAnything.clearToSearchAll')).toBeInTheDocument()
})
it('should show use @ for specific when in general mode', () => {
render(
<Footer
resultCount={5}
searchMode="general"
isError={false}
isCommandsMode={false}
hasQuery={true}
/>,
)
expect(screen.getByText('app.gotoAnything.useAtForSpecific')).toBeInTheDocument()
})
it('should show same hint when error', () => {
render(
<Footer
resultCount={0}
searchMode="general"
isError={true}
isCommandsMode={false}
hasQuery={true}
/>,
)
expect(screen.getByText('app.gotoAnything.useAtForSpecific')).toBeInTheDocument()
})
})
describe('when no results and no error', () => {
it('should show tips when has query', () => {
render(
<Footer
resultCount={0}
searchMode="general"
isError={false}
isCommandsMode={false}
hasQuery={true}
/>,
)
expect(screen.getByText('app.gotoAnything.tips')).toBeInTheDocument()
})
it('should show tips when in commands mode', () => {
render(
<Footer
resultCount={0}
searchMode="general"
isError={false}
isCommandsMode={true}
hasQuery={false}
/>,
)
expect(screen.getByText('app.gotoAnything.tips')).toBeInTheDocument()
})
it('should show press ESC to close when no query and not in commands mode', () => {
render(
<Footer
resultCount={0}
searchMode="general"
isError={false}
isCommandsMode={false}
hasQuery={false}
/>,
)
expect(screen.getByText('app.gotoAnything.pressEscToClose')).toBeInTheDocument()
})
})
})
describe('styling', () => {
it('should have border and background classes', () => {
const { container } = render(
<Footer
resultCount={0}
searchMode="general"
isError={false}
isCommandsMode={false}
hasQuery={false}
/>,
)
const footer = container.firstChild
expect(footer).toHaveClass('border-t', 'border-divider-subtle', 'bg-components-panel-bg-blur')
})
it('should have flex layout for content', () => {
const { container } = render(
<Footer
resultCount={0}
searchMode="general"
isError={false}
isCommandsMode={false}
hasQuery={false}
/>,
)
const flexContainer = container.querySelector('.flex.items-center.justify-between')
expect(flexContainer).toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,90 @@
'use client'
import type { FC } from 'react'
import { useTranslation } from 'react-i18next'
export type FooterProps = {
resultCount: number
searchMode: string
isError: boolean
isCommandsMode: boolean
hasQuery: boolean
}
const Footer: FC<FooterProps> = ({
resultCount,
searchMode,
isError,
isCommandsMode,
hasQuery,
}) => {
const { t } = useTranslation()
const renderLeftContent = () => {
if (resultCount > 0 || isError) {
if (isError) {
return (
<span className="text-red-500">
{t('gotoAnything.someServicesUnavailable', { ns: 'app' })}
</span>
)
}
return (
<>
{t('gotoAnything.resultCount', { ns: 'app', count: resultCount })}
{searchMode !== 'general' && (
<span className="ml-2 opacity-60">
{t('gotoAnything.inScope', { ns: 'app', scope: searchMode.replace('@', '') })}
</span>
)}
</>
)
}
return (
<span className="opacity-60">
{(() => {
if (isCommandsMode)
return t('gotoAnything.selectToNavigate', { ns: 'app' })
if (hasQuery)
return t('gotoAnything.searching', { ns: 'app' })
return t('gotoAnything.startTyping', { ns: 'app' })
})()}
</span>
)
}
const renderRightContent = () => {
if (resultCount > 0 || isError) {
return (
<span className="opacity-60">
{searchMode !== 'general'
? t('gotoAnything.clearToSearchAll', { ns: 'app' })
: t('gotoAnything.useAtForSpecific', { ns: 'app' })}
</span>
)
}
return (
<span className="opacity-60">
{hasQuery || isCommandsMode
? t('gotoAnything.tips', { ns: 'app' })
: t('gotoAnything.pressEscToClose', { ns: 'app' })}
</span>
)
}
return (
<div className="border-t border-divider-subtle bg-components-panel-bg-blur px-4 py-2 text-xs text-text-tertiary">
<div className="flex min-h-[16px] items-center justify-between">
<span>{renderLeftContent()}</span>
{renderRightContent()}
</div>
</div>
)
}
export default Footer

View File

@@ -0,0 +1,14 @@
export { default as EmptyState } from './empty-state'
export type { EmptyStateProps, EmptyStateVariant } from './empty-state'
export { default as Footer } from './footer'
export type { FooterProps } from './footer'
export { default as ResultItem } from './result-item'
export type { ResultItemProps } from './result-item'
export { default as ResultList } from './result-list'
export type { ResultListProps } from './result-list'
export { default as SearchInput } from './search-input'
export type { SearchInputProps } from './search-input'

View File

@@ -0,0 +1,38 @@
'use client'
import type { FC } from 'react'
import type { SearchResult } from '../actions/types'
import { Command } from 'cmdk'
export type ResultItemProps = {
result: SearchResult
onSelect: () => void
}
const ResultItem: FC<ResultItemProps> = ({ result, onSelect }) => {
return (
<Command.Item
key={`${result.type}-${result.id}`}
value={`${result.type}-${result.id}`}
className="flex cursor-pointer items-center gap-3 rounded-md p-3 will-change-[background-color] hover:bg-state-base-hover aria-[selected=true]:bg-state-base-hover-alt data-[selected=true]:bg-state-base-hover-alt"
onSelect={onSelect}
>
{result.icon}
<div className="min-w-0 flex-1">
<div className="truncate font-medium text-text-secondary">
{result.title}
</div>
{result.description && (
<div className="mt-0.5 truncate text-xs text-text-quaternary">
{result.description}
</div>
)}
</div>
<div className="text-xs capitalize text-text-quaternary">
{result.type}
</div>
</Command.Item>
)
}
export default ResultItem

View File

@@ -0,0 +1,49 @@
'use client'
import type { FC } from 'react'
import type { SearchResult } from '../actions/types'
import { Command } from 'cmdk'
import { useTranslation } from 'react-i18next'
import ResultItem from './result-item'
export type ResultListProps = {
groupedResults: Record<string, SearchResult[]>
onSelect: (result: SearchResult) => void
}
const ResultList: FC<ResultListProps> = ({ groupedResults, onSelect }) => {
const { t } = useTranslation()
const getGroupHeading = (type: string) => {
const typeMap = {
'app': 'gotoAnything.groups.apps',
'plugin': 'gotoAnything.groups.plugins',
'knowledge': 'gotoAnything.groups.knowledgeBases',
'workflow-node': 'gotoAnything.groups.workflowNodes',
'command': 'gotoAnything.groups.commands',
} as const
return t(typeMap[type as keyof typeof typeMap] || `${type}s`, { ns: 'app' })
}
return (
<>
{Object.entries(groupedResults).map(([type, results]) => (
<Command.Group
key={type}
heading={getGroupHeading(type)}
className="p-2 capitalize text-text-secondary"
>
{results.map(result => (
<ResultItem
key={`${result.type}-${result.id}`}
result={result}
onSelect={() => onSelect(result)}
/>
))}
</Command.Group>
))}
</>
)
}
export default ResultList

View File

@@ -0,0 +1,206 @@
import type { ChangeEvent, KeyboardEvent, RefObject } from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import SearchInput from './search-input'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string, options?: { ns?: string }) => `${options?.ns || 'common'}.${key}`,
}),
}))
vi.mock('@remixicon/react', () => ({
RiSearchLine: ({ className }: { className?: string }) => (
<svg data-testid="search-icon" className={className} />
),
}))
vi.mock('@/app/components/workflow/shortcuts-name', () => ({
default: ({ keys, textColor }: { keys: string[], textColor: string }) => (
<div data-testid="shortcuts-name" data-keys={keys.join(',')} data-color={textColor}>
{keys.join('+')}
</div>
),
}))
vi.mock('@/app/components/base/input', async () => {
const { forwardRef } = await import('react')
type MockInputProps = {
value?: string
placeholder?: string
onChange?: (e: ChangeEvent<HTMLInputElement>) => void
onKeyDown?: (e: KeyboardEvent<HTMLInputElement>) => void
className?: string
wrapperClassName?: string
autoFocus?: boolean
}
const MockInput = forwardRef<HTMLInputElement, MockInputProps>(
({ value, placeholder, onChange, onKeyDown, className, wrapperClassName, autoFocus }, ref) => (
<input
ref={ref}
value={value}
placeholder={placeholder}
onChange={onChange}
onKeyDown={onKeyDown}
className={className}
data-wrapper-class={wrapperClassName}
autoFocus={autoFocus}
data-testid="search-input"
/>
),
)
MockInput.displayName = 'MockInput'
return { default: MockInput }
})
describe('SearchInput', () => {
const defaultProps = {
inputRef: { current: null } as RefObject<HTMLInputElement | null>,
value: '',
onChange: vi.fn(),
searchMode: 'general',
}
beforeEach(() => {
vi.clearAllMocks()
})
describe('rendering', () => {
it('should render search icon', () => {
render(<SearchInput {...defaultProps} />)
expect(screen.getByTestId('search-icon')).toBeInTheDocument()
})
it('should render input field', () => {
render(<SearchInput {...defaultProps} />)
expect(screen.getByTestId('search-input')).toBeInTheDocument()
})
it('should render shortcuts name', () => {
render(<SearchInput {...defaultProps} />)
const shortcuts = screen.getByTestId('shortcuts-name')
expect(shortcuts).toBeInTheDocument()
expect(shortcuts).toHaveAttribute('data-keys', 'ctrl,K')
expect(shortcuts).toHaveAttribute('data-color', 'secondary')
})
it('should use provided placeholder', () => {
render(<SearchInput {...defaultProps} placeholder="Custom placeholder" />)
expect(screen.getByPlaceholderText('Custom placeholder')).toBeInTheDocument()
})
it('should use default placeholder from translation', () => {
render(<SearchInput {...defaultProps} />)
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
})
describe('mode label', () => {
it('should NOT show mode badge in general mode', () => {
render(<SearchInput {...defaultProps} searchMode="general" />)
expect(screen.queryByText('GENERAL')).not.toBeInTheDocument()
})
it('should show SCOPES label in scopes mode', () => {
render(<SearchInput {...defaultProps} searchMode="scopes" />)
expect(screen.getByText('SCOPES')).toBeInTheDocument()
})
it('should show COMMANDS label in commands mode', () => {
render(<SearchInput {...defaultProps} searchMode="commands" />)
expect(screen.getByText('COMMANDS')).toBeInTheDocument()
})
it('should show APP label in @app mode', () => {
render(<SearchInput {...defaultProps} searchMode="@app" />)
expect(screen.getByText('APP')).toBeInTheDocument()
})
it('should show PLUGIN label in @plugin mode', () => {
render(<SearchInput {...defaultProps} searchMode="@plugin" />)
expect(screen.getByText('PLUGIN')).toBeInTheDocument()
})
it('should show KNOWLEDGE label in @knowledge mode', () => {
render(<SearchInput {...defaultProps} searchMode="@knowledge" />)
expect(screen.getByText('KNOWLEDGE')).toBeInTheDocument()
})
it('should show NODE label in @node mode', () => {
render(<SearchInput {...defaultProps} searchMode="@node" />)
expect(screen.getByText('NODE')).toBeInTheDocument()
})
it('should uppercase custom mode label', () => {
render(<SearchInput {...defaultProps} searchMode="@custom" />)
expect(screen.getByText('CUSTOM')).toBeInTheDocument()
})
})
describe('input interactions', () => {
it('should call onChange when typing', () => {
const onChange = vi.fn()
render(<SearchInput {...defaultProps} onChange={onChange} />)
const input = screen.getByTestId('search-input')
fireEvent.change(input, { target: { value: 'test query' } })
expect(onChange).toHaveBeenCalledWith('test query')
})
it('should call onKeyDown when pressing keys', () => {
const onKeyDown = vi.fn()
render(<SearchInput {...defaultProps} onKeyDown={onKeyDown} />)
const input = screen.getByTestId('search-input')
fireEvent.keyDown(input, { key: 'Enter' })
expect(onKeyDown).toHaveBeenCalled()
})
it('should render with provided value', () => {
render(<SearchInput {...defaultProps} value="existing query" />)
expect(screen.getByDisplayValue('existing query')).toBeInTheDocument()
})
it('should NOT throw when onKeyDown is undefined', () => {
render(<SearchInput {...defaultProps} onKeyDown={undefined} />)
const input = screen.getByTestId('search-input')
expect(() => fireEvent.keyDown(input, { key: 'Enter' })).not.toThrow()
})
})
describe('styling', () => {
it('should have search icon styling', () => {
render(<SearchInput {...defaultProps} />)
const icon = screen.getByTestId('search-icon')
expect(icon).toHaveClass('h-4', 'w-4', 'text-text-quaternary')
})
it('should have mode badge styling when visible', () => {
const { container } = render(<SearchInput {...defaultProps} searchMode="@app" />)
const badge = container.querySelector('.bg-gray-100')
expect(badge).toBeInTheDocument()
expect(badge).toHaveClass('rounded', 'px-2', 'text-xs', 'font-medium')
})
})
})

View File

@@ -0,0 +1,62 @@
'use client'
import type { FC, KeyboardEvent, RefObject } from 'react'
import { RiSearchLine } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import Input from '@/app/components/base/input'
import ShortcutsName from '@/app/components/workflow/shortcuts-name'
export type SearchInputProps = {
inputRef: RefObject<HTMLInputElement | null>
value: string
onChange: (value: string) => void
onKeyDown?: (e: KeyboardEvent<HTMLInputElement>) => void
searchMode: string
placeholder?: string
}
const SearchInput: FC<SearchInputProps> = ({
inputRef,
value,
onChange,
onKeyDown,
searchMode,
placeholder,
}) => {
const { t } = useTranslation()
const getModeLabel = () => {
if (searchMode === 'scopes')
return 'SCOPES'
else if (searchMode === 'commands')
return 'COMMANDS'
else
return searchMode.replace('@', '').toUpperCase()
}
return (
<div className="flex items-center gap-3 border-b border-divider-subtle bg-components-panel-bg-blur px-4 py-3">
<RiSearchLine className="h-4 w-4 text-text-quaternary" />
<div className="flex flex-1 items-center gap-2">
<Input
ref={inputRef}
value={value}
placeholder={placeholder || t('gotoAnything.searchPlaceholder', { ns: 'app' })}
onChange={e => onChange(e.target.value)}
onKeyDown={onKeyDown}
className="flex-1 !border-0 !bg-transparent !shadow-none"
wrapperClassName="flex-1 !border-0 !bg-transparent"
autoFocus
/>
{searchMode !== 'general' && (
<div className="flex items-center gap-1 rounded bg-gray-100 px-2 py-[2px] text-xs font-medium text-gray-700 dark:bg-gray-800 dark:text-gray-300">
<span>{getModeLabel()}</span>
</div>
)}
</div>
<ShortcutsName keys={['ctrl', 'K']} textColor="secondary" />
</div>
)
}
export default SearchInput

View File

@@ -2,7 +2,7 @@ import { render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { GotoAnythingProvider, useGotoAnythingContext } from './context'
let pathnameMock = '/'
let pathnameMock: string | null | undefined = '/'
vi.mock('next/navigation', () => ({
usePathname: () => pathnameMock,
}))
@@ -57,4 +57,79 @@ describe('GotoAnythingProvider', () => {
expect(screen.getByTestId('status')).toHaveTextContent('false|true')
})
})
it('should set both flags to false when pathname is null', async () => {
pathnameMock = null
render(
<GotoAnythingProvider>
<ContextConsumer />
</GotoAnythingProvider>,
)
await waitFor(() => {
expect(screen.getByTestId('status')).toHaveTextContent('false|false')
})
})
it('should set both flags to false when pathname is undefined', async () => {
pathnameMock = undefined
render(
<GotoAnythingProvider>
<ContextConsumer />
</GotoAnythingProvider>,
)
await waitFor(() => {
expect(screen.getByTestId('status')).toHaveTextContent('false|false')
})
})
it('should set both flags to false for regular paths', async () => {
pathnameMock = '/apps'
render(
<GotoAnythingProvider>
<ContextConsumer />
</GotoAnythingProvider>,
)
await waitFor(() => {
expect(screen.getByTestId('status')).toHaveTextContent('false|false')
})
})
it('should NOT match non-pipeline dataset paths', async () => {
pathnameMock = '/datasets/abc/documents'
render(
<GotoAnythingProvider>
<ContextConsumer />
</GotoAnythingProvider>,
)
await waitFor(() => {
expect(screen.getByTestId('status')).toHaveTextContent('false|false')
})
})
})
describe('useGotoAnythingContext', () => {
it('should return default values when used outside provider', () => {
const TestComponent = () => {
const { isWorkflowPage, isRagPipelinePage } = useGotoAnythingContext()
return (
<div data-testid="context">
{String(isWorkflowPage)}
|
{String(isRagPipelinePage)}
</div>
)
}
render(<TestComponent />)
expect(screen.getByTestId('context')).toHaveTextContent('false|false')
})
})

View File

@@ -0,0 +1,11 @@
export { useGotoAnythingModal } from './use-goto-anything-modal'
export type { UseGotoAnythingModalReturn } from './use-goto-anything-modal'
export { useGotoAnythingNavigation } from './use-goto-anything-navigation'
export type { UseGotoAnythingNavigationOptions, UseGotoAnythingNavigationReturn } from './use-goto-anything-navigation'
export { useGotoAnythingResults } from './use-goto-anything-results'
export type { UseGotoAnythingResultsOptions, UseGotoAnythingResultsReturn } from './use-goto-anything-results'
export { useGotoAnythingSearch } from './use-goto-anything-search'
export type { UseGotoAnythingSearchReturn } from './use-goto-anything-search'

View File

@@ -0,0 +1,291 @@
import { act, renderHook } from '@testing-library/react'
import { useGotoAnythingModal } from './use-goto-anything-modal'
type KeyPressEvent = {
preventDefault: () => void
target?: EventTarget
}
const keyPressHandlers: Record<string, (event: KeyPressEvent) => void> = {}
let mockIsEventTargetInputArea = false
vi.mock('ahooks', () => ({
useKeyPress: (keys: string | string[], handler: (event: KeyPressEvent) => void) => {
const keyList = Array.isArray(keys) ? keys : [keys]
keyList.forEach((key) => {
keyPressHandlers[key] = handler
})
},
}))
vi.mock('@/app/components/workflow/utils/common', () => ({
getKeyboardKeyCodeBySystem: () => 'ctrl',
isEventTargetInputArea: () => mockIsEventTargetInputArea,
}))
describe('useGotoAnythingModal', () => {
beforeEach(() => {
Object.keys(keyPressHandlers).forEach(key => delete keyPressHandlers[key])
mockIsEventTargetInputArea = false
vi.useFakeTimers()
})
afterEach(() => {
vi.useRealTimers()
})
describe('initialization', () => {
it('should initialize with show=false', () => {
const { result } = renderHook(() => useGotoAnythingModal())
expect(result.current.show).toBe(false)
})
it('should provide inputRef initialized to null', () => {
const { result } = renderHook(() => useGotoAnythingModal())
expect(result.current.inputRef).toBeDefined()
expect(result.current.inputRef.current).toBe(null)
})
it('should provide setShow function', () => {
const { result } = renderHook(() => useGotoAnythingModal())
expect(typeof result.current.setShow).toBe('function')
})
it('should provide handleClose function', () => {
const { result } = renderHook(() => useGotoAnythingModal())
expect(typeof result.current.handleClose).toBe('function')
})
})
describe('keyboard shortcuts', () => {
it('should toggle show state when Ctrl+K is triggered', () => {
const { result } = renderHook(() => useGotoAnythingModal())
expect(result.current.show).toBe(false)
act(() => {
keyPressHandlers['ctrl.k']?.({ preventDefault: vi.fn(), target: document.body })
})
expect(result.current.show).toBe(true)
})
it('should toggle back to closed when Ctrl+K is triggered twice', () => {
const { result } = renderHook(() => useGotoAnythingModal())
act(() => {
keyPressHandlers['ctrl.k']?.({ preventDefault: vi.fn(), target: document.body })
})
expect(result.current.show).toBe(true)
act(() => {
keyPressHandlers['ctrl.k']?.({ preventDefault: vi.fn(), target: document.body })
})
expect(result.current.show).toBe(false)
})
it('should NOT toggle when focus is in input area and modal is closed', () => {
mockIsEventTargetInputArea = true
const { result } = renderHook(() => useGotoAnythingModal())
expect(result.current.show).toBe(false)
act(() => {
keyPressHandlers['ctrl.k']?.({ preventDefault: vi.fn(), target: document.body })
})
// Should remain closed because focus is in input area
expect(result.current.show).toBe(false)
})
it('should close modal when escape is pressed and modal is open', () => {
const { result } = renderHook(() => useGotoAnythingModal())
// Open modal first
act(() => {
result.current.setShow(true)
})
expect(result.current.show).toBe(true)
// Press escape
act(() => {
keyPressHandlers.esc?.({ preventDefault: vi.fn() })
})
expect(result.current.show).toBe(false)
})
it('should NOT do anything when escape is pressed and modal is already closed', () => {
const { result } = renderHook(() => useGotoAnythingModal())
expect(result.current.show).toBe(false)
const preventDefaultMock = vi.fn()
act(() => {
keyPressHandlers.esc?.({ preventDefault: preventDefaultMock })
})
// Should remain closed, and preventDefault should not be called
expect(result.current.show).toBe(false)
expect(preventDefaultMock).not.toHaveBeenCalled()
})
it('should call preventDefault when Ctrl+K is triggered', () => {
renderHook(() => useGotoAnythingModal())
const preventDefaultMock = vi.fn()
act(() => {
keyPressHandlers['ctrl.k']?.({ preventDefault: preventDefaultMock, target: document.body })
})
expect(preventDefaultMock).toHaveBeenCalled()
})
})
describe('handleClose', () => {
it('should close modal when handleClose is called', () => {
const { result } = renderHook(() => useGotoAnythingModal())
// Open modal first
act(() => {
result.current.setShow(true)
})
expect(result.current.show).toBe(true)
// Close via handleClose
act(() => {
result.current.handleClose()
})
expect(result.current.show).toBe(false)
})
it('should be safe to call handleClose when modal is already closed', () => {
const { result } = renderHook(() => useGotoAnythingModal())
expect(result.current.show).toBe(false)
act(() => {
result.current.handleClose()
})
expect(result.current.show).toBe(false)
})
})
describe('setShow', () => {
it('should accept boolean value', () => {
const { result } = renderHook(() => useGotoAnythingModal())
act(() => {
result.current.setShow(true)
})
expect(result.current.show).toBe(true)
act(() => {
result.current.setShow(false)
})
expect(result.current.show).toBe(false)
})
it('should accept function value', () => {
const { result } = renderHook(() => useGotoAnythingModal())
act(() => {
result.current.setShow(prev => !prev)
})
expect(result.current.show).toBe(true)
act(() => {
result.current.setShow(prev => !prev)
})
expect(result.current.show).toBe(false)
})
})
describe('focus management', () => {
it('should call requestAnimationFrame when modal opens', () => {
const rafSpy = vi.spyOn(window, 'requestAnimationFrame')
const { result } = renderHook(() => useGotoAnythingModal())
act(() => {
result.current.setShow(true)
})
expect(rafSpy).toHaveBeenCalled()
rafSpy.mockRestore()
})
it('should not call requestAnimationFrame when modal closes', () => {
const { result } = renderHook(() => useGotoAnythingModal())
// First open
act(() => {
result.current.setShow(true)
})
const rafSpy = vi.spyOn(window, 'requestAnimationFrame')
// Then close
act(() => {
result.current.setShow(false)
})
expect(rafSpy).not.toHaveBeenCalled()
rafSpy.mockRestore()
})
it('should focus input when modal opens and inputRef.current exists', () => {
// Mock requestAnimationFrame to execute callback immediately
const originalRAF = window.requestAnimationFrame
window.requestAnimationFrame = (callback: FrameRequestCallback) => {
callback(0)
return 0
}
const { result } = renderHook(() => useGotoAnythingModal())
// Create a mock input element with focus method
const mockFocus = vi.fn()
const mockInput = { focus: mockFocus } as unknown as HTMLInputElement
// Manually set the inputRef
Object.defineProperty(result.current.inputRef, 'current', {
value: mockInput,
writable: true,
})
act(() => {
result.current.setShow(true)
})
expect(mockFocus).toHaveBeenCalled()
// Restore original requestAnimationFrame
window.requestAnimationFrame = originalRAF
})
it('should not throw when inputRef.current is null when modal opens', () => {
// Mock requestAnimationFrame to execute callback immediately
const originalRAF = window.requestAnimationFrame
window.requestAnimationFrame = (callback: FrameRequestCallback) => {
callback(0)
return 0
}
const { result } = renderHook(() => useGotoAnythingModal())
// inputRef.current is already null by default
// Should not throw
act(() => {
result.current.setShow(true)
})
expect(result.current.show).toBe(true)
// Restore original requestAnimationFrame
window.requestAnimationFrame = originalRAF
})
})
})

View File

@@ -0,0 +1,59 @@
'use client'
import type { RefObject } from 'react'
import { useKeyPress } from 'ahooks'
import { useCallback, useEffect, useRef, useState } from 'react'
import { getKeyboardKeyCodeBySystem, isEventTargetInputArea } from '@/app/components/workflow/utils/common'
export type UseGotoAnythingModalReturn = {
show: boolean
setShow: (show: boolean | ((prev: boolean) => boolean)) => void
inputRef: RefObject<HTMLInputElement | null>
handleClose: () => void
}
export const useGotoAnythingModal = (): UseGotoAnythingModalReturn => {
const [show, setShow] = useState<boolean>(false)
const inputRef = useRef<HTMLInputElement>(null)
// Handle keyboard shortcuts
const handleToggleModal = useCallback((e: KeyboardEvent) => {
// Allow closing when modal is open, even if focus is in the search input
if (!show && isEventTargetInputArea(e.target as HTMLElement))
return
e.preventDefault()
setShow(prev => !prev)
}, [show])
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.k`, handleToggleModal, {
exactMatch: true,
useCapture: true,
})
useKeyPress(['esc'], (e) => {
if (show) {
e.preventDefault()
setShow(false)
}
})
const handleClose = useCallback(() => {
setShow(false)
}, [])
// Focus input when modal opens
useEffect(() => {
if (show) {
requestAnimationFrame(() => {
inputRef.current?.focus()
})
}
}, [show])
return {
show,
setShow,
inputRef,
handleClose,
}
}

View File

@@ -0,0 +1,391 @@
import type * as React from 'react'
import type { Plugin } from '../../plugins/types'
import type { CommonNodeType } from '../../workflow/types'
import type { DataSet } from '@/models/datasets'
import type { App } from '@/types/app'
import { act, renderHook } from '@testing-library/react'
import { useGotoAnythingNavigation } from './use-goto-anything-navigation'
const mockRouterPush = vi.fn()
const mockSelectWorkflowNode = vi.fn()
type MockCommandResult = {
mode: string
execute?: () => void
} | null
let mockFindCommandResult: MockCommandResult = null
vi.mock('next/navigation', () => ({
useRouter: () => ({
push: mockRouterPush,
}),
}))
vi.mock('@/app/components/workflow/utils/node-navigation', () => ({
selectWorkflowNode: (...args: unknown[]) => mockSelectWorkflowNode(...args),
}))
vi.mock('../actions/commands/registry', () => ({
slashCommandRegistry: {
findCommand: () => mockFindCommandResult,
},
}))
const createMockActionItem = (
key: '@app' | '@knowledge' | '@plugin' | '@node' | '/',
extra: Record<string, unknown> = {},
) => ({
key,
shortcut: key,
title: `${key} title`,
description: `${key} description`,
search: vi.fn().mockResolvedValue([]),
...extra,
})
const createMockOptions = (overrides = {}) => ({
Actions: {
slash: createMockActionItem('/', { action: vi.fn() }),
app: createMockActionItem('@app'),
},
setSearchQuery: vi.fn(),
clearSelection: vi.fn(),
inputRef: { current: { focus: vi.fn() } } as unknown as React.RefObject<HTMLInputElement>,
onClose: vi.fn(),
...overrides,
})
describe('useGotoAnythingNavigation', () => {
beforeEach(() => {
vi.clearAllMocks()
mockFindCommandResult = null
vi.useFakeTimers()
})
afterEach(() => {
vi.useRealTimers()
})
describe('initialization', () => {
it('should return handleCommandSelect function', () => {
const { result } = renderHook(() => useGotoAnythingNavigation(createMockOptions()))
expect(typeof result.current.handleCommandSelect).toBe('function')
})
it('should return handleNavigate function', () => {
const { result } = renderHook(() => useGotoAnythingNavigation(createMockOptions()))
expect(typeof result.current.handleNavigate).toBe('function')
})
it('should initialize activePlugin as undefined', () => {
const { result } = renderHook(() => useGotoAnythingNavigation(createMockOptions()))
expect(result.current.activePlugin).toBeUndefined()
})
it('should return setActivePlugin function', () => {
const { result } = renderHook(() => useGotoAnythingNavigation(createMockOptions()))
expect(typeof result.current.setActivePlugin).toBe('function')
})
})
describe('handleCommandSelect', () => {
it('should execute direct mode slash command immediately', () => {
const execute = vi.fn()
mockFindCommandResult = { mode: 'direct', execute }
const options = createMockOptions()
const { result } = renderHook(() => useGotoAnythingNavigation(options))
act(() => {
result.current.handleCommandSelect('/theme')
})
expect(execute).toHaveBeenCalled()
expect(options.onClose).toHaveBeenCalled()
expect(options.setSearchQuery).toHaveBeenCalledWith('')
})
it('should NOT execute when handler has no execute function', () => {
mockFindCommandResult = { mode: 'direct', execute: undefined }
const options = createMockOptions()
const { result } = renderHook(() => useGotoAnythingNavigation(options))
act(() => {
result.current.handleCommandSelect('/theme')
})
expect(options.onClose).not.toHaveBeenCalled()
// Should proceed with submenu mode
expect(options.setSearchQuery).toHaveBeenCalledWith('/theme ')
})
it('should proceed with submenu mode for non-direct commands', () => {
mockFindCommandResult = { mode: 'submenu' }
const options = createMockOptions()
const { result } = renderHook(() => useGotoAnythingNavigation(options))
act(() => {
result.current.handleCommandSelect('/language')
})
expect(options.setSearchQuery).toHaveBeenCalledWith('/language ')
expect(options.clearSelection).toHaveBeenCalled()
})
it('should handle @ commands (scopes)', () => {
const options = createMockOptions()
const { result } = renderHook(() => useGotoAnythingNavigation(options))
act(() => {
result.current.handleCommandSelect('@app')
})
expect(options.setSearchQuery).toHaveBeenCalledWith('@app ')
expect(options.clearSelection).toHaveBeenCalled()
})
it('should focus input after setting search query', () => {
const focusMock = vi.fn()
const options = createMockOptions({
inputRef: { current: { focus: focusMock } },
})
const { result } = renderHook(() => useGotoAnythingNavigation(options))
act(() => {
result.current.handleCommandSelect('@app')
})
act(() => {
vi.runAllTimers()
})
expect(focusMock).toHaveBeenCalled()
})
it('should handle null handler from registry', () => {
mockFindCommandResult = null
const options = createMockOptions()
const { result } = renderHook(() => useGotoAnythingNavigation(options))
act(() => {
result.current.handleCommandSelect('/unknown')
})
// Should proceed with submenu mode
expect(options.setSearchQuery).toHaveBeenCalledWith('/unknown ')
})
})
describe('handleNavigate', () => {
it('should navigate to path for default result types', () => {
const options = createMockOptions()
const { result } = renderHook(() => useGotoAnythingNavigation(options))
act(() => {
result.current.handleNavigate({
id: '1',
type: 'app' as const,
title: 'My App',
path: '/apps/1',
data: { id: '1', name: 'My App' } as unknown as App,
})
})
expect(options.onClose).toHaveBeenCalled()
expect(options.setSearchQuery).toHaveBeenCalledWith('')
expect(mockRouterPush).toHaveBeenCalledWith('/apps/1')
})
it('should NOT call router.push when path is empty', () => {
const options = createMockOptions()
const { result } = renderHook(() => useGotoAnythingNavigation(options))
act(() => {
result.current.handleNavigate({
id: '1',
type: 'app' as const,
title: 'My App',
path: '',
data: { id: '1', name: 'My App' } as unknown as App,
})
})
expect(mockRouterPush).not.toHaveBeenCalled()
})
it('should execute slash command action for command type', () => {
const actionMock = vi.fn()
const options = createMockOptions({
Actions: {
slash: { key: '/', shortcut: '/', action: actionMock },
},
})
const { result } = renderHook(() => useGotoAnythingNavigation(options))
const commandResult = {
id: 'cmd-1',
type: 'command' as const,
title: 'Theme Dark',
data: { command: 'theme.set', args: { theme: 'dark' } },
}
act(() => {
result.current.handleNavigate(commandResult)
})
expect(actionMock).toHaveBeenCalledWith(commandResult)
})
it('should set activePlugin for plugin type', () => {
const options = createMockOptions()
const pluginData = { name: 'My Plugin', latest_package_identifier: 'pkg' } as unknown as Plugin
const { result } = renderHook(() => useGotoAnythingNavigation(options))
act(() => {
result.current.handleNavigate({
id: 'plugin-1',
type: 'plugin' as const,
title: 'My Plugin',
data: pluginData,
})
})
expect(result.current.activePlugin).toEqual(pluginData)
})
it('should select workflow node for workflow-node type', () => {
const options = createMockOptions()
const { result } = renderHook(() => useGotoAnythingNavigation(options))
act(() => {
result.current.handleNavigate({
id: 'node-1',
type: 'workflow-node' as const,
title: 'Start Node',
metadata: { nodeId: 'node-123', nodeData: {} as CommonNodeType },
data: { id: 'node-1' } as unknown as CommonNodeType,
})
})
expect(mockSelectWorkflowNode).toHaveBeenCalledWith('node-123', true)
})
it('should NOT select workflow node when metadata.nodeId is missing', () => {
const options = createMockOptions()
const { result } = renderHook(() => useGotoAnythingNavigation(options))
act(() => {
result.current.handleNavigate({
id: 'node-1',
type: 'workflow-node' as const,
title: 'Start Node',
metadata: undefined,
data: { id: 'node-1' } as unknown as CommonNodeType,
})
})
expect(mockSelectWorkflowNode).not.toHaveBeenCalled()
})
it('should handle knowledge type (default case with path)', () => {
const options = createMockOptions()
const { result } = renderHook(() => useGotoAnythingNavigation(options))
act(() => {
result.current.handleNavigate({
id: 'kb-1',
type: 'knowledge' as const,
title: 'My Knowledge Base',
path: '/datasets/kb-1',
data: { id: 'kb-1', name: 'My Knowledge Base' } as unknown as DataSet,
})
})
expect(mockRouterPush).toHaveBeenCalledWith('/datasets/kb-1')
})
})
describe('setActivePlugin', () => {
it('should update activePlugin state', () => {
const { result } = renderHook(() => useGotoAnythingNavigation(createMockOptions()))
const plugin = { name: 'Test Plugin', latest_package_identifier: 'test-pkg' } as unknown as Plugin
act(() => {
result.current.setActivePlugin(plugin)
})
expect(result.current.activePlugin).toEqual(plugin)
})
it('should clear activePlugin when set to undefined', () => {
const { result } = renderHook(() => useGotoAnythingNavigation(createMockOptions()))
// First set a plugin
act(() => {
result.current.setActivePlugin({ name: 'Plugin', latest_package_identifier: 'pkg' } as unknown as Plugin)
})
expect(result.current.activePlugin).toBeDefined()
// Then clear it
act(() => {
result.current.setActivePlugin(undefined)
})
expect(result.current.activePlugin).toBeUndefined()
})
})
describe('edge cases', () => {
it('should handle undefined inputRef.current', () => {
const options = createMockOptions({
inputRef: { current: null },
})
const { result } = renderHook(() => useGotoAnythingNavigation(options))
// Should not throw
act(() => {
result.current.handleCommandSelect('@app')
})
act(() => {
vi.runAllTimers()
})
// No error should occur
})
it('should handle missing slash action', () => {
const options = createMockOptions({
Actions: {},
})
const { result } = renderHook(() => useGotoAnythingNavigation(options))
// Should not throw
act(() => {
result.current.handleNavigate({
id: 'cmd-1',
type: 'command' as const,
title: 'Command',
data: { command: 'test-command' },
})
})
// No error should occur
})
})
})

View File

@@ -0,0 +1,96 @@
'use client'
import type { RefObject } from 'react'
import type { Plugin } from '../../plugins/types'
import type { ActionItem, SearchResult } from '../actions/types'
import { useRouter } from 'next/navigation'
import { useCallback, useState } from 'react'
import { selectWorkflowNode } from '@/app/components/workflow/utils/node-navigation'
import { slashCommandRegistry } from '../actions/commands/registry'
export type UseGotoAnythingNavigationReturn = {
handleCommandSelect: (commandKey: string) => void
handleNavigate: (result: SearchResult) => void
activePlugin: Plugin | undefined
setActivePlugin: (plugin: Plugin | undefined) => void
}
export type UseGotoAnythingNavigationOptions = {
Actions: Record<string, ActionItem>
setSearchQuery: (query: string) => void
clearSelection: () => void
inputRef: RefObject<HTMLInputElement | null>
onClose: () => void
}
export const useGotoAnythingNavigation = (
options: UseGotoAnythingNavigationOptions,
): UseGotoAnythingNavigationReturn => {
const {
Actions,
setSearchQuery,
clearSelection,
inputRef,
onClose,
} = options
const router = useRouter()
const [activePlugin, setActivePlugin] = useState<Plugin>()
const handleCommandSelect = useCallback((commandKey: string) => {
// Check if it's a slash command
if (commandKey.startsWith('/')) {
const commandName = commandKey.substring(1)
const handler = slashCommandRegistry.findCommand(commandName)
// If it's a direct mode command, execute immediately
if (handler?.mode === 'direct' && handler.execute) {
handler.execute()
onClose()
setSearchQuery('')
return
}
}
// Otherwise, proceed with the normal flow (submenu mode)
setSearchQuery(`${commandKey} `)
clearSelection()
setTimeout(() => {
inputRef.current?.focus()
}, 0)
}, [onClose, setSearchQuery, clearSelection, inputRef])
// Handle navigation to selected result
const handleNavigate = useCallback((result: SearchResult) => {
onClose()
setSearchQuery('')
switch (result.type) {
case 'command': {
// Execute slash commands
const action = Actions.slash
action?.action?.(result)
break
}
case 'plugin':
setActivePlugin(result.data)
break
case 'workflow-node':
// Handle workflow node selection and navigation
if (result.metadata?.nodeId)
selectWorkflowNode(result.metadata.nodeId, true)
break
default:
if (result.path)
router.push(result.path)
}
}, [router, Actions, onClose, setSearchQuery])
return {
handleCommandSelect,
handleNavigate,
activePlugin,
setActivePlugin,
}
}

View File

@@ -0,0 +1,354 @@
import type { SearchResult } from '../actions/types'
import { renderHook } from '@testing-library/react'
import { useGotoAnythingResults } from './use-goto-anything-results'
type MockQueryResult = {
data: Array<{ id: string, type: string, title: string }> | undefined
isLoading: boolean
isError: boolean
error: Error | null
}
type UseQueryOptions = {
queryFn: () => Promise<SearchResult[]>
}
let mockQueryResult: MockQueryResult = { data: [], isLoading: false, isError: false, error: null }
let capturedQueryFn: (() => Promise<SearchResult[]>) | null = null
vi.mock('@tanstack/react-query', () => ({
useQuery: (options: UseQueryOptions) => {
capturedQueryFn = options.queryFn
return mockQueryResult
},
}))
vi.mock('@/context/i18n', () => ({
useGetLanguage: () => 'en_US',
}))
const mockMatchAction = vi.fn()
const mockSearchAnything = vi.fn()
vi.mock('../actions', () => ({
matchAction: (...args: unknown[]) => mockMatchAction(...args),
searchAnything: (...args: unknown[]) => mockSearchAnything(...args),
}))
const createMockActionItem = (key: '@app' | '@knowledge' | '@plugin' | '@node' | '/') => ({
key,
shortcut: key,
title: `${key} title`,
description: `${key} description`,
search: vi.fn().mockResolvedValue([]),
})
const createMockOptions = (overrides = {}) => ({
searchQueryDebouncedValue: '',
searchMode: 'general',
isCommandsMode: false,
Actions: { app: createMockActionItem('@app') },
isWorkflowPage: false,
isRagPipelinePage: false,
cmdVal: '_',
setCmdVal: vi.fn(),
...overrides,
})
describe('useGotoAnythingResults', () => {
beforeEach(() => {
mockQueryResult = { data: [], isLoading: false, isError: false, error: null }
capturedQueryFn = null
mockMatchAction.mockReset()
mockSearchAnything.mockReset()
})
describe('initialization', () => {
it('should return empty arrays when no results', () => {
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
expect(result.current.searchResults).toEqual([])
expect(result.current.dedupedResults).toEqual([])
expect(result.current.groupedResults).toEqual({})
})
it('should return loading state', () => {
mockQueryResult = { data: [], isLoading: true, isError: false, error: null }
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
expect(result.current.isLoading).toBe(true)
})
it('should return error state', () => {
const error = new Error('Test error')
mockQueryResult = { data: [], isLoading: false, isError: true, error }
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
expect(result.current.isError).toBe(true)
expect(result.current.error).toBe(error)
})
})
describe('dedupedResults', () => {
it('should remove duplicate results', () => {
mockQueryResult = {
data: [
{ id: '1', type: 'app', title: 'App 1' },
{ id: '1', type: 'app', title: 'App 1 Duplicate' },
{ id: '2', type: 'app', title: 'App 2' },
],
isLoading: false,
isError: false,
error: null,
}
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
expect(result.current.dedupedResults).toHaveLength(2)
expect(result.current.dedupedResults[0].id).toBe('1')
expect(result.current.dedupedResults[1].id).toBe('2')
})
it('should keep first occurrence when duplicates exist', () => {
mockQueryResult = {
data: [
{ id: '1', type: 'app', title: 'First' },
{ id: '1', type: 'app', title: 'Second' },
],
isLoading: false,
isError: false,
error: null,
}
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
expect(result.current.dedupedResults).toHaveLength(1)
expect(result.current.dedupedResults[0].title).toBe('First')
})
it('should handle different types with same id', () => {
mockQueryResult = {
data: [
{ id: '1', type: 'app', title: 'App' },
{ id: '1', type: 'plugin', title: 'Plugin' },
],
isLoading: false,
isError: false,
error: null,
}
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
// Different types, same id = different keys, so both should remain
expect(result.current.dedupedResults).toHaveLength(2)
})
})
describe('groupedResults', () => {
it('should group results by type', () => {
mockQueryResult = {
data: [
{ id: '1', type: 'app', title: 'App 1' },
{ id: '2', type: 'app', title: 'App 2' },
{ id: '3', type: 'plugin', title: 'Plugin 1' },
],
isLoading: false,
isError: false,
error: null,
}
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
expect(result.current.groupedResults.app).toHaveLength(2)
expect(result.current.groupedResults.plugin).toHaveLength(1)
})
it('should handle single type', () => {
mockQueryResult = {
data: [
{ id: '1', type: 'knowledge', title: 'KB 1' },
{ id: '2', type: 'knowledge', title: 'KB 2' },
],
isLoading: false,
isError: false,
error: null,
}
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
expect(Object.keys(result.current.groupedResults)).toEqual(['knowledge'])
expect(result.current.groupedResults.knowledge).toHaveLength(2)
})
it('should return empty object when no results', () => {
mockQueryResult = { data: [], isLoading: false, isError: false, error: null }
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
expect(result.current.groupedResults).toEqual({})
})
})
describe('auto-select first result', () => {
it('should call setCmdVal when results change and current value does not exist', () => {
const setCmdVal = vi.fn()
mockQueryResult = {
data: [{ id: '1', type: 'app', title: 'App 1' }],
isLoading: false,
isError: false,
error: null,
}
renderHook(() => useGotoAnythingResults(createMockOptions({
cmdVal: 'non-existent',
setCmdVal,
})))
expect(setCmdVal).toHaveBeenCalledWith('app-1')
})
it('should NOT call setCmdVal when in commands mode', () => {
const setCmdVal = vi.fn()
mockQueryResult = {
data: [{ id: '1', type: 'app', title: 'App 1' }],
isLoading: false,
isError: false,
error: null,
}
renderHook(() => useGotoAnythingResults(createMockOptions({
isCommandsMode: true,
setCmdVal,
})))
expect(setCmdVal).not.toHaveBeenCalled()
})
it('should NOT call setCmdVal when results are empty', () => {
const setCmdVal = vi.fn()
mockQueryResult = { data: [], isLoading: false, isError: false, error: null }
renderHook(() => useGotoAnythingResults(createMockOptions({
setCmdVal,
})))
expect(setCmdVal).not.toHaveBeenCalled()
})
it('should NOT call setCmdVal when current value exists in results', () => {
const setCmdVal = vi.fn()
mockQueryResult = {
data: [
{ id: '1', type: 'app', title: 'App 1' },
{ id: '2', type: 'app', title: 'App 2' },
],
isLoading: false,
isError: false,
error: null,
}
renderHook(() => useGotoAnythingResults(createMockOptions({
cmdVal: 'app-2',
setCmdVal,
})))
expect(setCmdVal).not.toHaveBeenCalled()
})
})
describe('error handling', () => {
it('should return error as Error | null', () => {
const error = new Error('Search failed')
mockQueryResult = { data: [], isLoading: false, isError: true, error }
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
expect(result.current.error).toBeInstanceOf(Error)
expect(result.current.error?.message).toBe('Search failed')
})
it('should return null error when no error', () => {
mockQueryResult = { data: [], isLoading: false, isError: false, error: null }
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
expect(result.current.error).toBeNull()
})
})
describe('searchResults', () => {
it('should return raw search results', () => {
const mockData = [
{ id: '1', type: 'app', title: 'App 1' },
{ id: '2', type: 'plugin', title: 'Plugin 1' },
]
mockQueryResult = { data: mockData, isLoading: false, isError: false, error: null }
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
expect(result.current.searchResults).toEqual(mockData)
})
it('should default to empty array when data is undefined', () => {
mockQueryResult = { data: undefined, isLoading: false, isError: false, error: null }
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
expect(result.current.searchResults).toEqual([])
})
})
describe('queryFn execution', () => {
it('should call matchAction with lowercased query', async () => {
const mockActions = { app: createMockActionItem('@app') }
mockMatchAction.mockReturnValue({ key: '@app' })
mockSearchAnything.mockResolvedValue([])
renderHook(() => useGotoAnythingResults(createMockOptions({
searchQueryDebouncedValue: 'TEST QUERY',
Actions: mockActions,
})))
expect(capturedQueryFn).toBeDefined()
await capturedQueryFn!()
expect(mockMatchAction).toHaveBeenCalledWith('test query', mockActions)
})
it('should call searchAnything with correct parameters', async () => {
const mockActions = { app: createMockActionItem('@app') }
const mockAction = { key: '@app' }
mockMatchAction.mockReturnValue(mockAction)
mockSearchAnything.mockResolvedValue([{ id: '1', type: 'app', title: 'Result' }])
renderHook(() => useGotoAnythingResults(createMockOptions({
searchQueryDebouncedValue: 'My Query',
Actions: mockActions,
})))
expect(capturedQueryFn).toBeDefined()
const result = await capturedQueryFn!()
expect(mockSearchAnything).toHaveBeenCalledWith('en_US', 'my query', mockAction, mockActions)
expect(result).toEqual([{ id: '1', type: 'app', title: 'Result' }])
})
it('should handle searchAnything returning results', async () => {
const expectedResults = [
{ id: '1', type: 'app', title: 'App 1' },
{ id: '2', type: 'plugin', title: 'Plugin 1' },
]
mockMatchAction.mockReturnValue(null)
mockSearchAnything.mockResolvedValue(expectedResults)
renderHook(() => useGotoAnythingResults(createMockOptions({
searchQueryDebouncedValue: 'search term',
})))
expect(capturedQueryFn).toBeDefined()
const result = await capturedQueryFn!()
expect(result).toEqual(expectedResults)
})
})
})

View File

@@ -0,0 +1,115 @@
'use client'
import type { ActionItem, SearchResult } from '../actions/types'
import { useQuery } from '@tanstack/react-query'
import { useEffect, useMemo } from 'react'
import { useGetLanguage } from '@/context/i18n'
import { matchAction, searchAnything } from '../actions'
export type UseGotoAnythingResultsReturn = {
searchResults: SearchResult[]
dedupedResults: SearchResult[]
groupedResults: Record<string, SearchResult[]>
isLoading: boolean
isError: boolean
error: Error | null
}
export type UseGotoAnythingResultsOptions = {
searchQueryDebouncedValue: string
searchMode: string
isCommandsMode: boolean
Actions: Record<string, ActionItem>
isWorkflowPage: boolean
isRagPipelinePage: boolean
cmdVal: string
setCmdVal: (val: string) => void
}
export const useGotoAnythingResults = (
options: UseGotoAnythingResultsOptions,
): UseGotoAnythingResultsReturn => {
const {
searchQueryDebouncedValue,
searchMode,
isCommandsMode,
Actions,
isWorkflowPage,
isRagPipelinePage,
cmdVal,
setCmdVal,
} = options
const defaultLocale = useGetLanguage()
// Use action keys as stable cache key instead of the full Actions object
// (Actions contains functions which are not serializable)
const actionKeys = useMemo(() => Object.keys(Actions).sort(), [Actions])
const { data: searchResults = [], isLoading, isError, error } = useQuery(
{
// eslint-disable-next-line @tanstack/query/exhaustive-deps -- Actions intentionally excluded: contains non-serializable functions; actionKeys provides stable representation
queryKey: [
'goto-anything',
'search-result',
searchQueryDebouncedValue,
searchMode,
isWorkflowPage,
isRagPipelinePage,
defaultLocale,
actionKeys,
],
queryFn: async () => {
const query = searchQueryDebouncedValue.toLowerCase()
const action = matchAction(query, Actions)
return await searchAnything(defaultLocale, query, action, Actions)
},
enabled: !!searchQueryDebouncedValue && !isCommandsMode,
staleTime: 30000,
gcTime: 300000,
},
)
const dedupedResults = useMemo(() => {
const seen = new Set<string>()
return searchResults.filter((result) => {
const key = `${result.type}-${result.id}`
if (seen.has(key))
return false
seen.add(key)
return true
})
}, [searchResults])
// Group results by type
const groupedResults = useMemo(() => dedupedResults.reduce((acc, result) => {
if (!acc[result.type])
acc[result.type] = []
acc[result.type].push(result)
return acc
}, {} as Record<string, SearchResult[]>), [dedupedResults])
// Auto-select first result when results change
useEffect(() => {
if (isCommandsMode)
return
if (!dedupedResults.length)
return
const currentValueExists = dedupedResults.some(result => `${result.type}-${result.id}` === cmdVal)
if (!currentValueExists)
setCmdVal(`${dedupedResults[0].type}-${dedupedResults[0].id}`)
}, [isCommandsMode, dedupedResults, cmdVal, setCmdVal])
return {
searchResults,
dedupedResults,
groupedResults,
isLoading,
isError,
error: error as Error | null,
}
}

View File

@@ -0,0 +1,301 @@
import type { ActionItem } from '../actions/types'
import { act, renderHook } from '@testing-library/react'
import { useGotoAnythingSearch } from './use-goto-anything-search'
let mockContextValue = { isWorkflowPage: false, isRagPipelinePage: false }
let mockMatchActionResult: Partial<ActionItem> | undefined
vi.mock('ahooks', () => ({
useDebounce: <T>(value: T) => value,
}))
vi.mock('../context', () => ({
useGotoAnythingContext: () => mockContextValue,
}))
vi.mock('../actions', () => ({
createActions: (isWorkflowPage: boolean, isRagPipelinePage: boolean) => {
const base = {
slash: { key: '/', shortcut: '/' },
app: { key: '@app', shortcut: '@app' },
knowledge: { key: '@knowledge', shortcut: '@kb' },
}
if (isWorkflowPage) {
return { ...base, node: { key: '@node', shortcut: '@node' } }
}
if (isRagPipelinePage) {
return { ...base, ragNode: { key: '@node', shortcut: '@node' } }
}
return base
},
matchAction: () => mockMatchActionResult,
}))
describe('useGotoAnythingSearch', () => {
beforeEach(() => {
mockContextValue = { isWorkflowPage: false, isRagPipelinePage: false }
mockMatchActionResult = undefined
})
describe('initialization', () => {
it('should initialize with empty search query', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
expect(result.current.searchQuery).toBe('')
})
it('should initialize cmdVal with "_"', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
expect(result.current.cmdVal).toBe('_')
})
it('should initialize searchMode as "general"', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
expect(result.current.searchMode).toBe('general')
})
it('should initialize isCommandsMode as false', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
expect(result.current.isCommandsMode).toBe(false)
})
it('should provide setSearchQuery function', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
expect(typeof result.current.setSearchQuery).toBe('function')
})
it('should provide setCmdVal function', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
expect(typeof result.current.setCmdVal).toBe('function')
})
it('should provide clearSelection function', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
expect(typeof result.current.clearSelection).toBe('function')
})
})
describe('Actions', () => {
it('should provide Actions based on context', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
expect(result.current.Actions).toBeDefined()
expect(typeof result.current.Actions).toBe('object')
})
it('should include node action when on workflow page', () => {
mockContextValue = { isWorkflowPage: true, isRagPipelinePage: false }
const { result } = renderHook(() => useGotoAnythingSearch())
expect(result.current.Actions.node).toBeDefined()
})
it('should include ragNode action when on RAG pipeline page', () => {
mockContextValue = { isWorkflowPage: false, isRagPipelinePage: true }
const { result } = renderHook(() => useGotoAnythingSearch())
expect(result.current.Actions.ragNode).toBeDefined()
})
it('should not include node actions when on regular page', () => {
mockContextValue = { isWorkflowPage: false, isRagPipelinePage: false }
const { result } = renderHook(() => useGotoAnythingSearch())
expect(result.current.Actions.node).toBeUndefined()
expect(result.current.Actions.ragNode).toBeUndefined()
})
})
describe('isCommandsMode', () => {
it('should return true when query is exactly "@"', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
result.current.setSearchQuery('@')
})
expect(result.current.isCommandsMode).toBe(true)
})
it('should return true when query is exactly "/"', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
result.current.setSearchQuery('/')
})
expect(result.current.isCommandsMode).toBe(true)
})
it('should return true when query starts with "@" and no action matches', () => {
mockMatchActionResult = undefined
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
result.current.setSearchQuery('@unknown')
})
expect(result.current.isCommandsMode).toBe(true)
})
it('should return true when query starts with "/" and no action matches', () => {
mockMatchActionResult = undefined
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
result.current.setSearchQuery('/unknown')
})
expect(result.current.isCommandsMode).toBe(true)
})
it('should return false when query starts with "@" and action matches', () => {
mockMatchActionResult = { key: '@app', shortcut: '@app' }
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
result.current.setSearchQuery('@app test')
})
expect(result.current.isCommandsMode).toBe(false)
})
it('should return false for regular search query', () => {
mockMatchActionResult = undefined
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
result.current.setSearchQuery('hello world')
})
expect(result.current.isCommandsMode).toBe(false)
})
})
describe('searchMode', () => {
it('should return "general" when query is empty', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
expect(result.current.searchMode).toBe('general')
})
it('should return "scopes" when in commands mode and query starts with "@"', () => {
mockMatchActionResult = undefined
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
result.current.setSearchQuery('@')
})
expect(result.current.searchMode).toBe('scopes')
})
it('should return "commands" when in commands mode and query starts with "/"', () => {
mockMatchActionResult = undefined
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
result.current.setSearchQuery('/')
})
expect(result.current.searchMode).toBe('commands')
})
it('should return "general" when no action matches', () => {
mockMatchActionResult = undefined
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
result.current.setSearchQuery('hello')
})
expect(result.current.searchMode).toBe('general')
})
it('should return action key when action matches', () => {
mockMatchActionResult = { key: '@app', shortcut: '@app' }
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
result.current.setSearchQuery('@app test')
})
expect(result.current.searchMode).toBe('@app')
})
it('should return "@command" when action key is "/"', () => {
mockMatchActionResult = { key: '/', shortcut: '/' }
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
result.current.setSearchQuery('/theme dark')
})
expect(result.current.searchMode).toBe('@command')
})
})
describe('clearSelection', () => {
it('should reset cmdVal to "_"', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
// First change cmdVal
act(() => {
result.current.setCmdVal('app-1')
})
expect(result.current.cmdVal).toBe('app-1')
// Then clear
act(() => {
result.current.clearSelection()
})
expect(result.current.cmdVal).toBe('_')
})
})
describe('setSearchQuery', () => {
it('should update search query', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
result.current.setSearchQuery('test query')
})
expect(result.current.searchQuery).toBe('test query')
})
it('should handle empty string', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
result.current.setSearchQuery('test')
})
expect(result.current.searchQuery).toBe('test')
act(() => {
result.current.setSearchQuery('')
})
expect(result.current.searchQuery).toBe('')
})
})
describe('setCmdVal', () => {
it('should update cmdVal', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
result.current.setCmdVal('plugin-2')
})
expect(result.current.cmdVal).toBe('plugin-2')
})
})
describe('searchQueryDebouncedValue', () => {
it('should return trimmed debounced value', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
result.current.setSearchQuery(' test ')
})
// Since we mock useDebounce to return value directly
expect(result.current.searchQueryDebouncedValue).toBe('test')
})
})
})

View File

@@ -0,0 +1,77 @@
'use client'
import type { ActionItem } from '../actions/types'
import { useDebounce } from 'ahooks'
import { useCallback, useMemo, useState } from 'react'
import { createActions, matchAction } from '../actions'
import { useGotoAnythingContext } from '../context'
export type UseGotoAnythingSearchReturn = {
searchQuery: string
setSearchQuery: (query: string) => void
searchQueryDebouncedValue: string
searchMode: string
isCommandsMode: boolean
cmdVal: string
setCmdVal: (val: string) => void
clearSelection: () => void
Actions: Record<string, ActionItem>
}
export const useGotoAnythingSearch = (): UseGotoAnythingSearchReturn => {
const { isWorkflowPage, isRagPipelinePage } = useGotoAnythingContext()
const [searchQuery, setSearchQuery] = useState<string>('')
const [cmdVal, setCmdVal] = useState<string>('_')
// Filter actions based on context
const Actions = useMemo(() => {
return createActions(isWorkflowPage, isRagPipelinePage)
}, [isWorkflowPage, isRagPipelinePage])
const searchQueryDebouncedValue = useDebounce(searchQuery.trim(), {
wait: 300,
})
const isCommandsMode = useMemo(() => {
const trimmed = searchQuery.trim()
return trimmed === '@' || trimmed === '/'
|| (trimmed.startsWith('@') && !matchAction(trimmed, Actions))
|| (trimmed.startsWith('/') && !matchAction(trimmed, Actions))
}, [searchQuery, Actions])
const searchMode = useMemo(() => {
if (isCommandsMode) {
// Distinguish between @ (scopes) and / (commands) mode
if (searchQuery.trim().startsWith('@'))
return 'scopes'
else if (searchQuery.trim().startsWith('/'))
return 'commands'
return 'commands' // default fallback
}
const query = searchQueryDebouncedValue.toLowerCase()
const action = matchAction(query, Actions)
if (!action)
return 'general'
return action.key === '/' ? '@command' : action.key
}, [searchQueryDebouncedValue, Actions, isCommandsMode, searchQuery])
// Prevent automatic selection of the first option when cmdVal is not set
const clearSelection = useCallback(() => {
setCmdVal('_')
}, [])
return {
searchQuery,
setSearchQuery,
searchQueryDebouncedValue,
searchMode,
isCommandsMode,
cmdVal,
setCmdVal,
clearSelection,
Actions,
}
}

View File

@@ -1,9 +1,27 @@
import type { ReactNode } from 'react'
import type { ActionItem, SearchResult } from './actions/types'
import { act, render, screen } from '@testing-library/react'
import { act, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import GotoAnything from './index'
// Test helper type that matches SearchResult but allows ReactNode for icon and flexible data
type TestSearchResult = Omit<SearchResult, 'icon' | 'data'> & {
icon?: ReactNode
data?: Record<string, unknown>
}
// Mock react-i18next to return namespace.key format
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string, options?: { ns?: string }) => {
const ns = options?.ns || 'common'
return `${ns}.${key}`
},
i18n: { language: 'en' },
}),
}))
const routerPush = vi.fn()
vi.mock('next/navigation', () => ({
useRouter: () => ({
@@ -12,10 +30,15 @@ vi.mock('next/navigation', () => ({
usePathname: () => '/',
}))
const keyPressHandlers: Record<string, (event: any) => void> = {}
type KeyPressEvent = {
preventDefault: () => void
target?: EventTarget
}
const keyPressHandlers: Record<string, (event: KeyPressEvent) => void> = {}
vi.mock('ahooks', () => ({
useDebounce: (value: any) => value,
useKeyPress: (keys: string | string[], handler: (event: any) => void) => {
useDebounce: <T,>(value: T) => value,
useKeyPress: (keys: string | string[], handler: (event: KeyPressEvent) => void) => {
const keyList = Array.isArray(keys) ? keys : [keys]
keyList.forEach((key) => {
keyPressHandlers[key] = handler
@@ -32,7 +55,7 @@ const triggerKeyPress = (combo: string) => {
}
}
let mockQueryResult = { data: [] as SearchResult[], isLoading: false, isError: false, error: null as Error | null }
let mockQueryResult = { data: [] as TestSearchResult[], isLoading: false, isError: false, error: null as Error | null }
vi.mock('@tanstack/react-query', () => ({
useQuery: () => mockQueryResult,
}))
@@ -76,9 +99,16 @@ vi.mock('./actions/commands', () => ({
SlashCommandProvider: () => null,
}))
type MockSlashCommand = {
mode: string
execute?: () => void
isAvailable?: () => boolean
} | null
let mockFindCommand: MockSlashCommand = null
vi.mock('./actions/commands/registry', () => ({
slashCommandRegistry: {
findCommand: () => null,
findCommand: () => mockFindCommand,
getAvailableCommands: () => [],
getAllCommands: () => [],
},
@@ -86,6 +116,7 @@ vi.mock('./actions/commands/registry', () => ({
vi.mock('@/app/components/workflow/utils/common', () => ({
getKeyboardKeyCodeBySystem: () => 'ctrl',
getKeyboardKeyNameBySystem: (key: string) => key,
isEventTargetInputArea: () => false,
isMac: () => false,
}))
@@ -95,10 +126,11 @@ vi.mock('@/app/components/workflow/utils/node-navigation', () => ({
}))
vi.mock('../plugins/install-plugin/install-from-marketplace', () => ({
default: (props: { manifest?: { name?: string }, onClose: () => void }) => (
default: (props: { manifest?: { name?: string }, onClose: () => void, onSuccess: () => void }) => (
<div data-testid="install-modal">
<span>{props.manifest?.name}</span>
<button onClick={props.onClose}>close</button>
<button onClick={props.onClose} data-testid="close-install">close</button>
<button onClick={props.onSuccess} data-testid="success-install">success</button>
</div>
),
}))
@@ -110,65 +142,504 @@ describe('GotoAnything', () => {
mockQueryResult = { data: [], isLoading: false, isError: false, error: null }
matchActionMock.mockReset()
searchAnythingMock.mockClear()
mockFindCommand = null
})
it('should open modal via shortcut and navigate to selected result', async () => {
mockQueryResult = {
data: [{
id: 'app-1',
type: 'app',
title: 'Sample App',
description: 'desc',
path: '/apps/1',
icon: <div data-testid="icon">🧩</div>,
data: {},
} as any],
isLoading: false,
isError: false,
error: null,
}
describe('modal behavior', () => {
it('should open modal via Ctrl+K shortcut', async () => {
render(<GotoAnything />)
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
})
const input = await screen.findByPlaceholderText('app.gotoAnything.searchPlaceholder')
await userEvent.type(input, 'app')
it('should close modal via ESC key', async () => {
render(<GotoAnything />)
const result = await screen.findByText('Sample App')
await userEvent.click(result)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
expect(routerPush).toHaveBeenCalledWith('/apps/1')
triggerKeyPress('esc')
await waitFor(() => {
expect(screen.queryByPlaceholderText('app.gotoAnything.searchPlaceholder')).not.toBeInTheDocument()
})
})
it('should toggle modal when pressing Ctrl+K twice', async () => {
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.queryByPlaceholderText('app.gotoAnything.searchPlaceholder')).not.toBeInTheDocument()
})
})
it('should call onHide when modal closes', async () => {
const onHide = vi.fn()
render(<GotoAnything onHide={onHide} />)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
triggerKeyPress('esc')
await waitFor(() => {
expect(onHide).toHaveBeenCalled()
})
})
it('should reset search query when modal opens', async () => {
const user = userEvent.setup()
render(<GotoAnything />)
// Open modal first time
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
// Type something
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
await user.type(input, 'test')
// Close modal
triggerKeyPress('esc')
await waitFor(() => {
expect(screen.queryByPlaceholderText('app.gotoAnything.searchPlaceholder')).not.toBeInTheDocument()
})
// Open modal again - should be empty
triggerKeyPress('ctrl.k')
await waitFor(() => {
const newInput = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
expect(newInput).toHaveValue('')
})
})
})
it('should open plugin installer when selecting plugin result', async () => {
mockQueryResult = {
data: [{
id: 'plugin-1',
type: 'plugin',
title: 'Plugin Item',
description: 'desc',
path: '',
icon: <div />,
data: {
name: 'Plugin Item',
latest_package_identifier: 'pkg',
},
} as any],
isLoading: false,
isError: false,
error: null,
}
describe('search functionality', () => {
it('should navigate to selected result', async () => {
const user = userEvent.setup()
mockQueryResult = {
data: [{
id: 'app-1',
type: 'app',
title: 'Sample App',
description: 'desc',
path: '/apps/1',
icon: <div data-testid="icon">🧩</div>,
data: {},
}],
isLoading: false,
isError: false,
error: null,
}
render(<GotoAnything />)
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
triggerKeyPress('ctrl.k')
const input = await screen.findByPlaceholderText('app.gotoAnything.searchPlaceholder')
await userEvent.type(input, 'plugin')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
const pluginItem = await screen.findByText('Plugin Item')
await userEvent.click(pluginItem)
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
await user.type(input, 'app')
expect(await screen.findByTestId('install-modal')).toHaveTextContent('Plugin Item')
const result = await screen.findByText('Sample App')
await user.click(result)
expect(routerPush).toHaveBeenCalledWith('/apps/1')
})
it('should clear selection when typing without prefix', async () => {
const user = userEvent.setup()
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
await user.type(input, 'test query')
// Should not throw and input should have value
expect(input).toHaveValue('test query')
})
})
describe('empty states', () => {
it('should show loading state', async () => {
const user = userEvent.setup()
mockQueryResult = {
data: [],
isLoading: true,
isError: false,
error: null,
}
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
await user.type(input, 'search')
// Loading state shows in both EmptyState (spinner) and Footer
const searchingTexts = screen.getAllByText('app.gotoAnything.searching')
expect(searchingTexts.length).toBeGreaterThanOrEqual(1)
})
it('should show error state', async () => {
const user = userEvent.setup()
const testError = new Error('Search failed')
mockQueryResult = {
data: [],
isLoading: false,
isError: true,
error: testError,
}
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
await user.type(input, 'search')
expect(screen.getByText('app.gotoAnything.searchFailed')).toBeInTheDocument()
})
it('should show default state when no query', async () => {
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
expect(screen.getByText('app.gotoAnything.searchTitle')).toBeInTheDocument()
})
it('should show no results state when search returns empty', async () => {
const user = userEvent.setup()
mockQueryResult = {
data: [],
isLoading: false,
isError: false,
error: null,
}
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
await user.type(input, 'nonexistent')
expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument()
})
})
describe('plugin installation', () => {
it('should open plugin installer when selecting plugin result', async () => {
const user = userEvent.setup()
mockQueryResult = {
data: [{
id: 'plugin-1',
type: 'plugin',
title: 'Plugin Item',
description: 'desc',
path: '',
icon: <div />,
data: {
name: 'Plugin Item',
latest_package_identifier: 'pkg',
},
}],
isLoading: false,
isError: false,
error: null,
}
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
await user.type(input, 'plugin')
const pluginItem = await screen.findByText('Plugin Item')
await user.click(pluginItem)
expect(await screen.findByTestId('install-modal')).toHaveTextContent('Plugin Item')
})
it('should close plugin installer via close button', async () => {
const user = userEvent.setup()
mockQueryResult = {
data: [{
id: 'plugin-1',
type: 'plugin',
title: 'Plugin Item',
description: 'desc',
path: '',
icon: <div />,
data: {
name: 'Plugin Item',
latest_package_identifier: 'pkg',
},
}],
isLoading: false,
isError: false,
error: null,
}
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
await user.type(input, 'plugin')
const pluginItem = await screen.findByText('Plugin Item')
await user.click(pluginItem)
const closeBtn = await screen.findByTestId('close-install')
await user.click(closeBtn)
await waitFor(() => {
expect(screen.queryByTestId('install-modal')).not.toBeInTheDocument()
})
})
it('should close plugin installer on success', async () => {
const user = userEvent.setup()
mockQueryResult = {
data: [{
id: 'plugin-1',
type: 'plugin',
title: 'Plugin Item',
description: 'desc',
path: '',
icon: <div />,
data: {
name: 'Plugin Item',
latest_package_identifier: 'pkg',
},
}],
isLoading: false,
isError: false,
error: null,
}
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
await user.type(input, 'plugin')
const pluginItem = await screen.findByText('Plugin Item')
await user.click(pluginItem)
const successBtn = await screen.findByTestId('success-install')
await user.click(successBtn)
await waitFor(() => {
expect(screen.queryByTestId('install-modal')).not.toBeInTheDocument()
})
})
})
describe('slash command handling', () => {
it('should execute direct slash command on Enter', async () => {
const user = userEvent.setup()
const executeMock = vi.fn()
mockFindCommand = {
mode: 'direct',
execute: executeMock,
isAvailable: () => true,
}
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
await user.type(input, '/theme')
await user.keyboard('{Enter}')
expect(executeMock).toHaveBeenCalled()
})
it('should NOT execute unavailable slash command', async () => {
const user = userEvent.setup()
const executeMock = vi.fn()
mockFindCommand = {
mode: 'direct',
execute: executeMock,
isAvailable: () => false,
}
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
await user.type(input, '/theme')
await user.keyboard('{Enter}')
expect(executeMock).not.toHaveBeenCalled()
})
it('should NOT execute non-direct mode slash command on Enter', async () => {
const user = userEvent.setup()
const executeMock = vi.fn()
mockFindCommand = {
mode: 'submenu',
execute: executeMock,
}
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
await user.type(input, '/language')
await user.keyboard('{Enter}')
expect(executeMock).not.toHaveBeenCalled()
})
it('should close modal after executing direct slash command', async () => {
const user = userEvent.setup()
mockFindCommand = {
mode: 'direct',
execute: vi.fn(),
isAvailable: () => true,
}
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
await user.type(input, '/theme')
await user.keyboard('{Enter}')
await waitFor(() => {
expect(screen.queryByPlaceholderText('app.gotoAnything.searchPlaceholder')).not.toBeInTheDocument()
})
})
})
describe('result navigation', () => {
it('should handle knowledge result navigation', async () => {
const user = userEvent.setup()
mockQueryResult = {
data: [{
id: 'kb-1',
type: 'knowledge',
title: 'Knowledge Base',
description: 'desc',
path: '/datasets/kb-1',
icon: <div />,
data: {},
}],
isLoading: false,
isError: false,
error: null,
}
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
await user.type(input, 'knowledge')
const result = await screen.findByText('Knowledge Base')
await user.click(result)
expect(routerPush).toHaveBeenCalledWith('/datasets/kb-1')
})
it('should NOT navigate when result has no path', async () => {
const user = userEvent.setup()
mockQueryResult = {
data: [{
id: 'item-1',
type: 'app',
title: 'No Path Item',
description: 'desc',
path: '',
icon: <div />,
data: {},
}],
isLoading: false,
isError: false,
error: null,
}
render(<GotoAnything />)
triggerKeyPress('ctrl.k')
await waitFor(() => {
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
})
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
await user.type(input, 'no path')
const result = await screen.findByText('No Path Item')
await user.click(result)
expect(routerPush).not.toHaveBeenCalled()
})
})
})

View File

@@ -1,300 +1,149 @@
'use client'
import type { FC } from 'react'
import type { Plugin } from '../plugins/types'
import type { SearchResult } from './actions'
import { RiSearchLine } from '@remixicon/react'
import { useQuery } from '@tanstack/react-query'
import { useDebounce, useKeyPress } from 'ahooks'
import type { FC, KeyboardEvent } from 'react'
import { Command } from 'cmdk'
import { useRouter } from 'next/navigation'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useCallback, useEffect, useMemo, useRef } from 'react'
import { useTranslation } from 'react-i18next'
import Input from '@/app/components/base/input'
import Modal from '@/app/components/base/modal'
import ShortcutsName from '@/app/components/workflow/shortcuts-name'
import { getKeyboardKeyCodeBySystem, isEventTargetInputArea } from '@/app/components/workflow/utils/common'
import { selectWorkflowNode } from '@/app/components/workflow/utils/node-navigation'
import { useGetLanguage } from '@/context/i18n'
import InstallFromMarketplace from '../plugins/install-plugin/install-from-marketplace'
import { createActions, matchAction, searchAnything } from './actions'
import { SlashCommandProvider } from './actions/commands'
import { slashCommandRegistry } from './actions/commands/registry'
import CommandSelector from './command-selector'
import { EmptyState, Footer, ResultList, SearchInput } from './components'
import { GotoAnythingProvider, useGotoAnythingContext } from './context'
import {
useGotoAnythingModal,
useGotoAnythingNavigation,
useGotoAnythingResults,
useGotoAnythingSearch,
} from './hooks'
type Props = {
onHide?: () => void
}
const GotoAnything: FC<Props> = ({
onHide,
}) => {
const router = useRouter()
const defaultLocale = useGetLanguage()
const { isWorkflowPage, isRagPipelinePage } = useGotoAnythingContext()
const { t } = useTranslation()
const [show, setShow] = useState<boolean>(false)
const [searchQuery, setSearchQuery] = useState<string>('')
const [cmdVal, setCmdVal] = useState<string>('_')
const inputRef = useRef<HTMLInputElement>(null)
const { isWorkflowPage, isRagPipelinePage } = useGotoAnythingContext()
const prevShowRef = useRef(false)
// Filter actions based on context
const Actions = useMemo(() => {
// Create actions based on current page context
return createActions(isWorkflowPage, isRagPipelinePage)
}, [isWorkflowPage, isRagPipelinePage])
// Search state management (called first so setSearchQuery is available)
const {
searchQuery,
setSearchQuery,
searchQueryDebouncedValue,
searchMode,
isCommandsMode,
cmdVal,
setCmdVal,
clearSelection,
Actions,
} = useGotoAnythingSearch()
const [activePlugin, setActivePlugin] = useState<Plugin>()
// Modal state management
const {
show,
setShow,
inputRef,
handleClose: modalClose,
} = useGotoAnythingModal()
// Handle keyboard shortcuts
const handleToggleModal = useCallback((e: KeyboardEvent) => {
// Allow closing when modal is open, even if focus is in the search input
if (!show && isEventTargetInputArea(e.target as HTMLElement))
return
e.preventDefault()
setShow((prev) => {
if (!prev) {
// Opening modal - reset search state
setSearchQuery('')
}
return !prev
})
}, [show])
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.k`, handleToggleModal, {
exactMatch: true,
useCapture: true,
})
useKeyPress(['esc'], (e) => {
if (show) {
e.preventDefault()
setShow(false)
// Reset state when modal opens/closes
useEffect(() => {
if (show && !prevShowRef.current) {
// Modal just opened - reset search
setSearchQuery('')
}
else if (!show && prevShowRef.current) {
// Modal just closed
setSearchQuery('')
clearSelection()
onHide?.()
}
prevShowRef.current = show
}, [show, setSearchQuery, clearSelection, onHide])
// Results fetching and processing
const {
dedupedResults,
groupedResults,
isLoading,
isError,
error,
} = useGotoAnythingResults({
searchQueryDebouncedValue,
searchMode,
isCommandsMode,
Actions,
isWorkflowPage,
isRagPipelinePage,
cmdVal,
setCmdVal,
})
const searchQueryDebouncedValue = useDebounce(searchQuery.trim(), {
wait: 300,
// Navigation handlers
const {
handleCommandSelect,
handleNavigate,
activePlugin,
setActivePlugin,
} = useGotoAnythingNavigation({
Actions,
setSearchQuery,
clearSelection,
inputRef,
onClose: () => setShow(false),
})
const isCommandsMode = searchQuery.trim() === '@' || searchQuery.trim() === '/'
|| (searchQuery.trim().startsWith('@') && !matchAction(searchQuery.trim(), Actions))
|| (searchQuery.trim().startsWith('/') && !matchAction(searchQuery.trim(), Actions))
// Handle search input change
const handleSearchChange = useCallback((value: string) => {
setSearchQuery(value)
if (!value.startsWith('@') && !value.startsWith('/'))
clearSelection()
}, [setSearchQuery, clearSelection])
const searchMode = useMemo(() => {
if (isCommandsMode) {
// Distinguish between @ (scopes) and / (commands) mode
if (searchQuery.trim().startsWith('@'))
return 'scopes'
else if (searchQuery.trim().startsWith('/'))
return 'commands'
return 'commands' // default fallback
}
// Handle search input keydown for slash commands
const handleSearchKeyDown = useCallback((e: KeyboardEvent<HTMLInputElement>) => {
if (e.key === 'Enter') {
const query = searchQuery.trim()
// Check if it's a complete slash command
if (query.startsWith('/')) {
const commandName = query.substring(1).split(' ')[0]
const handler = slashCommandRegistry.findCommand(commandName)
const query = searchQueryDebouncedValue.toLowerCase()
const action = matchAction(query, Actions)
if (!action)
return 'general'
return action.key === '/' ? '@command' : action.key
}, [searchQueryDebouncedValue, Actions, isCommandsMode, searchQuery])
const { data: searchResults = [], isLoading, isError, error } = useQuery(
{
queryKey: [
'goto-anything',
'search-result',
searchQueryDebouncedValue,
searchMode,
isWorkflowPage,
isRagPipelinePage,
defaultLocale,
Actions,
],
queryFn: async () => {
const query = searchQueryDebouncedValue.toLowerCase()
const action = matchAction(query, Actions)
return await searchAnything(defaultLocale, query, action, Actions)
},
enabled: !!searchQueryDebouncedValue && !isCommandsMode,
staleTime: 30000,
gcTime: 300000,
},
)
// Prevent automatic selection of the first option when cmdVal is not set
const clearSelection = () => {
setCmdVal('_')
}
const handleCommandSelect = useCallback((commandKey: string) => {
// Check if it's a slash command
if (commandKey.startsWith('/')) {
const commandName = commandKey.substring(1)
const handler = slashCommandRegistry.findCommand(commandName)
// If it's a direct mode command, execute immediately
if (handler?.mode === 'direct' && handler.execute) {
handler.execute()
setShow(false)
setSearchQuery('')
return
// If it's a direct mode command, execute immediately
const isAvailable = handler?.isAvailable?.() ?? true
if (handler?.mode === 'direct' && handler.execute && isAvailable) {
e.preventDefault()
handler.execute()
setShow(false)
setSearchQuery('')
}
}
}
}, [searchQuery, setShow, setSearchQuery])
// Otherwise, proceed with the normal flow (submenu mode)
setSearchQuery(`${commandKey} `)
clearSelection()
setTimeout(() => {
inputRef.current?.focus()
}, 0)
}, [])
// Handle navigation to selected result
const handleNavigate = useCallback((result: SearchResult) => {
setShow(false)
setSearchQuery('')
switch (result.type) {
case 'command': {
// Execute slash commands
const action = Actions.slash
action?.action?.(result)
break
}
case 'plugin':
setActivePlugin(result.data)
break
case 'workflow-node':
// Handle workflow node selection and navigation
if (result.metadata?.nodeId)
selectWorkflowNode(result.metadata.nodeId, true)
break
default:
if (result.path)
router.push(result.path)
}
}, [router])
const dedupedResults = useMemo(() => {
const seen = new Set<string>()
return searchResults.filter((result) => {
const key = `${result.type}-${result.id}`
if (seen.has(key))
return false
seen.add(key)
return true
})
}, [searchResults])
// Group results by type
const groupedResults = useMemo(() => dedupedResults.reduce((acc, result) => {
if (!acc[result.type])
acc[result.type] = []
acc[result.type].push(result)
return acc
}, {} as { [key: string]: SearchResult[] }), [dedupedResults])
useEffect(() => {
if (isCommandsMode)
return
if (!dedupedResults.length)
return
const currentValueExists = dedupedResults.some(result => `${result.type}-${result.id}` === cmdVal)
if (!currentValueExists)
setCmdVal(`${dedupedResults[0].type}-${dedupedResults[0].id}`)
}, [isCommandsMode, dedupedResults, cmdVal])
const emptyResult = useMemo(() => {
if (dedupedResults.length || !searchQuery.trim() || isLoading || isCommandsMode)
return null
const isCommandSearch = searchMode !== 'general'
const commandType = isCommandSearch ? searchMode.replace('@', '') : ''
if (isError) {
return (
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
<div>
<div className="text-sm font-medium text-red-500">{t('gotoAnything.searchTemporarilyUnavailable', { ns: 'app' })}</div>
<div className="mt-1 text-xs text-text-quaternary">
{t('gotoAnything.servicesUnavailableMessage', { ns: 'app' })}
</div>
</div>
</div>
)
}
return (
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
<div>
<div className="text-sm font-medium">
{isCommandSearch
? (() => {
const keyMap = {
app: 'gotoAnything.emptyState.noAppsFound',
plugin: 'gotoAnything.emptyState.noPluginsFound',
knowledge: 'gotoAnything.emptyState.noKnowledgeBasesFound',
node: 'gotoAnything.emptyState.noWorkflowNodesFound',
} as const
return t(keyMap[commandType as keyof typeof keyMap] || 'gotoAnything.noResults', { ns: 'app' })
})()
: t('gotoAnything.noResults', { ns: 'app' })}
</div>
<div className="mt-1 text-xs text-text-quaternary">
{isCommandSearch
? t('gotoAnything.emptyState.tryDifferentTerm', { ns: 'app' })
: t('gotoAnything.emptyState.trySpecificSearch', { ns: 'app', shortcuts: Object.values(Actions).map(action => action.shortcut).join(', ') })}
</div>
</div>
</div>
)
}, [dedupedResults, searchQuery, Actions, searchMode, isLoading, isError, isCommandsMode])
const defaultUI = useMemo(() => {
if (searchQuery.trim())
return null
return (
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
<div>
<div className="text-sm font-medium">{t('gotoAnything.searchTitle', { ns: 'app' })}</div>
<div className="mt-3 space-y-1 text-xs text-text-quaternary">
<div>{t('gotoAnything.searchHint', { ns: 'app' })}</div>
<div>{t('gotoAnything.commandHint', { ns: 'app' })}</div>
<div>{t('gotoAnything.slashHint', { ns: 'app' })}</div>
</div>
</div>
</div>
)
}, [searchQuery, Actions])
useEffect(() => {
if (show) {
requestAnimationFrame(() => {
inputRef.current?.focus()
})
}
}, [show])
// Determine which empty state to show
const emptyStateVariant = useMemo(() => {
if (isLoading)
return 'loading'
if (isError)
return 'error'
if (!searchQuery.trim())
return 'default'
if (dedupedResults.length === 0 && !isCommandsMode)
return 'no-results'
return null
}, [isLoading, isError, searchQuery, dedupedResults.length, isCommandsMode])
return (
<>
<SlashCommandProvider />
<Modal
isShow={show}
onClose={() => {
setShow(false)
setSearchQuery('')
clearSelection()
onHide?.()
}}
onClose={modalClose}
closable={false}
className="!w-[480px] !p-0"
highPriority={true}
@@ -307,78 +156,24 @@ const GotoAnything: FC<Props> = ({
disablePointerSelection
loop
>
<div className="flex items-center gap-3 border-b border-divider-subtle bg-components-panel-bg-blur px-4 py-3">
<RiSearchLine className="h-4 w-4 text-text-quaternary" />
<div className="flex flex-1 items-center gap-2">
<Input
ref={inputRef}
value={searchQuery}
placeholder={t('gotoAnything.searchPlaceholder', { ns: 'app' })}
onChange={(e) => {
setSearchQuery(e.target.value)
if (!e.target.value.startsWith('@') && !e.target.value.startsWith('/'))
clearSelection()
}}
onKeyDown={(e) => {
if (e.key === 'Enter') {
const query = searchQuery.trim()
// Check if it's a complete slash command
if (query.startsWith('/')) {
const commandName = query.substring(1).split(' ')[0]
const handler = slashCommandRegistry.findCommand(commandName)
// If it's a direct mode command, execute immediately
const isAvailable = handler?.isAvailable?.() ?? true
if (handler?.mode === 'direct' && handler.execute && isAvailable) {
e.preventDefault()
handler.execute()
setShow(false)
setSearchQuery('')
}
}
}
}}
className="flex-1 !border-0 !bg-transparent !shadow-none"
wrapperClassName="flex-1 !border-0 !bg-transparent"
autoFocus
/>
{searchMode !== 'general' && (
<div className="flex items-center gap-1 rounded bg-gray-100 px-2 py-[2px] text-xs font-medium text-gray-700 dark:bg-gray-800 dark:text-gray-300">
<span>
{(() => {
if (searchMode === 'scopes')
return 'SCOPES'
else if (searchMode === 'commands')
return 'COMMANDS'
else
return searchMode.replace('@', '').toUpperCase()
})()}
</span>
</div>
)}
</div>
<ShortcutsName keys={['ctrl', 'K']} textColor="secondary" />
</div>
<SearchInput
inputRef={inputRef}
value={searchQuery}
onChange={handleSearchChange}
onKeyDown={handleSearchKeyDown}
searchMode={searchMode}
placeholder={t('gotoAnything.searchPlaceholder', { ns: 'app' })}
/>
<Command.List className="h-[240px] overflow-y-auto">
{isLoading && (
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
<div className="flex items-center gap-2">
<div className="h-4 w-4 animate-spin rounded-full border-2 border-gray-300 border-t-gray-600"></div>
<span className="text-sm">{t('gotoAnything.searching', { ns: 'app' })}</span>
</div>
</div>
{emptyStateVariant === 'loading' && (
<EmptyState variant="loading" />
)}
{isError && (
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
<div>
<div className="text-sm font-medium text-red-500">{t('gotoAnything.searchFailed', { ns: 'app' })}</div>
<div className="mt-1 text-xs text-text-quaternary">
{error.message}
</div>
</div>
</div>
{emptyStateVariant === 'error' && (
<EmptyState variant="error" error={error} />
)}
{!isLoading && !isError && (
<>
{isCommandsMode
@@ -393,118 +188,46 @@ const GotoAnything: FC<Props> = ({
/>
)
: (
Object.entries(groupedResults).map(([type, results], groupIndex) => (
<Command.Group
key={groupIndex}
heading={(() => {
const typeMap = {
'app': 'gotoAnything.groups.apps',
'plugin': 'gotoAnything.groups.plugins',
'knowledge': 'gotoAnything.groups.knowledgeBases',
'workflow-node': 'gotoAnything.groups.workflowNodes',
'command': 'gotoAnything.groups.commands',
} as const
return t(typeMap[type as keyof typeof typeMap] || `${type}s`, { ns: 'app' })
})()}
className="p-2 capitalize text-text-secondary"
>
{results.map(result => (
<Command.Item
key={`${result.type}-${result.id}`}
value={`${result.type}-${result.id}`}
className="flex cursor-pointer items-center gap-3 rounded-md p-3 will-change-[background-color] hover:bg-state-base-hover aria-[selected=true]:bg-state-base-hover-alt data-[selected=true]:bg-state-base-hover-alt"
onSelect={() => handleNavigate(result)}
>
{result.icon}
<div className="min-w-0 flex-1">
<div className="truncate font-medium text-text-secondary">
{result.title}
</div>
{result.description && (
<div className="mt-0.5 truncate text-xs text-text-quaternary">
{result.description}
</div>
)}
</div>
<div className="text-xs capitalize text-text-quaternary">
{result.type}
</div>
</Command.Item>
))}
</Command.Group>
))
<ResultList
groupedResults={groupedResults}
onSelect={handleNavigate}
/>
)}
{!isCommandsMode && emptyResult}
{!isCommandsMode && defaultUI}
{!isCommandsMode && emptyStateVariant === 'no-results' && (
<EmptyState
variant="no-results"
searchMode={searchMode}
Actions={Actions}
/>
)}
{!isCommandsMode && emptyStateVariant === 'default' && (
<EmptyState variant="default" />
)}
</>
)}
</Command.List>
{/* Always show footer to prevent height jumping */}
<div className="border-t border-divider-subtle bg-components-panel-bg-blur px-4 py-2 text-xs text-text-tertiary">
<div className="flex min-h-[16px] items-center justify-between">
{(!!dedupedResults.length || isError)
? (
<>
<span>
{isError
? (
<span className="text-red-500">{t('gotoAnything.someServicesUnavailable', { ns: 'app' })}</span>
)
: (
<>
{t('gotoAnything.resultCount', { ns: 'app', count: dedupedResults.length })}
{searchMode !== 'general' && (
<span className="ml-2 opacity-60">
{t('gotoAnything.inScope', { ns: 'app', scope: searchMode.replace('@', '') })}
</span>
)}
</>
)}
</span>
<span className="opacity-60">
{searchMode !== 'general'
? t('gotoAnything.clearToSearchAll', { ns: 'app' })
: t('gotoAnything.useAtForSpecific', { ns: 'app' })}
</span>
</>
)
: (
<>
<span className="opacity-60">
{(() => {
if (isCommandsMode)
return t('gotoAnything.selectToNavigate', { ns: 'app' })
if (searchQuery.trim())
return t('gotoAnything.searching', { ns: 'app' })
return t('gotoAnything.startTyping', { ns: 'app' })
})()}
</span>
<span className="opacity-60">
{searchQuery.trim() || isCommandsMode
? t('gotoAnything.tips', { ns: 'app' })
: t('gotoAnything.pressEscToClose', { ns: 'app' })}
</span>
</>
)}
</div>
</div>
<Footer
resultCount={dedupedResults.length}
searchMode={searchMode}
isError={isError}
isCommandsMode={isCommandsMode}
hasQuery={!!searchQuery.trim()}
/>
</Command>
</div>
</Modal>
{
activePlugin && (
<InstallFromMarketplace
manifest={activePlugin}
uniqueIdentifier={activePlugin.latest_package_identifier}
onClose={() => setActivePlugin(undefined)}
onSuccess={() => setActivePlugin(undefined)}
/>
)
}
{activePlugin && (
<InstallFromMarketplace
manifest={activePlugin}
uniqueIdentifier={activePlugin.latest_package_identifier}
onClose={() => setActivePlugin(undefined)}
onSuccess={() => setActivePlugin(undefined)}
/>
)}
</>
)
}

View File

@@ -1,27 +1,19 @@
'use client'
import type { FileUpload } from '@/app/components/base/features/types'
import type { App } from '@/types/app'
import * as React from 'react'
import { useMemo, useRef } from 'react'
import { useRef } from 'react'
import { useTranslation } from 'react-i18next'
import Loading from '@/app/components/base/loading'
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
import AppInputsForm from '@/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-form'
import { BlockEnum, InputVarType, SupportUploadFileTypes } from '@/app/components/workflow/types'
import { useAppDetail } from '@/service/use-apps'
import { useFileUploadConfig } from '@/service/use-common'
import { useAppWorkflow } from '@/service/use-workflow'
import { AppModeEnum, Resolution } from '@/types/app'
import { useAppInputsFormSchema } from '@/app/components/plugins/plugin-detail-panel/app-selector/hooks/use-app-inputs-form-schema'
import { cn } from '@/utils/classnames'
type Props = {
value?: {
app_id: string
inputs: Record<string, any>
inputs: Record<string, unknown>
}
appDetail: App
onFormChange: (value: Record<string, any>) => void
onFormChange: (value: Record<string, unknown>) => void
}
const AppInputsPanel = ({
@@ -30,155 +22,33 @@ const AppInputsPanel = ({
onFormChange,
}: Props) => {
const { t } = useTranslation()
const inputsRef = useRef<any>(value?.inputs || {})
const isBasicApp = appDetail.mode !== AppModeEnum.ADVANCED_CHAT && appDetail.mode !== AppModeEnum.WORKFLOW
const { data: fileUploadConfig } = useFileUploadConfig()
const { data: currentApp, isFetching: isAppLoading } = useAppDetail(appDetail.id)
const { data: currentWorkflow, isFetching: isWorkflowLoading } = useAppWorkflow(isBasicApp ? '' : appDetail.id)
const isLoading = isAppLoading || isWorkflowLoading
const inputsRef = useRef<Record<string, unknown>>(value?.inputs || {})
const basicAppFileConfig = useMemo(() => {
let fileConfig: FileUpload
if (isBasicApp)
fileConfig = currentApp?.model_config?.file_upload as FileUpload
else
fileConfig = currentWorkflow?.features?.file_upload as FileUpload
return {
image: {
detail: fileConfig?.image?.detail || Resolution.high,
enabled: !!fileConfig?.image?.enabled,
number_limits: fileConfig?.image?.number_limits || 3,
transfer_methods: fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'],
},
enabled: !!(fileConfig?.enabled || fileConfig?.image?.enabled),
allowed_file_types: fileConfig?.allowed_file_types || [SupportUploadFileTypes.image],
allowed_file_extensions: fileConfig?.allowed_file_extensions || [...FILE_EXTS[SupportUploadFileTypes.image]].map(ext => `.${ext}`),
allowed_file_upload_methods: fileConfig?.allowed_file_upload_methods || fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'],
number_limits: fileConfig?.number_limits || fileConfig?.image?.number_limits || 3,
}
}, [currentApp?.model_config?.file_upload, currentWorkflow?.features?.file_upload, isBasicApp])
const { inputFormSchema, isLoading } = useAppInputsFormSchema({ appDetail })
const inputFormSchema = useMemo(() => {
if (!currentApp)
return []
let inputFormSchema = []
if (isBasicApp) {
inputFormSchema = currentApp.model_config?.user_input_form?.filter((item: any) => !item.external_data_tool).map((item: any) => {
if (item.paragraph) {
return {
...item.paragraph,
type: 'paragraph',
required: false,
}
}
if (item.number) {
return {
...item.number,
type: 'number',
required: false,
}
}
if (item.checkbox) {
return {
...item.checkbox,
type: 'checkbox',
required: false,
}
}
if (item.select) {
return {
...item.select,
type: 'select',
required: false,
}
}
if (item['file-list']) {
return {
...item['file-list'],
type: 'file-list',
required: false,
fileUploadConfig,
}
}
if (item.file) {
return {
...item.file,
type: 'file',
required: false,
fileUploadConfig,
}
}
if (item.json_object) {
return {
...item.json_object,
type: 'json_object',
}
}
return {
...item['text-input'],
type: 'text-input',
required: false,
}
}) || []
}
else {
const startNode = currentWorkflow?.graph?.nodes.find(node => node.data.type === BlockEnum.Start) as any
inputFormSchema = startNode?.data.variables.map((variable: any) => {
if (variable.type === InputVarType.multiFiles) {
return {
...variable,
required: false,
fileUploadConfig,
}
}
if (variable.type === InputVarType.singleFile) {
return {
...variable,
required: false,
fileUploadConfig,
}
}
return {
...variable,
required: false,
}
}) || []
}
if ((currentApp.mode === AppModeEnum.COMPLETION || currentApp.mode === AppModeEnum.WORKFLOW) && basicAppFileConfig.enabled) {
inputFormSchema.push({
label: 'Image Upload',
variable: '#image#',
type: InputVarType.singleFile,
required: false,
...basicAppFileConfig,
fileUploadConfig,
})
}
return inputFormSchema || []
}, [basicAppFileConfig, currentApp, currentWorkflow, fileUploadConfig, isBasicApp])
const handleFormChange = (value: Record<string, any>) => {
inputsRef.current = value
onFormChange(value)
const handleFormChange = (newValue: Record<string, unknown>) => {
inputsRef.current = newValue
onFormChange(newValue)
}
const hasInputs = inputFormSchema.length > 0
return (
<div className={cn('flex max-h-[240px] flex-col rounded-b-2xl border-t border-divider-subtle pb-4')}>
{isLoading && <div className="pt-3"><Loading type="app" /></div>}
{!isLoading && (
<div className="system-sm-semibold mb-2 mt-3 flex h-6 shrink-0 items-center px-4 text-text-secondary">{t('appSelector.params', { ns: 'app' })}</div>
)}
{!isLoading && !inputFormSchema.length && (
<div className="flex h-16 flex-col items-center justify-center">
<div className="system-sm-regular text-text-tertiary">{t('appSelector.noParams', { ns: 'app' })}</div>
<div className="system-sm-semibold mb-2 mt-3 flex h-6 shrink-0 items-center px-4 text-text-secondary">
{t('appSelector.params', { ns: 'app' })}
</div>
)}
{!isLoading && !!inputFormSchema.length && (
{!isLoading && !hasInputs && (
<div className="flex h-16 flex-col items-center justify-center">
<div className="system-sm-regular text-text-tertiary">
{t('appSelector.noParams', { ns: 'app' })}
</div>
</div>
)}
{!isLoading && hasInputs && (
<div className="grow overflow-y-auto">
<AppInputsForm
inputs={value?.inputs || {}}

View File

@@ -0,0 +1,211 @@
'use client'
import type { FileUpload } from '@/app/components/base/features/types'
import type { FileUploadConfigResponse } from '@/models/common'
import type { App } from '@/types/app'
import type { FetchWorkflowDraftResponse } from '@/types/workflow'
import { useMemo } from 'react'
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
import { BlockEnum, InputVarType, SupportUploadFileTypes } from '@/app/components/workflow/types'
import { useAppDetail } from '@/service/use-apps'
import { useFileUploadConfig } from '@/service/use-common'
import { useAppWorkflow } from '@/service/use-workflow'
import { AppModeEnum, Resolution } from '@/types/app'
const BASIC_INPUT_TYPE_MAP: Record<string, string> = {
'paragraph': 'paragraph',
'number': 'number',
'checkbox': 'checkbox',
'select': 'select',
'file-list': 'file-list',
'file': 'file',
'json_object': 'json_object',
}
const FILE_INPUT_TYPES = new Set(['file-list', 'file'])
const WORKFLOW_FILE_VAR_TYPES = new Set([InputVarType.multiFiles, InputVarType.singleFile])
type InputSchemaItem = {
label?: string
variable?: string
type: string
required: boolean
fileUploadConfig?: FileUploadConfigResponse
[key: string]: unknown
}
function isBasicAppMode(mode: string): boolean {
return mode !== AppModeEnum.ADVANCED_CHAT && mode !== AppModeEnum.WORKFLOW
}
function supportsImageUpload(mode: string): boolean {
return mode === AppModeEnum.COMPLETION || mode === AppModeEnum.WORKFLOW
}
function buildFileConfig(fileConfig: FileUpload | undefined) {
return {
image: {
detail: fileConfig?.image?.detail || Resolution.high,
enabled: !!fileConfig?.image?.enabled,
number_limits: fileConfig?.image?.number_limits || 3,
transfer_methods: fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'],
},
enabled: !!(fileConfig?.enabled || fileConfig?.image?.enabled),
allowed_file_types: fileConfig?.allowed_file_types || [SupportUploadFileTypes.image],
allowed_file_extensions: fileConfig?.allowed_file_extensions
|| [...FILE_EXTS[SupportUploadFileTypes.image]].map(ext => `.${ext}`),
allowed_file_upload_methods: fileConfig?.allowed_file_upload_methods
|| fileConfig?.image?.transfer_methods
|| ['local_file', 'remote_url'],
number_limits: fileConfig?.number_limits || fileConfig?.image?.number_limits || 3,
}
}
function mapBasicAppInputItem(
item: Record<string, unknown>,
fileUploadConfig?: FileUploadConfigResponse,
): InputSchemaItem | null {
for (const [key, type] of Object.entries(BASIC_INPUT_TYPE_MAP)) {
if (!item[key])
continue
const inputData = item[key] as Record<string, unknown>
const needsFileConfig = FILE_INPUT_TYPES.has(key)
return {
...inputData,
type,
required: false,
...(needsFileConfig && { fileUploadConfig }),
}
}
const textInput = item['text-input'] as Record<string, unknown> | undefined
if (!textInput)
return null
return {
...textInput,
type: 'text-input',
required: false,
}
}
function mapWorkflowVariable(
variable: Record<string, unknown>,
fileUploadConfig?: FileUploadConfigResponse,
): InputSchemaItem {
const needsFileConfig = WORKFLOW_FILE_VAR_TYPES.has(variable.type as InputVarType)
return {
...variable,
type: variable.type as string,
required: false,
...(needsFileConfig && { fileUploadConfig }),
}
}
function createImageUploadSchema(
basicFileConfig: ReturnType<typeof buildFileConfig>,
fileUploadConfig?: FileUploadConfigResponse,
): InputSchemaItem {
return {
label: 'Image Upload',
variable: '#image#',
type: InputVarType.singleFile,
required: false,
...basicFileConfig,
fileUploadConfig,
}
}
function buildBasicAppSchema(
currentApp: App,
fileUploadConfig?: FileUploadConfigResponse,
): InputSchemaItem[] {
const userInputForm = currentApp.model_config?.user_input_form as Array<Record<string, unknown>> | undefined
if (!userInputForm)
return []
return userInputForm
.filter((item: Record<string, unknown>) => !item.external_data_tool)
.map((item: Record<string, unknown>) => mapBasicAppInputItem(item, fileUploadConfig))
.filter((item): item is InputSchemaItem => item !== null)
}
function buildWorkflowSchema(
workflow: FetchWorkflowDraftResponse,
fileUploadConfig?: FileUploadConfigResponse,
): InputSchemaItem[] {
const startNode = workflow.graph?.nodes.find(
node => node.data.type === BlockEnum.Start,
) as { data: { variables: Array<Record<string, unknown>> } } | undefined
if (!startNode?.data.variables)
return []
return startNode.data.variables.map(
variable => mapWorkflowVariable(variable, fileUploadConfig),
)
}
type UseAppInputsFormSchemaParams = {
appDetail: App
}
type UseAppInputsFormSchemaResult = {
inputFormSchema: InputSchemaItem[]
isLoading: boolean
fileUploadConfig?: FileUploadConfigResponse
}
export function useAppInputsFormSchema({
appDetail,
}: UseAppInputsFormSchemaParams): UseAppInputsFormSchemaResult {
const isBasicApp = isBasicAppMode(appDetail.mode)
const { data: fileUploadConfig } = useFileUploadConfig()
const { data: currentApp, isFetching: isAppLoading } = useAppDetail(appDetail.id)
const { data: currentWorkflow, isFetching: isWorkflowLoading } = useAppWorkflow(
isBasicApp ? '' : appDetail.id,
)
const isLoading = isAppLoading || isWorkflowLoading
const inputFormSchema = useMemo(() => {
if (!currentApp)
return []
if (!isBasicApp && !currentWorkflow)
return []
// Build base schema based on app type
// Note: currentWorkflow is guaranteed to be defined here due to the early return above
const baseSchema = isBasicApp
? buildBasicAppSchema(currentApp, fileUploadConfig)
: buildWorkflowSchema(currentWorkflow!, fileUploadConfig)
if (!supportsImageUpload(currentApp.mode))
return baseSchema
const rawFileConfig = isBasicApp
? currentApp.model_config?.file_upload as FileUpload
: currentWorkflow?.features?.file_upload as FileUpload
const basicFileConfig = buildFileConfig(rawFileConfig)
if (!basicFileConfig.enabled)
return baseSchema
return [
...baseSchema,
createImageUploadSchema(basicFileConfig, fileUploadConfig),
]
}, [currentApp, currentWorkflow, fileUploadConfig, isBasicApp])
return {
inputFormSchema,
isLoading,
fileUploadConfig,
}
}

View File

@@ -6,7 +6,6 @@ import Toast from '@/app/components/base/toast'
import { PluginSource } from '../types'
import DetailHeader from './detail-header'
// Use vi.hoisted for mock functions used in vi.mock factories
const {
mockSetShowUpdatePluginModal,
mockRefreshModelProviders,

View File

@@ -1,416 +1,2 @@
import type { PluginDetail } from '../types'
import {
RiArrowLeftRightLine,
RiBugLine,
RiCloseLine,
RiHardDrive3Line,
} from '@remixicon/react'
import { useBoolean } from 'ahooks'
import * as React from 'react'
import { useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import ActionButton from '@/app/components/base/action-button'
import { trackEvent } from '@/app/components/base/amplitude'
import Badge from '@/app/components/base/badge'
import Button from '@/app/components/base/button'
import Confirm from '@/app/components/base/confirm'
import { Github } from '@/app/components/base/icons/src/public/common'
import { BoxSparkleFill } from '@/app/components/base/icons/src/vender/plugin'
import Toast from '@/app/components/base/toast'
import Tooltip from '@/app/components/base/tooltip'
import { AuthCategory, PluginAuth } from '@/app/components/plugins/plugin-auth'
import OperationDropdown from '@/app/components/plugins/plugin-detail-panel/operation-dropdown'
import PluginInfo from '@/app/components/plugins/plugin-page/plugin-info'
import UpdateFromMarketplace from '@/app/components/plugins/update-plugin/from-market-place'
import PluginVersionPicker from '@/app/components/plugins/update-plugin/plugin-version-picker'
import { API_PREFIX } from '@/config'
import { useAppContext } from '@/context/app-context'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useGetLanguage, useLocale } from '@/context/i18n'
import { useModalContext } from '@/context/modal-context'
import { useProviderContext } from '@/context/provider-context'
import useTheme from '@/hooks/use-theme'
import { uninstallPlugin } from '@/service/plugins'
import { useAllToolProviders, useInvalidateAllToolProviders } from '@/service/use-tools'
import { cn } from '@/utils/classnames'
import { getMarketplaceUrl } from '@/utils/var'
import { AutoUpdateLine } from '../../base/icons/src/vender/system'
import Verified from '../base/badges/verified'
import DeprecationNotice from '../base/deprecation-notice'
import Icon from '../card/base/card-icon'
import Description from '../card/base/description'
import OrgInfo from '../card/base/org-info'
import Title from '../card/base/title'
import { useGitHubReleases } from '../install-plugin/hooks'
import useReferenceSetting from '../plugin-page/use-reference-setting'
import { AUTO_UPDATE_MODE } from '../reference-setting-modal/auto-update-setting/types'
import { convertUTCDaySecondsToLocalSeconds, timeOfDayToDayjs } from '../reference-setting-modal/auto-update-setting/utils'
import { PluginCategoryEnum, PluginSource } from '../types'
const i18nPrefix = 'action'
type Props = {
detail: PluginDetail
isReadmeView?: boolean
onHide?: () => void
onUpdate?: (isDelete?: boolean) => void
}
const DetailHeader = ({
detail,
isReadmeView = false,
onHide,
onUpdate,
}: Props) => {
const { t } = useTranslation()
const { userProfile: { timezone } } = useAppContext()
const { theme } = useTheme()
const locale = useGetLanguage()
const currentLocale = useLocale()
const { checkForUpdates, fetchReleases } = useGitHubReleases()
const { setShowUpdatePluginModal } = useModalContext()
const { refreshModelProviders } = useProviderContext()
const invalidateAllToolProviders = useInvalidateAllToolProviders()
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
const {
id,
source,
tenant_id,
version,
latest_unique_identifier,
latest_version,
meta,
plugin_id,
status,
deprecated_reason,
alternative_plugin_id,
} = detail
const { author, category, name, label, description, icon, icon_dark, verified, tool } = detail.declaration || detail
const isTool = category === PluginCategoryEnum.tool
const providerBriefInfo = tool?.identity
const providerKey = `${plugin_id}/${providerBriefInfo?.name}`
const { data: collectionList = [] } = useAllToolProviders(isTool)
const provider = useMemo(() => {
return collectionList.find(collection => collection.name === providerKey)
}, [collectionList, providerKey])
const isFromGitHub = source === PluginSource.github
const isFromMarketplace = source === PluginSource.marketplace
const [isShow, setIsShow] = useState(false)
const [targetVersion, setTargetVersion] = useState({
version: latest_version,
unique_identifier: latest_unique_identifier,
})
const hasNewVersion = useMemo(() => {
if (isFromMarketplace)
return !!latest_version && latest_version !== version
return false
}, [isFromMarketplace, latest_version, version])
const iconFileName = theme === 'dark' && icon_dark ? icon_dark : icon
const iconSrc = iconFileName
? (iconFileName.startsWith('http') ? iconFileName : `${API_PREFIX}/workspaces/current/plugin/icon?tenant_id=${tenant_id}&filename=${iconFileName}`)
: ''
const detailUrl = useMemo(() => {
if (isFromGitHub)
return `https://github.com/${meta!.repo}`
if (isFromMarketplace)
return getMarketplaceUrl(`/plugins/${author}/${name}`, { language: currentLocale, theme })
return ''
}, [author, isFromGitHub, isFromMarketplace, meta, name, theme])
const [isShowUpdateModal, {
setTrue: showUpdateModal,
setFalse: hideUpdateModal,
}] = useBoolean(false)
const { referenceSetting } = useReferenceSetting()
const { auto_upgrade: autoUpgradeInfo } = referenceSetting || {}
const isAutoUpgradeEnabled = useMemo(() => {
if (!enable_marketplace)
return false
if (!autoUpgradeInfo || !isFromMarketplace)
return false
if (autoUpgradeInfo.strategy_setting === 'disabled')
return false
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.update_all)
return true
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.partial && autoUpgradeInfo.include_plugins.includes(plugin_id))
return true
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.exclude && !autoUpgradeInfo.exclude_plugins.includes(plugin_id))
return true
return false
}, [autoUpgradeInfo, plugin_id, isFromMarketplace])
const [isDowngrade, setIsDowngrade] = useState(false)
const handleUpdate = async (isDowngrade?: boolean) => {
if (isFromMarketplace) {
setIsDowngrade(!!isDowngrade)
showUpdateModal()
return
}
const owner = meta!.repo.split('/')[0] || author
const repo = meta!.repo.split('/')[1] || name
const fetchedReleases = await fetchReleases(owner, repo)
if (fetchedReleases.length === 0)
return
const { needUpdate, toastProps } = checkForUpdates(fetchedReleases, meta!.version)
Toast.notify(toastProps)
if (needUpdate) {
setShowUpdatePluginModal({
onSaveCallback: () => {
onUpdate?.()
},
payload: {
type: PluginSource.github,
category: detail.declaration.category,
github: {
originalPackageInfo: {
id: detail.plugin_unique_identifier,
repo: meta!.repo,
version: meta!.version,
package: meta!.package,
releases: fetchedReleases,
},
},
},
})
}
}
const handleUpdatedFromMarketplace = () => {
onUpdate?.()
hideUpdateModal()
}
const [isShowPluginInfo, {
setTrue: showPluginInfo,
setFalse: hidePluginInfo,
}] = useBoolean(false)
const [isShowDeleteConfirm, {
setTrue: showDeleteConfirm,
setFalse: hideDeleteConfirm,
}] = useBoolean(false)
const [deleting, {
setTrue: showDeleting,
setFalse: hideDeleting,
}] = useBoolean(false)
const handleDelete = useCallback(async () => {
showDeleting()
const res = await uninstallPlugin(id)
hideDeleting()
if (res.success) {
hideDeleteConfirm()
onUpdate?.(true)
if (PluginCategoryEnum.model.includes(category))
refreshModelProviders()
if (PluginCategoryEnum.tool.includes(category))
invalidateAllToolProviders()
trackEvent('plugin_uninstalled', { plugin_id, plugin_name: name })
}
}, [showDeleting, id, hideDeleting, hideDeleteConfirm, onUpdate, category, refreshModelProviders, invalidateAllToolProviders, plugin_id, name])
return (
<div className={cn('shrink-0 border-b border-divider-subtle bg-components-panel-bg p-4 pb-3', isReadmeView && 'border-b-0 bg-transparent p-0')}>
<div className="flex">
<div className={cn('overflow-hidden rounded-xl border border-components-panel-border-subtle', isReadmeView && 'bg-components-panel-bg')}>
<Icon src={iconSrc} />
</div>
<div className="ml-3 w-0 grow">
<div className="flex h-5 items-center">
<Title title={label[locale]} />
{verified && !isReadmeView && <Verified className="ml-0.5 h-4 w-4" text={t('marketplace.verifiedTip', { ns: 'plugin' })} />}
{!!version && (
<PluginVersionPicker
disabled={!isFromMarketplace || isReadmeView}
isShow={isShow}
onShowChange={setIsShow}
pluginID={plugin_id}
currentVersion={version}
onSelect={(state) => {
setTargetVersion(state)
handleUpdate(state.isDowngrade)
}}
trigger={(
<Badge
className={cn(
'mx-1',
isShow && 'bg-state-base-hover',
(isShow || isFromMarketplace) && 'hover:bg-state-base-hover',
)}
uppercase={false}
text={(
<>
<div>{isFromGitHub ? meta!.version : version}</div>
{isFromMarketplace && !isReadmeView && <RiArrowLeftRightLine className="ml-1 h-3 w-3 text-text-tertiary" />}
</>
)}
hasRedCornerMark={hasNewVersion}
/>
)}
/>
)}
{/* Auto update info */}
{isAutoUpgradeEnabled && !isReadmeView && (
<Tooltip popupContent={t('autoUpdate.nextUpdateTime', { ns: 'plugin', time: timeOfDayToDayjs(convertUTCDaySecondsToLocalSeconds(autoUpgradeInfo?.upgrade_time_of_day || 0, timezone!)).format('hh:mm A') })}>
{/* add a a div to fix tooltip hover not show problem */}
<div>
<Badge className="mr-1 cursor-pointer px-1">
<AutoUpdateLine className="size-3" />
</Badge>
</div>
</Tooltip>
)}
{(hasNewVersion || isFromGitHub) && (
<Button
variant="secondary-accent"
size="small"
className="!h-5"
onClick={() => {
if (isFromMarketplace) {
setTargetVersion({
version: latest_version,
unique_identifier: latest_unique_identifier,
})
}
handleUpdate()
}}
>
{t('detailPanel.operation.update', { ns: 'plugin' })}
</Button>
)}
</div>
<div className="mb-1 flex h-4 items-center justify-between">
<div className="mt-0.5 flex items-center">
<OrgInfo
packageNameClassName="w-auto"
orgName={author}
packageName={name?.includes('/') ? (name.split('/').pop() || '') : name}
/>
{!!source && (
<>
<div className="system-xs-regular ml-1 mr-0.5 text-text-quaternary">·</div>
{source === PluginSource.marketplace && (
<Tooltip popupContent={t('detailPanel.categoryTip.marketplace', { ns: 'plugin' })}>
<div><BoxSparkleFill className="h-3.5 w-3.5 text-text-tertiary hover:text-text-accent" /></div>
</Tooltip>
)}
{source === PluginSource.github && (
<Tooltip popupContent={t('detailPanel.categoryTip.github', { ns: 'plugin' })}>
<div><Github className="h-3.5 w-3.5 text-text-secondary hover:text-text-primary" /></div>
</Tooltip>
)}
{source === PluginSource.local && (
<Tooltip popupContent={t('detailPanel.categoryTip.local', { ns: 'plugin' })}>
<div><RiHardDrive3Line className="h-3.5 w-3.5 text-text-tertiary" /></div>
</Tooltip>
)}
{source === PluginSource.debugging && (
<Tooltip popupContent={t('detailPanel.categoryTip.debugging', { ns: 'plugin' })}>
<div><RiBugLine className="h-3.5 w-3.5 text-text-tertiary hover:text-text-warning" /></div>
</Tooltip>
)}
</>
)}
</div>
</div>
</div>
{!isReadmeView && (
<div className="flex gap-1">
<OperationDropdown
source={source}
onInfo={showPluginInfo}
onCheckVersion={handleUpdate}
onRemove={showDeleteConfirm}
detailUrl={detailUrl}
/>
<ActionButton onClick={onHide}>
<RiCloseLine className="h-4 w-4" />
</ActionButton>
</div>
)}
</div>
{isFromMarketplace && (
<DeprecationNotice
status={status}
deprecatedReason={deprecated_reason}
alternativePluginId={alternative_plugin_id}
alternativePluginURL={getMarketplaceUrl(`/plugins/${alternative_plugin_id}`, { language: currentLocale, theme })}
className="mt-3"
/>
)}
{!isReadmeView && <Description className="mb-2 mt-3 h-auto" text={description[locale]} descriptionLineRows={2}></Description>}
{
category === PluginCategoryEnum.tool && !isReadmeView && (
<PluginAuth
pluginPayload={{
provider: provider?.name || '',
category: AuthCategory.tool,
providerType: provider?.type || '',
detail,
}}
/>
)
}
{isShowPluginInfo && (
<PluginInfo
repository={isFromGitHub ? meta?.repo : ''}
release={version}
packageName={meta?.package || ''}
onHide={hidePluginInfo}
/>
)}
{isShowDeleteConfirm && (
<Confirm
isShow
title={t(`${i18nPrefix}.delete`, { ns: 'plugin' })}
content={(
<div>
{t(`${i18nPrefix}.deleteContentLeft`, { ns: 'plugin' })}
<span className="system-md-semibold">{label[locale]}</span>
{t(`${i18nPrefix}.deleteContentRight`, { ns: 'plugin' })}
<br />
</div>
)}
onCancel={hideDeleteConfirm}
onConfirm={handleDelete}
isLoading={deleting}
isDisabled={deleting}
/>
)}
{
isShowUpdateModal && (
<UpdateFromMarketplace
pluginId={plugin_id}
payload={{
category: detail.declaration.category,
originalPackageInfo: {
id: detail.plugin_unique_identifier,
payload: detail.declaration,
},
targetPackageInfo: {
id: targetVersion.unique_identifier,
version: targetVersion.version,
},
}}
onCancel={hideUpdateModal}
onSave={handleUpdatedFromMarketplace}
isShowDowngradeWarningModal={isDowngrade && isAutoUpgradeEnabled}
/>
)
}
</div>
)
}
export default DetailHeader
// Re-export from refactored module for backward compatibility
export { default } from './detail-header/index'

View File

@@ -0,0 +1,539 @@
import type { PluginDetail } from '../../../types'
import type { ModalStates, VersionTarget } from '../hooks'
import { fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { PluginSource } from '../../../types'
import HeaderModals from './header-modals'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
vi.mock('@/context/i18n', () => ({
useGetLanguage: () => 'en_US',
}))
vi.mock('@/app/components/base/confirm', () => ({
default: ({ isShow, title, onCancel, onConfirm, isLoading }: {
isShow: boolean
title: string
onCancel: () => void
onConfirm: () => void
isLoading: boolean
}) => isShow
? (
<div data-testid="delete-confirm">
<div data-testid="delete-title">{title}</div>
<button data-testid="confirm-cancel" onClick={onCancel}>Cancel</button>
<button data-testid="confirm-ok" onClick={onConfirm} disabled={isLoading}>Confirm</button>
</div>
)
: null,
}))
vi.mock('@/app/components/plugins/plugin-page/plugin-info', () => ({
default: ({ repository, release, packageName, onHide }: {
repository: string
release: string
packageName: string
onHide: () => void
}) => (
<div data-testid="plugin-info">
<div data-testid="plugin-info-repo">{repository}</div>
<div data-testid="plugin-info-release">{release}</div>
<div data-testid="plugin-info-package">{packageName}</div>
<button data-testid="plugin-info-close" onClick={onHide}>Close</button>
</div>
),
}))
vi.mock('@/app/components/plugins/update-plugin/from-market-place', () => ({
default: ({ pluginId, onSave, onCancel, isShowDowngradeWarningModal }: {
pluginId: string
onSave: () => void
onCancel: () => void
isShowDowngradeWarningModal: boolean
}) => (
<div data-testid="update-modal">
<div data-testid="update-plugin-id">{pluginId}</div>
<div data-testid="update-downgrade-warning">{String(isShowDowngradeWarningModal)}</div>
<button data-testid="update-modal-save" onClick={onSave}>Save</button>
<button data-testid="update-modal-cancel" onClick={onCancel}>Cancel</button>
</div>
),
}))
const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({
id: 'test-id',
created_at: '2024-01-01',
updated_at: '2024-01-02',
name: 'Test Plugin',
plugin_id: 'test-plugin',
plugin_unique_identifier: 'test-uid',
declaration: {
author: 'test-author',
name: 'test-plugin-name',
category: 'tool',
label: { en_US: 'Test Plugin Label' },
description: { en_US: 'Test description' },
icon: 'icon.png',
verified: true,
} as unknown as PluginDetail['declaration'],
installation_id: 'install-1',
tenant_id: 'tenant-1',
endpoints_setups: 0,
endpoints_active: 0,
version: '1.0.0',
latest_version: '2.0.0',
latest_unique_identifier: 'new-uid',
source: PluginSource.marketplace,
meta: undefined,
status: 'active',
deprecated_reason: '',
alternative_plugin_id: '',
...overrides,
})
const createModalStatesMock = (overrides: Partial<ModalStates> = {}): ModalStates => ({
isShowUpdateModal: false,
showUpdateModal: vi.fn<() => void>(),
hideUpdateModal: vi.fn<() => void>(),
isShowPluginInfo: false,
showPluginInfo: vi.fn<() => void>(),
hidePluginInfo: vi.fn<() => void>(),
isShowDeleteConfirm: false,
showDeleteConfirm: vi.fn<() => void>(),
hideDeleteConfirm: vi.fn<() => void>(),
deleting: false,
showDeleting: vi.fn<() => void>(),
hideDeleting: vi.fn<() => void>(),
...overrides,
})
const createTargetVersion = (overrides: Partial<VersionTarget> = {}): VersionTarget => ({
version: '2.0.0',
unique_identifier: 'new-uid',
...overrides,
})
describe('HeaderModals', () => {
let mockOnUpdatedFromMarketplace: () => void
let mockOnDelete: () => void
beforeEach(() => {
vi.clearAllMocks()
mockOnUpdatedFromMarketplace = vi.fn<() => void>()
mockOnDelete = vi.fn<() => void>()
})
describe('Plugin Info Modal', () => {
it('should not render plugin info modal when isShowPluginInfo is false', () => {
const modalStates = createModalStatesMock({ isShowPluginInfo: false })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.queryByTestId('plugin-info')).not.toBeInTheDocument()
})
it('should render plugin info modal when isShowPluginInfo is true', () => {
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('plugin-info')).toBeInTheDocument()
})
it('should pass GitHub repo to plugin info for GitHub source', () => {
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
const detail = createPluginDetail({
source: PluginSource.github,
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'test-pkg' },
})
render(
<HeaderModals
detail={detail}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('owner/repo')
})
it('should pass empty string for repo for non-GitHub source', () => {
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
render(
<HeaderModals
detail={createPluginDetail({ source: PluginSource.marketplace })}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('')
})
it('should call hidePluginInfo when close button is clicked', () => {
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
fireEvent.click(screen.getByTestId('plugin-info-close'))
expect(modalStates.hidePluginInfo).toHaveBeenCalled()
})
})
describe('Delete Confirm Modal', () => {
it('should not render delete confirm when isShowDeleteConfirm is false', () => {
const modalStates = createModalStatesMock({ isShowDeleteConfirm: false })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.queryByTestId('delete-confirm')).not.toBeInTheDocument()
})
it('should render delete confirm when isShowDeleteConfirm is true', () => {
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
})
it('should show correct delete title', () => {
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('delete-title')).toHaveTextContent('action.delete')
})
it('should call hideDeleteConfirm when cancel is clicked', () => {
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
fireEvent.click(screen.getByTestId('confirm-cancel'))
expect(modalStates.hideDeleteConfirm).toHaveBeenCalled()
})
it('should call onDelete when confirm is clicked', () => {
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
fireEvent.click(screen.getByTestId('confirm-ok'))
expect(mockOnDelete).toHaveBeenCalled()
})
it('should disable confirm button when deleting', () => {
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true, deleting: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('confirm-ok')).toBeDisabled()
})
})
describe('Update Modal', () => {
it('should not render update modal when isShowUpdateModal is false', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: false })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.queryByTestId('update-modal')).not.toBeInTheDocument()
})
it('should render update modal when isShowUpdateModal is true', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
})
it('should pass plugin id to update modal', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
render(
<HeaderModals
detail={createPluginDetail({ plugin_id: 'my-plugin-id' })}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('update-plugin-id')).toHaveTextContent('my-plugin-id')
})
it('should call onUpdatedFromMarketplace when save is clicked', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
fireEvent.click(screen.getByTestId('update-modal-save'))
expect(mockOnUpdatedFromMarketplace).toHaveBeenCalled()
})
it('should call hideUpdateModal when cancel is clicked', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
fireEvent.click(screen.getByTestId('update-modal-cancel'))
expect(modalStates.hideUpdateModal).toHaveBeenCalled()
})
it('should show downgrade warning when isDowngrade and isAutoUpgradeEnabled are true', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={true}
isAutoUpgradeEnabled={true}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('true')
})
it('should not show downgrade warning when only isDowngrade is true', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={true}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('false')
})
it('should not show downgrade warning when only isAutoUpgradeEnabled is true', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={true}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('false')
})
})
describe('Multiple Modals', () => {
it('should render multiple modals when multiple are open', () => {
const modalStates = createModalStatesMock({
isShowPluginInfo: true,
isShowDeleteConfirm: true,
isShowUpdateModal: true,
})
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('plugin-info')).toBeInTheDocument()
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
})
})
describe('Edge Cases', () => {
it('should handle undefined target version values', () => {
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
render(
<HeaderModals
detail={createPluginDetail()}
modalStates={modalStates}
targetVersion={{ version: undefined, unique_identifier: undefined }}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
})
it('should handle empty meta for GitHub source', () => {
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
const detail = createPluginDetail({
source: PluginSource.github,
meta: undefined,
})
render(
<HeaderModals
detail={detail}
modalStates={modalStates}
targetVersion={createTargetVersion()}
isDowngrade={false}
isAutoUpgradeEnabled={false}
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
onDelete={mockOnDelete}
/>,
)
expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('')
expect(screen.getByTestId('plugin-info-package')).toHaveTextContent('')
})
})
})

View File

@@ -0,0 +1,107 @@
'use client'
import type { FC } from 'react'
import type { PluginDetail } from '../../../types'
import type { ModalStates, VersionTarget } from '../hooks'
import { useTranslation } from 'react-i18next'
import Confirm from '@/app/components/base/confirm'
import PluginInfo from '@/app/components/plugins/plugin-page/plugin-info'
import UpdateFromMarketplace from '@/app/components/plugins/update-plugin/from-market-place'
import { useGetLanguage } from '@/context/i18n'
import { PluginSource } from '../../../types'
const i18nPrefix = 'action'
type HeaderModalsProps = {
detail: PluginDetail
modalStates: ModalStates
targetVersion: VersionTarget
isDowngrade: boolean
isAutoUpgradeEnabled: boolean
onUpdatedFromMarketplace: () => void
onDelete: () => void
}
const HeaderModals: FC<HeaderModalsProps> = ({
detail,
modalStates,
targetVersion,
isDowngrade,
isAutoUpgradeEnabled,
onUpdatedFromMarketplace,
onDelete,
}) => {
const { t } = useTranslation()
const locale = useGetLanguage()
const { source, version, meta } = detail
const { label } = detail.declaration || detail
const isFromGitHub = source === PluginSource.github
const {
isShowUpdateModal,
hideUpdateModal,
isShowPluginInfo,
hidePluginInfo,
isShowDeleteConfirm,
hideDeleteConfirm,
deleting,
} = modalStates
return (
<>
{/* Plugin Info Modal */}
{isShowPluginInfo && (
<PluginInfo
repository={isFromGitHub ? meta?.repo : ''}
release={version}
packageName={meta?.package || ''}
onHide={hidePluginInfo}
/>
)}
{/* Delete Confirm Modal */}
{isShowDeleteConfirm && (
<Confirm
isShow
title={t(`${i18nPrefix}.delete`, { ns: 'plugin' })}
content={(
<div>
{t(`${i18nPrefix}.deleteContentLeft`, { ns: 'plugin' })}
<span className="system-md-semibold">{label[locale]}</span>
{t(`${i18nPrefix}.deleteContentRight`, { ns: 'plugin' })}
<br />
</div>
)}
onCancel={hideDeleteConfirm}
onConfirm={onDelete}
isLoading={deleting}
isDisabled={deleting}
/>
)}
{/* Update from Marketplace Modal */}
{isShowUpdateModal && (
<UpdateFromMarketplace
pluginId={detail.plugin_id}
payload={{
category: detail.declaration?.category ?? '',
originalPackageInfo: {
id: detail.plugin_unique_identifier,
payload: detail.declaration ?? undefined,
},
targetPackageInfo: {
id: targetVersion.unique_identifier || '',
version: targetVersion.version || '',
},
}}
onCancel={hideUpdateModal}
onSave={onUpdatedFromMarketplace}
isShowDowngradeWarningModal={isDowngrade && isAutoUpgradeEnabled}
/>
)}
</>
)
}
export default HeaderModals

View File

@@ -0,0 +1,2 @@
export { default as HeaderModals } from './header-modals'
export { default as PluginSourceBadge } from './plugin-source-badge'

View File

@@ -0,0 +1,200 @@
import { render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { PluginSource } from '../../../types'
import PluginSourceBadge from './plugin-source-badge'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
vi.mock('@/app/components/base/tooltip', () => ({
default: ({ children, popupContent }: { children: React.ReactNode, popupContent: string }) => (
<div data-testid="tooltip" data-content={popupContent}>
{children}
</div>
),
}))
describe('PluginSourceBadge', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('Source Icon Rendering', () => {
it('should render marketplace source badge', () => {
render(<PluginSourceBadge source={PluginSource.marketplace} />)
const tooltip = screen.getByTestId('tooltip')
expect(tooltip).toBeInTheDocument()
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.marketplace')
})
it('should render github source badge', () => {
render(<PluginSourceBadge source={PluginSource.github} />)
const tooltip = screen.getByTestId('tooltip')
expect(tooltip).toBeInTheDocument()
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.github')
})
it('should render local source badge', () => {
render(<PluginSourceBadge source={PluginSource.local} />)
const tooltip = screen.getByTestId('tooltip')
expect(tooltip).toBeInTheDocument()
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.local')
})
it('should render debugging source badge', () => {
render(<PluginSourceBadge source={PluginSource.debugging} />)
const tooltip = screen.getByTestId('tooltip')
expect(tooltip).toBeInTheDocument()
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.debugging')
})
})
describe('Separator Rendering', () => {
it('should render separator dot before marketplace badge', () => {
const { container } = render(<PluginSourceBadge source={PluginSource.marketplace} />)
const separator = container.querySelector('.text-text-quaternary')
expect(separator).toBeInTheDocument()
expect(separator?.textContent).toBe('·')
})
it('should render separator dot before github badge', () => {
const { container } = render(<PluginSourceBadge source={PluginSource.github} />)
const separator = container.querySelector('.text-text-quaternary')
expect(separator).toBeInTheDocument()
expect(separator?.textContent).toBe('·')
})
it('should render separator dot before local badge', () => {
const { container } = render(<PluginSourceBadge source={PluginSource.local} />)
const separator = container.querySelector('.text-text-quaternary')
expect(separator).toBeInTheDocument()
})
it('should render separator dot before debugging badge', () => {
const { container } = render(<PluginSourceBadge source={PluginSource.debugging} />)
const separator = container.querySelector('.text-text-quaternary')
expect(separator).toBeInTheDocument()
})
})
describe('Tooltip Content', () => {
it('should show marketplace tooltip', () => {
render(<PluginSourceBadge source={PluginSource.marketplace} />)
expect(screen.getByTestId('tooltip')).toHaveAttribute(
'data-content',
'detailPanel.categoryTip.marketplace',
)
})
it('should show github tooltip', () => {
render(<PluginSourceBadge source={PluginSource.github} />)
expect(screen.getByTestId('tooltip')).toHaveAttribute(
'data-content',
'detailPanel.categoryTip.github',
)
})
it('should show local tooltip', () => {
render(<PluginSourceBadge source={PluginSource.local} />)
expect(screen.getByTestId('tooltip')).toHaveAttribute(
'data-content',
'detailPanel.categoryTip.local',
)
})
it('should show debugging tooltip', () => {
render(<PluginSourceBadge source={PluginSource.debugging} />)
expect(screen.getByTestId('tooltip')).toHaveAttribute(
'data-content',
'detailPanel.categoryTip.debugging',
)
})
})
describe('Icon Element Structure', () => {
it('should render icon inside tooltip for marketplace', () => {
render(<PluginSourceBadge source={PluginSource.marketplace} />)
const tooltip = screen.getByTestId('tooltip')
const iconWrapper = tooltip.querySelector('div')
expect(iconWrapper).toBeInTheDocument()
})
it('should render icon inside tooltip for github', () => {
render(<PluginSourceBadge source={PluginSource.github} />)
const tooltip = screen.getByTestId('tooltip')
const iconWrapper = tooltip.querySelector('div')
expect(iconWrapper).toBeInTheDocument()
})
it('should render icon inside tooltip for local', () => {
render(<PluginSourceBadge source={PluginSource.local} />)
const tooltip = screen.getByTestId('tooltip')
const iconWrapper = tooltip.querySelector('div')
expect(iconWrapper).toBeInTheDocument()
})
it('should render icon inside tooltip for debugging', () => {
render(<PluginSourceBadge source={PluginSource.debugging} />)
const tooltip = screen.getByTestId('tooltip')
const iconWrapper = tooltip.querySelector('div')
expect(iconWrapper).toBeInTheDocument()
})
})
describe('Lookup Table Coverage', () => {
it('should handle all PluginSource enum values', () => {
const allSources = Object.values(PluginSource)
allSources.forEach((source) => {
const { container } = render(<PluginSourceBadge source={source} />)
// Should render either tooltip or nothing
expect(container).toBeTruthy()
})
})
})
describe('Invalid Source Handling', () => {
it('should return null for unknown source type', () => {
// Use type assertion to test invalid source value
const invalidSource = 'unknown_source' as PluginSource
const { container } = render(<PluginSourceBadge source={invalidSource} />)
// Should render nothing (empty container)
expect(container.firstChild).toBeNull()
})
it('should not render separator for invalid source', () => {
const invalidSource = 'invalid' as PluginSource
const { container } = render(<PluginSourceBadge source={invalidSource} />)
const separator = container.querySelector('.text-text-quaternary')
expect(separator).not.toBeInTheDocument()
})
it('should not render tooltip for invalid source', () => {
const invalidSource = '' as PluginSource
render(<PluginSourceBadge source={invalidSource} />)
expect(screen.queryByTestId('tooltip')).not.toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,59 @@
'use client'
import type { FC, ReactNode } from 'react'
import {
RiBugLine,
RiHardDrive3Line,
} from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import { Github } from '@/app/components/base/icons/src/public/common'
import { BoxSparkleFill } from '@/app/components/base/icons/src/vender/plugin'
import Tooltip from '@/app/components/base/tooltip'
import { PluginSource } from '../../../types'
type SourceConfig = {
icon: ReactNode
tipKey: string
}
type PluginSourceBadgeProps = {
source: PluginSource
}
const SOURCE_CONFIG_MAP: Record<PluginSource, SourceConfig | null> = {
[PluginSource.marketplace]: {
icon: <BoxSparkleFill className="h-3.5 w-3.5 text-text-tertiary hover:text-text-accent" />,
tipKey: 'detailPanel.categoryTip.marketplace',
},
[PluginSource.github]: {
icon: <Github className="h-3.5 w-3.5 text-text-secondary hover:text-text-primary" />,
tipKey: 'detailPanel.categoryTip.github',
},
[PluginSource.local]: {
icon: <RiHardDrive3Line className="h-3.5 w-3.5 text-text-tertiary" />,
tipKey: 'detailPanel.categoryTip.local',
},
[PluginSource.debugging]: {
icon: <RiBugLine className="h-3.5 w-3.5 text-text-tertiary hover:text-text-warning" />,
tipKey: 'detailPanel.categoryTip.debugging',
},
}
const PluginSourceBadge: FC<PluginSourceBadgeProps> = ({ source }) => {
const { t } = useTranslation()
const config = SOURCE_CONFIG_MAP[source]
if (!config)
return null
return (
<>
<div className="system-xs-regular ml-1 mr-0.5 text-text-quaternary">·</div>
<Tooltip popupContent={t(config.tipKey as never, { ns: 'plugin' })}>
<div>{config.icon}</div>
</Tooltip>
</>
)
}
export default PluginSourceBadge

Some files were not shown because too many files have changed in this diff Show More