mirror of
https://github.com/langgenius/dify.git
synced 2026-04-10 19:02:02 +08:00
Compare commits
29 Commits
refactor/m
...
fix/consol
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce644e1549 | ||
|
|
468990cc39 | ||
|
|
64e769f96e | ||
|
|
778aabb485 | ||
|
|
d8402f686e | ||
|
|
8bd8dee767 | ||
|
|
05f2764d7c | ||
|
|
f5d6c250ed | ||
|
|
45daec7541 | ||
|
|
c14a8bb437 | ||
|
|
b76c8fa853 | ||
|
|
8c3e77cd0c | ||
|
|
476946f122 | ||
|
|
62a698a883 | ||
|
|
ebca36ffbb | ||
|
|
aa7fe42615 | ||
|
|
b55c0ec4de | ||
|
|
8b50c0d920 | ||
|
|
47f8de3f8e | ||
|
|
491fa9923b | ||
|
|
ce2c41bbf5 | ||
|
|
920db69ef2 | ||
|
|
ac222a4dd4 | ||
|
|
840a975fef | ||
|
|
9fb72c151c | ||
|
|
603a896c49 | ||
|
|
41177757e6 | ||
|
|
4f826b4641 | ||
|
|
3216b67bfa |
3
.github/CODEOWNERS
vendored
3
.github/CODEOWNERS
vendored
@@ -9,6 +9,9 @@
|
||||
# CODEOWNERS file
|
||||
/.github/CODEOWNERS @laipz8200 @crazywoola
|
||||
|
||||
# Agents
|
||||
/.agents/skills/ @hyoban
|
||||
|
||||
# Docs
|
||||
/docs/ @crazywoola
|
||||
|
||||
|
||||
203
api/commands.py
203
api/commands.py
@@ -1450,54 +1450,58 @@ def clear_orphaned_file_records(force: bool):
|
||||
all_ids_in_tables = []
|
||||
for ids_table in ids_tables:
|
||||
query = ""
|
||||
if ids_table["type"] == "uuid":
|
||||
click.echo(
|
||||
click.style(
|
||||
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", fg="white"
|
||||
match ids_table["type"]:
|
||||
case "uuid":
|
||||
click.echo(
|
||||
click.style(
|
||||
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
)
|
||||
query = (
|
||||
f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
|
||||
elif ids_table["type"] == "text":
|
||||
click.echo(
|
||||
click.style(
|
||||
f"- Listing file-id-like strings in column {ids_table['column']} in table {ids_table['table']}",
|
||||
fg="white",
|
||||
c = ids_table["column"]
|
||||
query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL"
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
|
||||
case "text":
|
||||
t = ids_table["table"]
|
||||
click.echo(
|
||||
click.style(
|
||||
f"- Listing file-id-like strings in column {ids_table['column']} in table {t}",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
)
|
||||
query = (
|
||||
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
|
||||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
elif ids_table["type"] == "json":
|
||||
click.echo(
|
||||
click.style(
|
||||
(
|
||||
f"- Listing file-id-like JSON string in column {ids_table['column']} "
|
||||
f"in table {ids_table['table']}"
|
||||
),
|
||||
fg="white",
|
||||
query = (
|
||||
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
|
||||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
)
|
||||
query = (
|
||||
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
|
||||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
case "json":
|
||||
click.echo(
|
||||
click.style(
|
||||
(
|
||||
f"- Listing file-id-like JSON string in column {ids_table['column']} "
|
||||
f"in table {ids_table['table']}"
|
||||
),
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
query = (
|
||||
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
|
||||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
case _:
|
||||
pass
|
||||
click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white"))
|
||||
|
||||
except Exception as e:
|
||||
@@ -1737,59 +1741,18 @@ def file_usage(
|
||||
if src_filter != src:
|
||||
continue
|
||||
|
||||
if ids_table["type"] == "uuid":
|
||||
# Direct UUID match
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
ref_file_id = str(row[1])
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
|
||||
# Apply filters
|
||||
if file_id and ref_file_id != file_id:
|
||||
continue
|
||||
if key and not storage_key.endswith(key):
|
||||
continue
|
||||
|
||||
# Only collect items within the requested page range
|
||||
if offset <= total_count < offset + limit:
|
||||
paginated_usages.append(
|
||||
{
|
||||
"src": f"{ids_table['table']}.{ids_table['column']}",
|
||||
"record_id": record_id,
|
||||
"file_id": ref_file_id,
|
||||
"key": storage_key,
|
||||
}
|
||||
)
|
||||
total_count += 1
|
||||
|
||||
elif ids_table["type"] in ("text", "json"):
|
||||
# Extract UUIDs from text/json content
|
||||
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {column_cast} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
content = str(row[1])
|
||||
|
||||
# Find all UUIDs in the content
|
||||
import re
|
||||
|
||||
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
|
||||
matches = uuid_pattern.findall(content)
|
||||
|
||||
for ref_file_id in matches:
|
||||
match ids_table["type"]:
|
||||
case "uuid":
|
||||
# Direct UUID match
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
ref_file_id = str(row[1])
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
@@ -1812,6 +1775,50 @@ def file_usage(
|
||||
)
|
||||
total_count += 1
|
||||
|
||||
case "text" | "json":
|
||||
# Extract UUIDs from text/json content
|
||||
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {column_cast} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
content = str(row[1])
|
||||
|
||||
# Find all UUIDs in the content
|
||||
import re
|
||||
|
||||
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
|
||||
matches = uuid_pattern.findall(content)
|
||||
|
||||
for ref_file_id in matches:
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
|
||||
# Apply filters
|
||||
if file_id and ref_file_id != file_id:
|
||||
continue
|
||||
if key and not storage_key.endswith(key):
|
||||
continue
|
||||
|
||||
# Only collect items within the requested page range
|
||||
if offset <= total_count < offset + limit:
|
||||
paginated_usages.append(
|
||||
{
|
||||
"src": f"{ids_table['table']}.{ids_table['column']}",
|
||||
"record_id": record_id,
|
||||
"file_id": ref_file_id,
|
||||
"key": storage_key,
|
||||
}
|
||||
)
|
||||
total_count += 1
|
||||
case _:
|
||||
pass
|
||||
|
||||
# Output results
|
||||
if output_json:
|
||||
result = {
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import abort, make_response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
|
||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@@ -16,9 +17,11 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
annotation_hit_history_fields,
|
||||
build_annotation_model,
|
||||
Annotation,
|
||||
AnnotationExportList,
|
||||
AnnotationHitHistory,
|
||||
AnnotationHitHistoryList,
|
||||
AnnotationList,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
@@ -89,6 +92,14 @@ reg(CreateAnnotationPayload)
|
||||
reg(UpdateAnnotationPayload)
|
||||
reg(AnnotationReplyStatusQuery)
|
||||
reg(AnnotationFilePayload)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
Annotation,
|
||||
AnnotationList,
|
||||
AnnotationExportList,
|
||||
AnnotationHitHistory,
|
||||
AnnotationHitHistoryList,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||
@@ -107,10 +118,11 @@ class AnnotationReplyActionApi(Resource):
|
||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||
app_id = str(app_id)
|
||||
args = AnnotationReplyPayload.model_validate(console_ns.payload)
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||
elif action == "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
match action:
|
||||
case "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
return result, 200
|
||||
|
||||
|
||||
@@ -201,33 +213,33 @@ class AnnotationApi(Resource):
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||
response = {
|
||||
"data": marshal(annotation_list, annotation_fields),
|
||||
"has_more": len(annotation_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response, 200
|
||||
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
|
||||
response = AnnotationList(
|
||||
data=annotation_models,
|
||||
has_more=len(annotation_list) == limit,
|
||||
limit=limit,
|
||||
total=total,
|
||||
page=page,
|
||||
)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
@console_ns.doc("create_annotation")
|
||||
@console_ns.doc(description="Create a new annotation for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
|
||||
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(201, "Annotation created successfully", console_ns.models[Annotation.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
args = CreateAnnotationPayload.model_validate(console_ns.payload)
|
||||
data = args.model_dump(exclude_none=True)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
|
||||
return annotation
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -264,7 +276,7 @@ class AnnotationExportApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Annotations exported successfully",
|
||||
console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}),
|
||||
console_ns.models[AnnotationExportList.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@@ -274,7 +286,8 @@ class AnnotationExportApi(Resource):
|
||||
def get(self, app_id):
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
response_data = {"data": marshal(annotation_list, annotation_fields)}
|
||||
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
|
||||
response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json")
|
||||
|
||||
# Create response with secure headers for CSV export
|
||||
response = make_response(response_data, 200)
|
||||
@@ -289,7 +302,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@console_ns.doc("update_delete_annotation")
|
||||
@console_ns.doc(description="Update or delete an annotation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(200, "Annotation updated successfully", console_ns.models[Annotation.__name__])
|
||||
@console_ns.response(204, "Annotation deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
|
||||
@@ -298,7 +311,6 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
@@ -306,7 +318,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
args.model_dump(exclude_none=True), app_id, annotation_id
|
||||
)
|
||||
return annotation
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -414,14 +426,7 @@ class AnnotationHitHistoryListApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Hit histories retrieved successfully",
|
||||
console_ns.model(
|
||||
"AnnotationHitHistoryList",
|
||||
{
|
||||
"data": fields.List(
|
||||
fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields))
|
||||
)
|
||||
},
|
||||
),
|
||||
console_ns.models[AnnotationHitHistoryList.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@@ -436,11 +441,14 @@ class AnnotationHitHistoryListApi(Resource):
|
||||
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
|
||||
app_id, annotation_id, page, limit
|
||||
)
|
||||
response = {
|
||||
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields),
|
||||
"has_more": len(annotation_hit_history_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response
|
||||
history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python(
|
||||
annotation_hit_history_list, from_attributes=True
|
||||
)
|
||||
response = AnnotationHitHistoryList(
|
||||
data=history_models,
|
||||
has_more=len(annotation_hit_history_list) == limit,
|
||||
limit=limit,
|
||||
total=total,
|
||||
page=page,
|
||||
)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
@@ -33,7 +34,6 @@ from services.errors.audio import (
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class TextToSpeechPayload(BaseModel):
|
||||
@@ -47,13 +47,11 @@ class TextToSpeechVoiceQuery(BaseModel):
|
||||
language: str = Field(..., description="Language code")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TextToSpeechVoiceQuery.__name__,
|
||||
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
class AudioTranscriptResponse(BaseModel):
|
||||
text: str = Field(description="Transcribed text from audio")
|
||||
|
||||
|
||||
register_schema_models(console_ns, AudioTranscriptResponse, TextToSpeechPayload, TextToSpeechVoiceQuery)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
|
||||
@@ -64,7 +62,7 @@ class ChatMessageAudioApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Audio transcription successful",
|
||||
console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
|
||||
console_ns.models[AudioTranscriptResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Bad request - No audio uploaded or unsupported type")
|
||||
@console_ns.response(413, "Audio file too large")
|
||||
|
||||
@@ -508,16 +508,19 @@ class ChatConversationApi(Resource):
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at <= end_datetime_utc)
|
||||
|
||||
if args.annotation_status == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
match args.annotation_status:
|
||||
case "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
case "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
case "all":
|
||||
pass
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||
|
||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
@@ -35,7 +36,6 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ChatMessagesQuery(BaseModel):
|
||||
@@ -90,13 +90,22 @@ class FeedbackExportQuery(BaseModel):
|
||||
raise ValueError("has_comment must be a boolean value")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class AnnotationCountResponse(BaseModel):
|
||||
count: int = Field(description="Number of annotations")
|
||||
|
||||
|
||||
reg(ChatMessagesQuery)
|
||||
reg(MessageFeedbackPayload)
|
||||
reg(FeedbackExportQuery)
|
||||
class SuggestedQuestionsResponse(BaseModel):
|
||||
data: list[str] = Field(description="Suggested question")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ChatMessagesQuery,
|
||||
MessageFeedbackPayload,
|
||||
FeedbackExportQuery,
|
||||
AnnotationCountResponse,
|
||||
SuggestedQuestionsResponse,
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
@@ -231,7 +240,7 @@ class ChatMessageListApi(Resource):
|
||||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
@@ -356,7 +365,7 @@ class MessageAnnotationCountApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Annotation count retrieved successfully",
|
||||
console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
|
||||
console_ns.models[AnnotationCountResponse.__name__],
|
||||
)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@@ -376,9 +385,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Suggested questions retrieved successfully",
|
||||
console_ns.model(
|
||||
"SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
|
||||
),
|
||||
console_ns.models[SuggestedQuestionsResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Message or conversation not found")
|
||||
@setup_required
|
||||
@@ -428,7 +435,7 @@ class MessageFeedbackExportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Import the service function
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
@@ -2,9 +2,11 @@ import logging
|
||||
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
@@ -14,6 +16,26 @@ from ..wraps import account_initialization_required, is_admin_or_owner_required,
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthDataSourceResponse(BaseModel):
|
||||
data: str = Field(description="Authorization URL or 'internal' for internal setup")
|
||||
|
||||
|
||||
class OAuthDataSourceBindingResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
class OAuthDataSourceSyncResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
OAuthDataSourceResponse,
|
||||
OAuthDataSourceBindingResponse,
|
||||
OAuthDataSourceSyncResponse,
|
||||
)
|
||||
|
||||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
notion_oauth = NotionOAuth(
|
||||
@@ -34,10 +56,7 @@ class OAuthDataSource(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Authorization URL or internal setup success",
|
||||
console_ns.model(
|
||||
"OAuthDataSourceResponse",
|
||||
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
|
||||
),
|
||||
console_ns.models[OAuthDataSourceResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@@ -101,7 +120,7 @@ class OAuthDataSourceBinding(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Data source binding success",
|
||||
console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.models[OAuthDataSourceBindingResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid provider or code")
|
||||
def get(self, provider: str):
|
||||
@@ -133,7 +152,7 @@ class OAuthDataSourceSync(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Data source sync success",
|
||||
console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.models[OAuthDataSourceSyncResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid provider or sync failed")
|
||||
@setup_required
|
||||
|
||||
@@ -2,10 +2,11 @@ import base64
|
||||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
@@ -48,8 +49,31 @@ class ForgotPasswordResetPayload(BaseModel):
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class ForgotPasswordEmailResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
data: str | None = Field(default=None, description="Reset token")
|
||||
code: str | None = Field(default=None, description="Error code if account not found")
|
||||
|
||||
|
||||
class ForgotPasswordCheckResponse(BaseModel):
|
||||
is_valid: bool = Field(description="Whether code is valid")
|
||||
email: EmailStr = Field(description="Email address")
|
||||
token: str = Field(description="New reset token")
|
||||
|
||||
|
||||
class ForgotPasswordResetResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ForgotPasswordSendPayload,
|
||||
ForgotPasswordCheckPayload,
|
||||
ForgotPasswordResetPayload,
|
||||
ForgotPasswordEmailResponse,
|
||||
ForgotPasswordCheckResponse,
|
||||
ForgotPasswordResetResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password")
|
||||
@@ -60,14 +84,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Email sent successfully",
|
||||
console_ns.model(
|
||||
"ForgotPasswordEmailResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
"data": fields.String(description="Reset token"),
|
||||
"code": fields.String(description="Error code if account not found"),
|
||||
},
|
||||
),
|
||||
console_ns.models[ForgotPasswordEmailResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid email or rate limit exceeded")
|
||||
@setup_required
|
||||
@@ -106,14 +123,7 @@ class ForgotPasswordCheckApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Code verified successfully",
|
||||
console_ns.model(
|
||||
"ForgotPasswordCheckResponse",
|
||||
{
|
||||
"is_valid": fields.Boolean(description="Whether code is valid"),
|
||||
"email": fields.String(description="Email address"),
|
||||
"token": fields.String(description="New reset token"),
|
||||
},
|
||||
),
|
||||
console_ns.models[ForgotPasswordCheckResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid code or token")
|
||||
@setup_required
|
||||
@@ -163,7 +173,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Password reset successfully",
|
||||
console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.models[ForgotPasswordResetResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid token or password mismatch")
|
||||
@setup_required
|
||||
|
||||
@@ -155,43 +155,43 @@ class OAuthServerUserTokenApi(Resource):
|
||||
grant_type = OAuthGrantType(payload.grant_type)
|
||||
except ValueError:
|
||||
raise BadRequest("invalid grant_type")
|
||||
match grant_type:
|
||||
case OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not payload.code:
|
||||
raise BadRequest("code is required")
|
||||
|
||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not payload.code:
|
||||
raise BadRequest("code is required")
|
||||
if payload.client_secret != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
|
||||
if payload.client_secret != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
case OAuthGrantType.REFRESH_TOKEN:
|
||||
if not payload.refresh_token:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
||||
if not payload.refresh_token:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/provider/account")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
@@ -157,9 +157,8 @@ class DataSourceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, binding_id, action):
|
||||
def patch(self, binding_id, action: Literal["enable", "disable"]):
|
||||
binding_id = str(binding_id)
|
||||
action = str(action)
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id)
|
||||
@@ -167,23 +166,24 @@ class DataSourceApi(Resource):
|
||||
if data_source_binding is None:
|
||||
raise NotFound("Data source binding not found.")
|
||||
# enable binding
|
||||
if action == "enable":
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is not disabled.")
|
||||
# disable binding
|
||||
if action == "disable":
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is disabled.")
|
||||
match action:
|
||||
case "enable":
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is not disabled.")
|
||||
# disable binding
|
||||
case "disable":
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is disabled.")
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
|
||||
@@ -576,63 +576,62 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
data_source_info = document.data_source_info_dict
|
||||
match document.data_source_type:
|
||||
case "upload_file":
|
||||
if not data_source_info:
|
||||
continue
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if document.data_source_type == "upload_file":
|
||||
if not data_source_info:
|
||||
continue
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
if file_detail is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
if file_detail is None:
|
||||
raise NotFound("File not found.")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
case "notion_import":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
case "website_crawl":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
elif document.data_source_type == "notion_import":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif document.data_source_type == "website_crawl":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
else:
|
||||
raise ValueError("Data source type not support")
|
||||
case _:
|
||||
raise ValueError("Data source type not support")
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(
|
||||
@@ -954,23 +953,24 @@ class DocumentProcessingApi(DocumentResource):
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if action == "pause":
|
||||
if document.indexing_status != "indexing":
|
||||
raise InvalidActionError("Document not in indexing state.")
|
||||
match action:
|
||||
case "pause":
|
||||
if document.indexing_status != "indexing":
|
||||
raise InvalidActionError("Document not in indexing state.")
|
||||
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = naive_utc_now()
|
||||
document.is_paused = True
|
||||
db.session.commit()
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = naive_utc_now()
|
||||
document.is_paused = True
|
||||
db.session.commit()
|
||||
|
||||
elif action == "resume":
|
||||
if document.indexing_status not in {"paused", "error"}:
|
||||
raise InvalidActionError("Document not in paused or error state.")
|
||||
case "resume":
|
||||
if document.indexing_status not in {"paused", "error"}:
|
||||
raise InvalidActionError("Document not in paused or error state.")
|
||||
|
||||
document.paused_by = None
|
||||
document.paused_at = None
|
||||
document.is_paused = False
|
||||
db.session.commit()
|
||||
document.paused_by = None
|
||||
document.paused_at = None
|
||||
document.is_paused = False
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@@ -1339,6 +1339,18 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
missing_ids = set(document_list) - found_ids
|
||||
raise NotFound(f"Some documents not found: {list(missing_ids)}")
|
||||
|
||||
# Update need_summary to True for documents that don't have it set
|
||||
# This handles the case where documents were created when summary_index_setting was disabled
|
||||
documents_to_update = [doc for doc in documents if not doc.need_summary and doc.doc_form != "qa_model"]
|
||||
|
||||
if documents_to_update:
|
||||
document_ids_to_update = [str(doc.id) for doc in documents_to_update]
|
||||
DocumentService.update_documents_need_summary(
|
||||
dataset_id=dataset_id,
|
||||
document_ids=document_ids_to_update,
|
||||
need_summary=True,
|
||||
)
|
||||
|
||||
# Dispatch async tasks for each document
|
||||
for document in documents:
|
||||
# Skip qa_model documents as they don't generate summaries
|
||||
|
||||
@@ -126,10 +126,11 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
if action == "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
elif action == "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
match action:
|
||||
case "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@@ -51,7 +52,7 @@ from fields.app_fields import (
|
||||
tag_fields,
|
||||
)
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from fields.member_fields import build_simple_account_model
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_fields import (
|
||||
conversation_variable_fields,
|
||||
pipeline_variable_fields,
|
||||
@@ -103,7 +104,7 @@ app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model))
|
||||
app_detail_fields_with_site_copy["site"] = fields.Nested(site_model)
|
||||
app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy)
|
||||
|
||||
simple_account_model = build_simple_account_model(console_ns)
|
||||
simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields)
|
||||
conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields)
|
||||
pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields)
|
||||
|
||||
@@ -117,7 +118,56 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel
|
||||
workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy)
|
||||
|
||||
|
||||
# Pydantic models for request validation
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowRunRequest(BaseModel):
|
||||
inputs: dict
|
||||
files: list | None = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
inputs: dict
|
||||
query: str
|
||||
files: list | None = None
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
class TextToSpeechRequest(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
inputs: dict
|
||||
query: str = ""
|
||||
files: list | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
# Register schemas for Swagger documentation
|
||||
console_ns.schema_model(
|
||||
WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
|
||||
def post(self, trial_app):
|
||||
"""
|
||||
Run workflow
|
||||
@@ -129,10 +179,8 @@ class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = WorkflowRunRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
assert current_user is not None
|
||||
try:
|
||||
app_id = app_model.id
|
||||
@@ -183,6 +231,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialChatApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[ChatRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
@@ -190,14 +239,14 @@ class TrialChatApi(TrialAppResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = ChatRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
|
||||
# Validate UUID values if provided
|
||||
if args.get("conversation_id"):
|
||||
args["conversation_id"] = uuid_value(args["conversation_id"])
|
||||
if args.get("parent_message_id"):
|
||||
args["parent_message_id"] = uuid_value(args["parent_message_id"])
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@@ -320,20 +369,16 @@ class TrialChatAudioApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialChatTextApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[TextToSpeechRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = TextToSpeechRequest.model_validate(console_ns.payload)
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
message_id = request_data.message_id
|
||||
text = request_data.text
|
||||
voice = request_data.voice
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
@@ -371,19 +416,15 @@ class TrialChatTextApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialCompletionApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = CompletionRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@@ -1,58 +1,60 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.fastopenapi import console_router
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
||||
|
||||
|
||||
class FeatureResponse(BaseModel):
|
||||
features: FeatureModel = Field(description="Feature configuration object")
|
||||
@console_ns.route("/features")
|
||||
class FeatureApi(Resource):
|
||||
@console_ns.doc("get_tenant_features")
|
||||
@console_ns.doc(description="Get feature configuration for current tenant")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}),
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
def get(self):
|
||||
"""Get feature configuration for current tenant"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
return FeatureService.get_features(current_tenant_id).model_dump()
|
||||
|
||||
|
||||
class SystemFeatureResponse(BaseModel):
|
||||
features: SystemFeatureModel = Field(description="System feature configuration object")
|
||||
@console_ns.route("/system-features")
|
||||
class SystemFeatureApi(Resource):
|
||||
@console_ns.doc("get_system_features")
|
||||
@console_ns.doc(description="Get system-wide feature configuration")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
"""Get system-wide feature configuration
|
||||
|
||||
NOTE: This endpoint is unauthenticated by design, as it provides system features
|
||||
data required for dashboard initialization.
|
||||
|
||||
@console_router.get(
|
||||
"/features",
|
||||
response_model=FeatureResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
def get_tenant_features() -> FeatureResponse:
|
||||
"""Get feature configuration for current tenant."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
Authentication would create circular dependency (can't login without dashboard loading).
|
||||
|
||||
return FeatureResponse(features=FeatureService.get_features(current_tenant_id))
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/system-features",
|
||||
response_model=SystemFeatureResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
def get_system_features() -> SystemFeatureResponse:
|
||||
"""Get system-wide feature configuration
|
||||
|
||||
NOTE: This endpoint is unauthenticated by design, as it provides system features
|
||||
data required for dashboard initialization.
|
||||
|
||||
Authentication would create circular dependency (can't login without dashboard loading).
|
||||
|
||||
Only non-sensitive configuration data should be returned by this endpoint.
|
||||
"""
|
||||
# NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
|
||||
# without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
|
||||
# in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
|
||||
# raise `Unauthorized` exception if authentication token is not provided.
|
||||
try:
|
||||
is_authenticated = current_user.is_authenticated
|
||||
except Unauthorized:
|
||||
is_authenticated = False
|
||||
return SystemFeatureResponse(features=FeatureService.get_system_features(is_authenticated=is_authenticated))
|
||||
Only non-sensitive configuration data should be returned by this endpoint.
|
||||
"""
|
||||
# NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
|
||||
# without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
|
||||
# in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
|
||||
# raise `Unauthorized` exception if authentication token is not provided.
|
||||
try:
|
||||
is_authenticated = current_user.is_authenticated
|
||||
except Unauthorized:
|
||||
is_authenticated = False
|
||||
return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump()
|
||||
|
||||
@@ -1,14 +1,27 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.fastopenapi import console_router
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.tag_service import TagService
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"type": fields.String,
|
||||
"binding_count": fields.String,
|
||||
}
|
||||
|
||||
|
||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||
|
||||
|
||||
class TagBasePayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
@@ -32,129 +45,115 @@ class TagListQueryParam(BaseModel):
|
||||
keyword: str | None = Field(None, description="Search keyword")
|
||||
|
||||
|
||||
class TagResponse(BaseModel):
|
||||
id: str = Field(description="Tag ID")
|
||||
name: str = Field(description="Tag name")
|
||||
type: str = Field(description="Tag type")
|
||||
binding_count: int = Field(description="Number of bindings")
|
||||
|
||||
|
||||
class TagBindingResult(BaseModel):
|
||||
result: Literal["success"] = Field(description="Operation result", examples=["success"])
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/tags",
|
||||
response_model=list[TagResponse],
|
||||
tags=["console"],
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
TagListQueryParam,
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def list_tags(query: TagListQueryParam) -> list[TagResponse]:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tags = TagService.get_tags(query.type, current_tenant_id, query.keyword)
|
||||
|
||||
return [
|
||||
TagResponse(
|
||||
id=tag.id,
|
||||
name=tag.name,
|
||||
type=tag.type,
|
||||
binding_count=int(tag.binding_count),
|
||||
)
|
||||
for tag in tags
|
||||
]
|
||||
|
||||
|
||||
@console_router.post(
|
||||
"/tags",
|
||||
response_model=TagResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def create_tag(payload: TagBasePayload) -> TagResponse:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the tag table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
@console_ns.route("/tags")
|
||||
class TagListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.doc(
|
||||
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
||||
)
|
||||
@marshal_with(dataset_tag_fields)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
raw_args = request.args.to_dict()
|
||||
param = TagListQueryParam.model_validate(raw_args)
|
||||
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
|
||||
|
||||
tag = TagService.save_tags(payload.model_dump())
|
||||
return tags, 200
|
||||
|
||||
return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=0)
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.save_tags(payload.model_dump())
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
@console_router.patch(
|
||||
"/tags/<uuid:tag_id>",
|
||||
response_model=TagResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def update_tag(tag_id: UUID, payload: TagBasePayload) -> TagResponse:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tag_id_str = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
@console_ns.route("/tags/<uuid:tag_id>")
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
tag = TagService.update_tags(payload.model_dump(), tag_id_str)
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(payload.model_dump(), tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id_str)
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=binding_count)
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
@console_router.delete(
|
||||
"/tags/<uuid:tag_id>",
|
||||
tags=["console"],
|
||||
status_code=204,
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete_tag(tag_id: UUID) -> None:
|
||||
tag_id_str = str(tag_id)
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag(tag_id_str)
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(payload.model_dump())
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_router.post(
|
||||
"/tag-bindings/create",
|
||||
response_model=TagBindingResult,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def create_tag_binding(payload: TagBindingPayload) -> TagBindingResult:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the tag table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
TagService.save_tag_binding(payload.model_dump())
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(payload.model_dump())
|
||||
|
||||
return TagBindingResult(result="success")
|
||||
|
||||
|
||||
@console_router.post(
|
||||
"/tag-bindings/remove",
|
||||
response_model=TagBindingResult,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete_tag_binding(payload: TagBindingRemovePayload) -> TagBindingResult:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the tag table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag_binding(payload.model_dump())
|
||||
|
||||
return TagBindingResult(result="success")
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
@@ -37,7 +38,7 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from fields.member_fields import Account as AccountResponse
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@@ -170,6 +171,12 @@ reg(ChangeEmailSendPayload)
|
||||
reg(ChangeEmailValidityPayload)
|
||||
reg(ChangeEmailResetPayload)
|
||||
reg(CheckEmailUniquePayload)
|
||||
register_schema_models(console_ns, AccountResponse)
|
||||
|
||||
|
||||
def _serialize_account(account) -> dict:
|
||||
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
integrate_fields = {
|
||||
"provider": fields.String,
|
||||
@@ -236,11 +243,11 @@ class AccountProfileApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
return current_user
|
||||
return _serialize_account(current_user)
|
||||
|
||||
|
||||
@console_ns.route("/account/name")
|
||||
@@ -249,14 +256,14 @@ class AccountNameApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountNamePayload.model_validate(payload)
|
||||
updated_account = AccountService.update_account(current_user, name=args.name)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
@@ -265,7 +272,7 @@ class AccountAvatarApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@@ -273,7 +280,7 @@ class AccountAvatarApi(Resource):
|
||||
|
||||
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-language")
|
||||
@@ -282,7 +289,7 @@ class AccountInterfaceLanguageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@@ -290,7 +297,7 @@ class AccountInterfaceLanguageApi(Resource):
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-theme")
|
||||
@@ -299,7 +306,7 @@ class AccountInterfaceThemeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@@ -307,7 +314,7 @@ class AccountInterfaceThemeApi(Resource):
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/timezone")
|
||||
@@ -316,7 +323,7 @@ class AccountTimezoneApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@@ -324,7 +331,7 @@ class AccountTimezoneApi(Resource):
|
||||
|
||||
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/password")
|
||||
@@ -333,7 +340,7 @@ class AccountPasswordApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@@ -344,7 +351,7 @@ class AccountPasswordApi(Resource):
|
||||
except ServiceCurrentPasswordIncorrectError:
|
||||
raise CurrentPasswordIncorrectError()
|
||||
|
||||
return {"result": "success"}
|
||||
return _serialize_account(current_user)
|
||||
|
||||
|
||||
@console_ns.route("/account/integrates")
|
||||
@@ -620,7 +627,7 @@ class ChangeEmailResetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailResetPayload.model_validate(payload)
|
||||
@@ -649,7 +656,7 @@ class ChangeEmailResetApi(Resource):
|
||||
email=normalized_new_email,
|
||||
)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/check-email-unique")
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@@ -38,15 +39,53 @@ class EndpointListForPluginQuery(EndpointListQuery):
|
||||
plugin_id: str
|
||||
|
||||
|
||||
class EndpointCreateResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
class EndpointListResponse(BaseModel):
|
||||
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
|
||||
|
||||
|
||||
class PluginEndpointListResponse(BaseModel):
|
||||
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
|
||||
|
||||
|
||||
class EndpointDeleteResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
class EndpointUpdateResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
class EndpointEnableResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
class EndpointDisableResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(EndpointCreatePayload)
|
||||
reg(EndpointIdPayload)
|
||||
reg(EndpointUpdatePayload)
|
||||
reg(EndpointListQuery)
|
||||
reg(EndpointListForPluginQuery)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
EndpointCreatePayload,
|
||||
EndpointIdPayload,
|
||||
EndpointUpdatePayload,
|
||||
EndpointListQuery,
|
||||
EndpointListForPluginQuery,
|
||||
EndpointCreateResponse,
|
||||
EndpointListResponse,
|
||||
PluginEndpointListResponse,
|
||||
EndpointDeleteResponse,
|
||||
EndpointUpdateResponse,
|
||||
EndpointEnableResponse,
|
||||
EndpointDisableResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
@@ -57,7 +96,7 @@ class EndpointCreateApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint created successfully",
|
||||
console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointCreateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@@ -91,9 +130,7 @@ class EndpointListApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
|
||||
),
|
||||
console_ns.models[EndpointListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -126,9 +163,7 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
|
||||
),
|
||||
console_ns.models[PluginEndpointListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -163,7 +198,7 @@ class EndpointDeleteApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint deleted successfully",
|
||||
console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointDeleteResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@@ -190,7 +225,7 @@ class EndpointUpdateApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointUpdateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@@ -221,7 +256,7 @@ class EndpointEnableApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint enabled successfully",
|
||||
console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointEnableResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@@ -248,7 +283,7 @@ class EndpointDisableApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint disabled successfully",
|
||||
console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointDisableResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from urllib import parse
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import get_or_create_model, register_enum_models
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
CannotTransferOwnerToSelfError,
|
||||
@@ -25,7 +25,7 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_with_role_fields, account_with_role_list_fields
|
||||
from fields.member_fields import AccountWithRole, AccountWithRoleList
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Account, TenantAccountRole
|
||||
@@ -69,12 +69,7 @@ reg(OwnerTransferEmailPayload)
|
||||
reg(OwnerTransferCheckPayload)
|
||||
reg(OwnerTransferPayload)
|
||||
register_enum_models(console_ns, TenantAccountRole)
|
||||
|
||||
account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields)
|
||||
|
||||
account_with_role_list_fields_copy = account_with_role_list_fields.copy()
|
||||
account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model))
|
||||
account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy)
|
||||
register_schema_models(console_ns, AccountWithRole, AccountWithRoleList)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members")
|
||||
@@ -84,13 +79,15 @@ class MemberListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
return {"result": "success", "accounts": members}, 200
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = AccountWithRoleList(accounts=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/invite-email")
|
||||
@@ -235,13 +232,15 @@ class DatasetOperatorMemberListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||
return {"result": "success", "accounts": members}, 200
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = AccountWithRoleList(accounts=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from flask_restx import Resource
|
||||
from flask_restx.api import HTTPStatus
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import annotation_fields, build_annotation_model
|
||||
from fields.annotation_fields import Annotation, AnnotationList
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
@@ -26,7 +26,9 @@ class AnnotationReplyActionPayload(BaseModel):
|
||||
embedding_model_name: str = Field(description="Embedding model name")
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload)
|
||||
register_schema_models(
|
||||
service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload, Annotation, AnnotationList
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotation-reply/<string:action>")
|
||||
@@ -45,10 +47,11 @@ class AnnotationReplyActionApi(Resource):
|
||||
def post(self, app_model: App, action: Literal["enable", "disable"]):
|
||||
"""Enable or disable annotation reply feature."""
|
||||
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
|
||||
elif action == "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_model.id)
|
||||
match action:
|
||||
case "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_model.id)
|
||||
return result, 200
|
||||
|
||||
|
||||
@@ -82,23 +85,6 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
|
||||
|
||||
# Define annotation list response model
|
||||
annotation_list_fields = {
|
||||
"data": fields.List(fields.Nested(annotation_fields)),
|
||||
"has_more": fields.Boolean,
|
||||
"limit": fields.Integer,
|
||||
"total": fields.Integer,
|
||||
"page": fields.Integer,
|
||||
}
|
||||
|
||||
|
||||
def build_annotation_list_model(api_or_ns: Namespace):
|
||||
"""Build the annotation list model for the API or Namespace."""
|
||||
copied_annotation_list_fields = annotation_list_fields.copy()
|
||||
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
|
||||
return api_or_ns.model("AnnotationList", copied_annotation_list_fields)
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotations")
|
||||
class AnnotationListApi(Resource):
|
||||
@service_api_ns.doc("list_annotations")
|
||||
@@ -109,8 +95,12 @@ class AnnotationListApi(Resource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Annotations retrieved successfully",
|
||||
service_api_ns.models[AnnotationList.__name__],
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_annotation_list_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""List annotations for the application."""
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
@@ -118,13 +108,15 @@ class AnnotationListApi(Resource):
|
||||
keyword = request.args.get("keyword", default="", type=str)
|
||||
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
|
||||
return {
|
||||
"data": annotation_list,
|
||||
"has_more": len(annotation_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
|
||||
response = AnnotationList(
|
||||
data=annotation_models,
|
||||
has_more=len(annotation_list) == limit,
|
||||
limit=limit,
|
||||
total=total,
|
||||
page=page,
|
||||
)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_annotation")
|
||||
@@ -135,13 +127,18 @@ class AnnotationListApi(Resource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
HTTPStatus.CREATED,
|
||||
"Annotation created successfully",
|
||||
service_api_ns.models[Annotation.__name__],
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new annotation."""
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
|
||||
return annotation, 201
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json"), HTTPStatus.CREATED
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
|
||||
@@ -158,14 +155,19 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
404: "Annotation not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Annotation updated successfully",
|
||||
service_api_ns.models[Annotation.__name__],
|
||||
)
|
||||
@validate_app_token
|
||||
@edit_permission_required
|
||||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
|
||||
def put(self, app_model: App, annotation_id: str):
|
||||
"""Update an existing annotation."""
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
||||
return annotation
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@service_api_ns.doc("delete_annotation")
|
||||
@service_api_ns.doc(description="Delete an annotation")
|
||||
|
||||
@@ -17,7 +17,7 @@ from controllers.service_api.wraps import (
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from fields.tag_fields import DataSetTag
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
@@ -114,6 +114,7 @@ register_schema_models(
|
||||
TagBindingPayload,
|
||||
TagUnbindingPayload,
|
||||
DatasetListQuery,
|
||||
DataSetTag,
|
||||
)
|
||||
|
||||
|
||||
@@ -480,15 +481,14 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def get(self, _):
|
||||
"""Get all knowledge type tags."""
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
tags = TagService.get_tags("knowledge", cid)
|
||||
|
||||
return tags, 200
|
||||
tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True)
|
||||
return [tag.model_dump(mode="json") for tag in tag_models], 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset_tag")
|
||||
@@ -500,7 +500,6 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def post(self, _):
|
||||
"""Add a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
@@ -510,7 +509,9 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
response = DataSetTag.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
).model_dump(mode="json")
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
|
||||
@@ -523,7 +524,6 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def patch(self, _):
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@@ -536,8 +536,9 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
response = DataSetTag.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
).model_dump(mode="json")
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
|
||||
|
||||
@@ -168,10 +168,11 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
if action == "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
elif action == "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
match action:
|
||||
case "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
|
||||
@@ -73,14 +73,14 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
|
||||
# If caller needs end-user context, attach EndUser to current_user
|
||||
if fetch_user_arg:
|
||||
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
|
||||
user_id = request.args.get("user")
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
|
||||
user_id = request.get_json().get("user")
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
|
||||
user_id = request.form.get("user")
|
||||
else:
|
||||
user_id = None
|
||||
user_id = None
|
||||
match fetch_user_arg.fetch_from:
|
||||
case WhereisUserArg.QUERY:
|
||||
user_id = request.args.get("user")
|
||||
case WhereisUserArg.JSON:
|
||||
user_id = request.get_json().get("user")
|
||||
case WhereisUserArg.FORM:
|
||||
user_id = request.form.get("user")
|
||||
|
||||
if not user_id and fetch_user_arg.required:
|
||||
raise ValueError("Arg user must be provided.")
|
||||
|
||||
@@ -14,16 +14,17 @@ class AgentConfigManager:
|
||||
agent_dict = config.get("agent_mode", {})
|
||||
agent_strategy = agent_dict.get("strategy", "cot")
|
||||
|
||||
if agent_strategy == "function_call":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
elif agent_strategy in {"cot", "react"}:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
else:
|
||||
# old configs, try to detect default strategy
|
||||
if config["model"]["provider"] == "openai":
|
||||
match agent_strategy:
|
||||
case "function_call":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
else:
|
||||
case "cot" | "react":
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
case _:
|
||||
# old configs, try to detect default strategy
|
||||
if config["model"]["provider"] == "openai":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
else:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get("tools", []):
|
||||
|
||||
@@ -250,7 +250,7 @@ class WorkflowResponseConverter:
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=run_id,
|
||||
workflow_id=workflow_id,
|
||||
status=status.value,
|
||||
status=status,
|
||||
outputs=encoded_outputs,
|
||||
error=error,
|
||||
elapsed_time=elapsed_time,
|
||||
@@ -340,13 +340,13 @@ class WorkflowResponseConverter:
|
||||
metadata = self._merge_metadata(event.execution_metadata, snapshot)
|
||||
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
error_message = event.error
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
status = WorkflowNodeExecutionStatus.FAILED
|
||||
error_message = event.error
|
||||
else:
|
||||
status = WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
status = WorkflowNodeExecutionStatus.EXCEPTION
|
||||
error_message = event.error
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
@@ -413,7 +413,7 @@ class WorkflowResponseConverter:
|
||||
process_data_truncated=process_data_truncated,
|
||||
outputs=outputs,
|
||||
outputs_truncated=outputs_truncated,
|
||||
status=WorkflowNodeExecutionStatus.RETRY.value,
|
||||
status=WorkflowNodeExecutionStatus.RETRY,
|
||||
error=event.error,
|
||||
elapsed_time=elapsed_time,
|
||||
execution_metadata=metadata,
|
||||
|
||||
@@ -120,7 +120,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
start_node_id: str = args["start_node_id"]
|
||||
datasource_type: str = args["datasource_type"]
|
||||
datasource_type = DatasourceProviderType(args["datasource_type"])
|
||||
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
|
||||
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
|
||||
)
|
||||
@@ -660,7 +660,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
built_in_field_enabled: bool,
|
||||
datasource_type: str,
|
||||
datasource_type: DatasourceProviderType,
|
||||
datasource_info: Mapping[str, Any],
|
||||
created_from: str,
|
||||
position: int,
|
||||
@@ -668,17 +668,17 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
batch: str,
|
||||
document_form: str,
|
||||
):
|
||||
if datasource_type == "local_file":
|
||||
name = datasource_info.get("name", "untitled")
|
||||
elif datasource_type == "online_document":
|
||||
name = datasource_info.get("page", {}).get("page_name", "untitled")
|
||||
elif datasource_type == "website_crawl":
|
||||
name = datasource_info.get("title", "untitled")
|
||||
elif datasource_type == "online_drive":
|
||||
name = datasource_info.get("name", "untitled")
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
name = datasource_info.get("name", "untitled")
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
name = datasource_info.get("page", {}).get("page_name", "untitled")
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
name = datasource_info.get("title", "untitled")
|
||||
case DatasourceProviderType.ONLINE_DRIVE:
|
||||
name = datasource_info.get("name", "untitled")
|
||||
case _:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
document = Document(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
@@ -706,7 +706,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
|
||||
def _format_datasource_info_list(
|
||||
self,
|
||||
datasource_type: str,
|
||||
datasource_type: DatasourceProviderType,
|
||||
datasource_info_list: list[Mapping[str, Any]],
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
@@ -716,7 +716,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
"""
|
||||
Format datasource info list.
|
||||
"""
|
||||
if datasource_type == "online_drive":
|
||||
if datasource_type == DatasourceProviderType.ONLINE_DRIVE:
|
||||
all_files: list[Mapping[str, Any]] = []
|
||||
datasource_node_data = None
|
||||
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
||||
|
||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
@@ -223,7 +223,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
status: WorkflowExecutionStatus
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
@@ -311,7 +311,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
process_data_truncated: bool = False
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
outputs_truncated: bool = True
|
||||
status: str
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
@@ -375,7 +375,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
process_data_truncated: bool = False
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
outputs_truncated: bool = False
|
||||
status: str
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
@@ -719,7 +719,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
status: WorkflowExecutionStatus
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
|
||||
@@ -369,77 +369,78 @@ class IndexingRunner:
|
||||
# Generate summary preview
|
||||
summary_index_setting = tmp_processing_rule.get("summary_index_setting")
|
||||
if summary_index_setting and summary_index_setting.get("enable") and preview_texts:
|
||||
preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
|
||||
preview_texts = index_processor.generate_summary_preview(
|
||||
tenant_id, preview_texts, summary_index_setting, doc_language
|
||||
)
|
||||
|
||||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
|
||||
|
||||
def _extract(
|
||||
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
|
||||
) -> list[Document]:
|
||||
# load file
|
||||
if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
|
||||
return []
|
||||
|
||||
data_source_info = dataset_document.data_source_info_dict
|
||||
text_docs = []
|
||||
if dataset_document.data_source_type == "upload_file":
|
||||
if not data_source_info or "upload_file_id" not in data_source_info:
|
||||
raise ValueError("no upload file found")
|
||||
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
|
||||
file_detail = db.session.scalars(stmt).one_or_none()
|
||||
match dataset_document.data_source_type:
|
||||
case "upload_file":
|
||||
if not data_source_info or "upload_file_id" not in data_source_info:
|
||||
raise ValueError("no upload file found")
|
||||
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
|
||||
file_detail = db.session.scalars(stmt).one_or_none()
|
||||
|
||||
if file_detail:
|
||||
if file_detail:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
case "notion_import":
|
||||
if (
|
||||
not data_source_info
|
||||
or "notion_workspace_id" not in data_source_info
|
||||
or "notion_page_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no notion import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"document": dataset_document,
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
elif dataset_document.data_source_type == "notion_import":
|
||||
if (
|
||||
not data_source_info
|
||||
or "notion_workspace_id" not in data_source_info
|
||||
or "notion_page_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no notion import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"document": dataset_document,
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
elif dataset_document.data_source_type == "website_crawl":
|
||||
if (
|
||||
not data_source_info
|
||||
or "provider" not in data_source_info
|
||||
or "url" not in data_source_info
|
||||
or "job_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no website import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
"url": data_source_info["url"],
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
case "website_crawl":
|
||||
if (
|
||||
not data_source_info
|
||||
or "provider" not in data_source_info
|
||||
or "url" not in data_source_info
|
||||
or "job_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no website import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
"url": data_source_info["url"],
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
case _:
|
||||
return []
|
||||
# update document status to splitting
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
|
||||
@@ -441,11 +441,13 @@ DEFAULT_GENERATOR_SUMMARY_PROMPT = (
|
||||
|
||||
Requirements:
|
||||
1. Write a concise summary in plain text
|
||||
2. Use the same language as the input content
|
||||
2. You must write in {language}. No language other than {language} should be used.
|
||||
3. Focus on important facts, concepts, and details
|
||||
4. If images are included, describe their key information
|
||||
5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions"
|
||||
6. Write directly without extra words
|
||||
7. If there is not enough content to generate a meaningful summary,
|
||||
return an empty string without any explanation or prompt
|
||||
|
||||
Output only the summary text. Start summarizing now:
|
||||
|
||||
|
||||
@@ -48,12 +48,22 @@ class BaseIndexProcessor(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
|
||||
The summary can be stored in a new attribute, e.g., summary.
|
||||
This method should be implemented by subclasses.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
preview_texts: List of preview details to generate summaries for
|
||||
summary_index_setting: Summary index configuration
|
||||
doc_language: Optional document language to ensure summary is generated in the correct language
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -275,7 +275,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError("Chunks is not a list")
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment, concurrently call generate_summary to generate a summary
|
||||
@@ -298,11 +302,15 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
if flask_app:
|
||||
# Ensure Flask app context in worker thread
|
||||
with flask_app.app_context():
|
||||
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
summary, _ = self.generate_summary(
|
||||
tenant_id, preview.content, summary_index_setting, document_language=doc_language
|
||||
)
|
||||
preview.summary = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
summary, _ = self.generate_summary(
|
||||
tenant_id, preview.content, summary_index_setting, document_language=doc_language
|
||||
)
|
||||
preview.summary = summary
|
||||
|
||||
# Generate summaries concurrently using ThreadPoolExecutor
|
||||
@@ -356,6 +364,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
text: str,
|
||||
summary_index_setting: dict | None = None,
|
||||
segment_id: str | None = None,
|
||||
document_language: str | None = None,
|
||||
) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt,
|
||||
@@ -366,6 +375,8 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
text: Text content to summarize
|
||||
summary_index_setting: Summary index configuration
|
||||
segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table
|
||||
document_language: Optional document language (e.g., "Chinese", "English")
|
||||
to ensure summary is generated in the correct language
|
||||
|
||||
Returns:
|
||||
Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object
|
||||
@@ -381,8 +392,22 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError("model_name and model_provider_name are required in summary_index_setting")
|
||||
|
||||
# Import default summary prompt
|
||||
is_default_prompt = False
|
||||
if not summary_prompt:
|
||||
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
is_default_prompt = True
|
||||
|
||||
# Format prompt with document language only for default prompt
|
||||
# Custom prompts are used as-is to avoid interfering with user-defined templates
|
||||
# If document_language is provided, use it; otherwise, use "the same language as the input content"
|
||||
# This is especially important for image-only chunks where text is empty or minimal
|
||||
if is_default_prompt:
|
||||
language_for_prompt = document_language or "the same language as the input content"
|
||||
try:
|
||||
summary_prompt = summary_prompt.format(language=language_for_prompt)
|
||||
except KeyError:
|
||||
# If default prompt doesn't have {language} placeholder, use it as-is
|
||||
pass
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
|
||||
@@ -358,7 +358,11 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
}
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary
|
||||
@@ -389,6 +393,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
tenant_id=tenant_id,
|
||||
text=preview.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
preview.summary = summary
|
||||
else:
|
||||
@@ -397,6 +402,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
tenant_id=tenant_id,
|
||||
text=preview.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
preview.summary = summary
|
||||
|
||||
|
||||
@@ -241,7 +241,11 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
}
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
QA model doesn't generate summaries, so this method returns preview_texts unchanged.
|
||||
|
||||
@@ -192,32 +192,33 @@ class AgentNode(Node[AgentNodeData]):
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
if agent_input.type == "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
elif agent_input.type in {"mixed", "constant"}:
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
else:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
case _:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
@@ -374,12 +375,13 @@ class AgentNode(Node[AgentNodeData]):
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
if input.type in ["mixed", "constant"]:
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
match input.type:
|
||||
case "mixed" | "constant":
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
|
||||
@@ -270,15 +270,18 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
if typed_node_data.datasource_parameters:
|
||||
for parameter_name in typed_node_data.datasource_parameters:
|
||||
input = typed_node_data.datasource_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
match input.type:
|
||||
case "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
case "constant":
|
||||
pass
|
||||
case None:
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
@@ -308,99 +311,107 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceMessage.MessageType.BINARY_LINK,
|
||||
DatasourceMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
match message.type:
|
||||
case (
|
||||
DatasourceMessage.MessageType.IMAGE_LINK
|
||||
| DatasourceMessage.MessageType.BINARY_LINK
|
||||
| DatasourceMessage.MessageType.IMAGE
|
||||
):
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
files.append(file)
|
||||
case DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
case DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
case DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
json.append(message.message.json_object)
|
||||
case DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
case DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
case DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
case (
|
||||
DatasourceMessage.MessageType.BLOB_CHUNK
|
||||
| DatasourceMessage.MessageType.LOG
|
||||
| DatasourceMessage.MessageType.RETRIEVER_RESOURCES
|
||||
):
|
||||
pass
|
||||
|
||||
# mark the end of the stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
|
||||
@@ -78,12 +78,21 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
|
||||
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
|
||||
|
||||
# Try to get document language if document_id is available
|
||||
doc_language = None
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if document_id:
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if document and document.doc_language:
|
||||
doc_language = document.doc_language
|
||||
|
||||
outputs = self._get_preview_output_with_summaries(
|
||||
node_data.chunk_structure,
|
||||
chunks,
|
||||
dataset=dataset,
|
||||
indexing_technique=indexing_technique,
|
||||
summary_index_setting=summary_index_setting,
|
||||
doc_language=doc_language,
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@@ -315,6 +324,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
dataset: Dataset,
|
||||
indexing_technique: str | None = None,
|
||||
summary_index_setting: dict | None = None,
|
||||
doc_language: str | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Generate preview output with summaries for chunks in preview mode.
|
||||
@@ -326,6 +336,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
dataset: Dataset object (for tenant_id)
|
||||
indexing_technique: Indexing technique from node config or dataset
|
||||
summary_index_setting: Summary index setting from node config or dataset
|
||||
doc_language: Optional document language to ensure summary is generated in the correct language
|
||||
"""
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
@@ -365,6 +376,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
@@ -374,6 +386,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
|
||||
@@ -303,33 +303,34 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
|
||||
if node_data.multiple_retrieval_config.reranking_model:
|
||||
reranking_model = {
|
||||
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
|
||||
}
|
||||
else:
|
||||
match node_data.multiple_retrieval_config.reranking_mode:
|
||||
case "reranking_model":
|
||||
if node_data.multiple_retrieval_config.reranking_model:
|
||||
reranking_model = {
|
||||
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
|
||||
}
|
||||
else:
|
||||
reranking_model = None
|
||||
weights = None
|
||||
case "weighted_score":
|
||||
if node_data.multiple_retrieval_config.weights is None:
|
||||
raise ValueError("weights is required")
|
||||
reranking_model = None
|
||||
weights = None
|
||||
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
|
||||
if node_data.multiple_retrieval_config.weights is None:
|
||||
raise ValueError("weights is required")
|
||||
reranking_model = None
|
||||
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
|
||||
weights = {
|
||||
"vector_setting": {
|
||||
"vector_weight": vector_setting.vector_weight,
|
||||
"embedding_provider_name": vector_setting.embedding_provider_name,
|
||||
"embedding_model_name": vector_setting.embedding_model_name,
|
||||
},
|
||||
"keyword_setting": {
|
||||
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
|
||||
},
|
||||
}
|
||||
else:
|
||||
reranking_model = None
|
||||
weights = None
|
||||
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
|
||||
weights = {
|
||||
"vector_setting": {
|
||||
"vector_weight": vector_setting.vector_weight,
|
||||
"embedding_provider_name": vector_setting.embedding_provider_name,
|
||||
"embedding_model_name": vector_setting.embedding_model_name,
|
||||
},
|
||||
"keyword_setting": {
|
||||
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
|
||||
},
|
||||
}
|
||||
case _:
|
||||
reranking_model = None
|
||||
weights = None
|
||||
all_documents = dataset_retrieval.multiple_retrieve(
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
@@ -453,73 +454,74 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
)
|
||||
filters: list[Any] = []
|
||||
metadata_condition = None
|
||||
if node_data.metadata_filtering_mode == "disabled":
|
||||
return None, None, usage
|
||||
elif node_data.metadata_filtering_mode == "automatic":
|
||||
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
|
||||
dataset_ids, query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, automatic_usage)
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
filter.get("value"),
|
||||
filters,
|
||||
)
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=filter.get("metadata_name"), # type: ignore
|
||||
comparison_operator=filter.get("condition"), # type: ignore
|
||||
value=filter.get("value"),
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or",
|
||||
conditions=conditions,
|
||||
match node_data.metadata_filtering_mode:
|
||||
case "disabled":
|
||||
return None, None, usage
|
||||
case "automatic":
|
||||
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
|
||||
dataset_ids, query, node_data
|
||||
)
|
||||
elif node_data.metadata_filtering_mode == "manual":
|
||||
if node_data.metadata_filtering_conditions:
|
||||
conditions = []
|
||||
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
).value[0]
|
||||
if expected_value.value_type in {"number", "integer", "float"}:
|
||||
expected_value = expected_value.value
|
||||
elif expected_value.value_type == "string":
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
usage = self._merge_usage(usage, automatic_usage)
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
filter.get("value"),
|
||||
filters,
|
||||
)
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=filter.get("metadata_name"), # type: ignore
|
||||
comparison_operator=filter.get("condition"), # type: ignore
|
||||
value=filter.get("value"),
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or",
|
||||
conditions=conditions,
|
||||
)
|
||||
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
case "manual":
|
||||
if node_data.metadata_filtering_conditions:
|
||||
conditions = []
|
||||
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
).value[0]
|
||||
if expected_value.value_type in {"number", "integer", "float"}:
|
||||
expected_value = expected_value.value
|
||||
elif expected_value.value_type == "string":
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
case _:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
if (
|
||||
node_data.metadata_filtering_conditions
|
||||
|
||||
@@ -482,16 +482,17 @@ class ToolNode(Node[ToolNodeData]):
|
||||
result = {}
|
||||
for parameter_name in typed_node_data.tool_parameters:
|
||||
input = typed_node_data.tool_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
match input.type:
|
||||
case "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
case "constant":
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
|
||||
@@ -390,8 +390,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
||||
"""
|
||||
content = self.load_once(filename)
|
||||
|
||||
with Path(target_filepath).open("wb") as f:
|
||||
f.write(content)
|
||||
Path(target_filepath).write_bytes(content)
|
||||
|
||||
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)
|
||||
|
||||
|
||||
@@ -1,36 +1,69 @@
|
||||
from flask_restx import Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from libs.helper import TimestampField
|
||||
from datetime import datetime
|
||||
|
||||
annotation_fields = {
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"answer": fields.Raw(attribute="content"),
|
||||
"hit_count": fields.Integer,
|
||||
"created_at": TimestampField,
|
||||
# 'account': fields.Nested(simple_account_fields, allow_null=True)
|
||||
}
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
def build_annotation_model(api_or_ns: Namespace):
|
||||
"""Build the annotation model for the API or Namespace."""
|
||||
return api_or_ns.model("Annotation", annotation_fields)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
annotation_list_fields = {
|
||||
"data": fields.List(fields.Nested(annotation_fields)),
|
||||
}
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
annotation_hit_history_fields = {
|
||||
"id": fields.String,
|
||||
"source": fields.String,
|
||||
"score": fields.Float,
|
||||
"question": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"match": fields.String(attribute="annotation_question"),
|
||||
"response": fields.String(attribute="annotation_content"),
|
||||
}
|
||||
|
||||
annotation_hit_history_list_fields = {
|
||||
"data": fields.List(fields.Nested(annotation_hit_history_fields)),
|
||||
}
|
||||
class Annotation(ResponseModel):
|
||||
id: str
|
||||
question: str | None = None
|
||||
answer: str | None = Field(default=None, validation_alias="content")
|
||||
hit_count: int | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AnnotationList(ResponseModel):
|
||||
data: list[Annotation]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class AnnotationExportList(ResponseModel):
|
||||
data: list[Annotation]
|
||||
|
||||
|
||||
class AnnotationHitHistory(ResponseModel):
|
||||
id: str
|
||||
source: str | None = None
|
||||
score: float | None = None
|
||||
question: str | None = None
|
||||
created_at: int | None = None
|
||||
match: str | None = Field(default=None, validation_alias="annotation_question")
|
||||
response: str | None = Field(default=None, validation_alias="annotation_content")
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AnnotationHitHistoryList(ResponseModel):
|
||||
data: list[AnnotationHitHistory]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from flask_restx import Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from flask_restx import fields
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
simple_end_user_fields = {
|
||||
"id": fields.String,
|
||||
@@ -8,5 +11,18 @@ simple_end_user_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_simple_end_user_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
|
||||
class SimpleEndUser(ResponseModel):
|
||||
id: str
|
||||
type: str
|
||||
is_anonymous: bool
|
||||
session_id: str | None = None
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from flask_restx import Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from libs.helper import AvatarUrlField, TimestampField
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restx import fields
|
||||
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
|
||||
|
||||
from core.file import helpers as file_helpers
|
||||
|
||||
simple_account_fields = {
|
||||
"id": fields.String,
|
||||
@@ -9,36 +14,78 @@ simple_account_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_simple_account_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("SimpleAccount", simple_account_fields)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
account_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"avatar": fields.String,
|
||||
"avatar_url": AvatarUrlField,
|
||||
"email": fields.String,
|
||||
"is_password_set": fields.Boolean,
|
||||
"interface_language": fields.String,
|
||||
"interface_theme": fields.String,
|
||||
"timezone": fields.String,
|
||||
"last_login_at": TimestampField,
|
||||
"last_login_ip": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
def _build_avatar_url(avatar: str | None) -> str | None:
|
||||
if avatar is None:
|
||||
return None
|
||||
if avatar.startswith(("http://", "https://")):
|
||||
return avatar
|
||||
return file_helpers.get_signed_file_url(avatar)
|
||||
|
||||
account_with_role_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"avatar": fields.String,
|
||||
"avatar_url": AvatarUrlField,
|
||||
"email": fields.String,
|
||||
"last_login_at": TimestampField,
|
||||
"last_active_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"role": fields.String,
|
||||
"status": fields.String,
|
||||
}
|
||||
|
||||
account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))}
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
|
||||
class SimpleAccount(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
|
||||
|
||||
class _AccountAvatar(ResponseModel):
|
||||
avatar: str | None = None
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
|
||||
@property
|
||||
def avatar_url(self) -> str | None:
|
||||
return _build_avatar_url(self.avatar)
|
||||
|
||||
|
||||
class Account(_AccountAvatar):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
is_password_set: bool
|
||||
interface_language: str | None = None
|
||||
interface_theme: str | None = None
|
||||
timezone: str | None = None
|
||||
last_login_at: int | None = None
|
||||
last_login_ip: str | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("last_login_at", "created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AccountWithRole(_AccountAvatar):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
last_login_at: int | None = None
|
||||
last_active_at: int | None = None
|
||||
created_at: int | None = None
|
||||
role: str
|
||||
status: str
|
||||
|
||||
@field_validator("last_login_at", "last_active_at", "created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AccountWithRoleList(ResponseModel):
|
||||
accounts: list[AccountWithRole]
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
from flask_restx import Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"type": fields.String,
|
||||
"binding_count": fields.String,
|
||||
}
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
|
||||
class DataSetTag(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
binding_count: str | None = None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
|
||||
from fields.member_fields import build_simple_account_model, simple_account_fields
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_run_fields import (
|
||||
build_workflow_run_for_archived_log_model,
|
||||
build_workflow_run_for_log_model,
|
||||
@@ -25,17 +25,9 @@ workflow_app_log_partial_fields = {
|
||||
def build_workflow_app_log_partial_model(api_or_ns: Namespace):
|
||||
"""Build the workflow app log partial model for the API or Namespace."""
|
||||
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
|
||||
simple_account_model = build_simple_account_model(api_or_ns)
|
||||
simple_end_user_model = build_simple_end_user_model(api_or_ns)
|
||||
|
||||
copied_fields = workflow_app_log_partial_fields.copy()
|
||||
copied_fields["workflow_run"] = fields.Nested(workflow_run_model, attribute="workflow_run", allow_null=True)
|
||||
copied_fields["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
copied_fields["created_by_end_user"] = fields.Nested(
|
||||
simple_end_user_model, attribute="created_by_end_user", allow_null=True
|
||||
)
|
||||
return api_or_ns.model("WorkflowAppLogPartial", copied_fields)
|
||||
|
||||
|
||||
@@ -52,17 +44,9 @@ workflow_archived_log_partial_fields = {
|
||||
def build_workflow_archived_log_partial_model(api_or_ns: Namespace):
|
||||
"""Build the workflow archived log partial model for the API or Namespace."""
|
||||
workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns)
|
||||
simple_account_model = build_simple_account_model(api_or_ns)
|
||||
simple_end_user_model = build_simple_end_user_model(api_or_ns)
|
||||
|
||||
copied_fields = workflow_archived_log_partial_fields.copy()
|
||||
copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True)
|
||||
copied_fields["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
copied_fields["created_by_end_user"] = fields.Nested(
|
||||
simple_end_user_model, attribute="created_by_end_user", allow_null=True
|
||||
)
|
||||
return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields)
|
||||
|
||||
|
||||
|
||||
@@ -145,7 +145,7 @@ dev = [
|
||||
"types-openpyxl~=3.1.5",
|
||||
"types-pexpect~=4.9.0",
|
||||
"types-protobuf~=5.29.1",
|
||||
"types-psutil~=7.0.0",
|
||||
"types-psutil~=7.2.2",
|
||||
"types-psycopg2~=2.9.21",
|
||||
"types-pygments~=2.19.0",
|
||||
"types-pymysql~=1.1.0",
|
||||
|
||||
@@ -158,7 +158,7 @@ class AppAnnotationService:
|
||||
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
|
||||
)
|
||||
annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
return annotations.items, annotations.total
|
||||
return annotations.items, annotations.total or 0
|
||||
|
||||
@classmethod
|
||||
def export_annotation_list_by_app_id(cls, app_id: str):
|
||||
@@ -524,7 +524,7 @@ class AppAnnotationService:
|
||||
annotation_hit_histories = db.paginate(
|
||||
select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False
|
||||
)
|
||||
return annotation_hit_histories.items, annotation_hit_histories.total
|
||||
return annotation_hit_histories.items, annotation_hit_histories.total or 0
|
||||
|
||||
@classmethod
|
||||
def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None:
|
||||
|
||||
@@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
@@ -1388,6 +1389,46 @@ class DocumentService:
|
||||
).all()
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def update_documents_need_summary(dataset_id: str, document_ids: Sequence[str], need_summary: bool = True) -> int:
|
||||
"""
|
||||
Update need_summary field for multiple documents.
|
||||
|
||||
This method handles the case where documents were created when summary_index_setting was disabled,
|
||||
and need to be updated when summary_index_setting is later enabled.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
document_ids: List of document IDs to update
|
||||
need_summary: Value to set for need_summary field (default: True)
|
||||
|
||||
Returns:
|
||||
Number of documents updated
|
||||
"""
|
||||
if not document_ids:
|
||||
return 0
|
||||
|
||||
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
updated_count = (
|
||||
session.query(Document)
|
||||
.filter(
|
||||
Document.id.in_(document_id_list),
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.doc_form != "qa_model", # Skip qa_model documents
|
||||
)
|
||||
.update({Document.need_summary: need_summary}, synchronize_session=False)
|
||||
)
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Updated need_summary to %s for %d documents in dataset %s",
|
||||
need_summary,
|
||||
updated_count,
|
||||
dataset_id,
|
||||
)
|
||||
return updated_count
|
||||
|
||||
@staticmethod
|
||||
def get_document_download_url(document: Document) -> str:
|
||||
"""
|
||||
@@ -2937,14 +2978,15 @@ class DocumentService:
|
||||
"""
|
||||
now = naive_utc_now()
|
||||
|
||||
if action == "enable":
|
||||
return DocumentService._prepare_enable_update(document, now)
|
||||
elif action == "disable":
|
||||
return DocumentService._prepare_disable_update(document, user, now)
|
||||
elif action == "archive":
|
||||
return DocumentService._prepare_archive_update(document, user, now)
|
||||
elif action == "un_archive":
|
||||
return DocumentService._prepare_unarchive_update(document, now)
|
||||
match action:
|
||||
case "enable":
|
||||
return DocumentService._prepare_enable_update(document, now)
|
||||
case "disable":
|
||||
return DocumentService._prepare_disable_update(document, user, now)
|
||||
case "archive":
|
||||
return DocumentService._prepare_archive_update(document, user, now)
|
||||
case "un_archive":
|
||||
return DocumentService._prepare_unarchive_update(document, now)
|
||||
|
||||
return None
|
||||
|
||||
@@ -3581,56 +3623,57 @@ class SegmentService:
|
||||
# Check if segment_ids is not empty to avoid WHERE false condition
|
||||
if not segment_ids or len(segment_ids) == 0:
|
||||
return
|
||||
if action == "enable":
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == False,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = True
|
||||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
db.session.add(segment)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
match action:
|
||||
case "enable":
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == False,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = True
|
||||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
db.session.add(segment)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
|
||||
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
elif action == "disable":
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.disabled_by = current_user.id
|
||||
db.session.add(segment)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
case "disable":
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.disabled_by = current_user.id
|
||||
db.session.add(segment)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
|
||||
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
|
||||
@classmethod
|
||||
def create_child_chunk(
|
||||
|
||||
@@ -174,6 +174,10 @@ class RagPipelineTransformService:
|
||||
else:
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
|
||||
# Copy summary_index_setting from dataset to knowledge_index node configuration
|
||||
if dataset.summary_index_setting:
|
||||
knowledge_configuration.summary_index_setting = dataset.summary_index_setting
|
||||
|
||||
knowledge_configuration_dict.update(knowledge_configuration.model_dump())
|
||||
node["data"] = knowledge_configuration_dict
|
||||
return node
|
||||
|
||||
@@ -49,11 +49,18 @@ class SummaryIndexService:
|
||||
# Use lazy import to avoid circular import
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
|
||||
# Get document language to ensure summary is generated in the correct language
|
||||
# This is especially important for image-only chunks where text is empty or minimal
|
||||
document_language = None
|
||||
if segment.document and segment.document.doc_language:
|
||||
document_language = segment.document.doc_language
|
||||
|
||||
summary_content, usage = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=segment.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
segment_id=segment.id,
|
||||
document_language=document_language,
|
||||
)
|
||||
|
||||
if not summary_content:
|
||||
@@ -558,6 +565,9 @@ class SummaryIndexService:
|
||||
)
|
||||
session.add(summary_record)
|
||||
|
||||
# Commit the batch created records
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_summary_record_error(
|
||||
segment: DocumentSegment,
|
||||
@@ -762,7 +772,6 @@ class SummaryIndexService:
|
||||
dataset=dataset,
|
||||
status="not_started",
|
||||
)
|
||||
session.commit() # Commit initial records
|
||||
|
||||
summary_records = []
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class TagService:
|
||||
escaped_keyword = escape_like_pattern(keyword)
|
||||
query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
|
||||
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
|
||||
results = query.order_by(Tag.created_at.desc()).all()
|
||||
results: list = query.order_by(Tag.created_at.desc()).all()
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,291 +0,0 @@
|
||||
import builtins
|
||||
import contextlib
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from extensions import ext_fastopenapi
|
||||
from extensions.ext_database import db
|
||||
from services.feature_service import FeatureModel, SystemFeatureModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""
|
||||
Creates a Flask application instance configured for testing.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SECRET_KEY"] = "test-secret"
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
|
||||
# Initialize the database with the app
|
||||
db.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fix_method_view_issue(monkeypatch):
|
||||
"""
|
||||
Automatic fixture to patch 'builtins.MethodView'.
|
||||
|
||||
Why this is needed:
|
||||
The official legacy codebase contains a global patch in its initialization logic:
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView
|
||||
|
||||
Some dependencies (like ext_fastopenapi or older Flask extensions) might implicitly
|
||||
rely on 'MethodView' being available in the global builtins namespace.
|
||||
|
||||
Refactoring Note:
|
||||
While patching builtins is generally discouraged due to global side effects,
|
||||
this fixture reproduces the production environment's state to ensure tests are realistic.
|
||||
We use 'monkeypatch' to ensure that this change is undone after the test finishes,
|
||||
keeping other tests isolated.
|
||||
"""
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
# 'raising=False' allows us to set an attribute that doesn't exist yet
|
||||
monkeypatch.setattr(builtins, "MethodView", MethodView, raising=False)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Helper Functions for Fixture Complexity Reduction
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_isolated_router():
|
||||
"""
|
||||
Creates a fresh, isolated router instance to prevent route pollution.
|
||||
"""
|
||||
import controllers.fastopenapi
|
||||
|
||||
# Dynamically get the class type (e.g., FlaskRouter) to avoid hardcoding dependencies
|
||||
RouterClass = type(controllers.fastopenapi.console_router)
|
||||
return RouterClass()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _patch_auth_and_router(temp_router):
|
||||
"""
|
||||
Context manager that applies all necessary patches for:
|
||||
1. The console_router (redirecting to our isolated temp_router)
|
||||
2. Authentication decorators (disabling them with no-ops)
|
||||
3. User/Account loaders (mocking authenticated state)
|
||||
"""
|
||||
|
||||
def noop(f):
|
||||
return f
|
||||
|
||||
# We patch the SOURCE of the decorators/functions, not the destination module.
|
||||
# This ensures that when 'controllers.console.feature' imports them, it gets the mocks.
|
||||
with (
|
||||
patch("controllers.fastopenapi.console_router", temp_router),
|
||||
patch("extensions.ext_fastopenapi.console_router", temp_router),
|
||||
patch("controllers.console.wraps.setup_required", side_effect=noop),
|
||||
patch("libs.login.login_required", side_effect=noop),
|
||||
patch("controllers.console.wraps.account_initialization_required", side_effect=noop),
|
||||
patch("controllers.console.wraps.cloud_utm_record", side_effect=noop),
|
||||
patch("libs.login.current_account_with_tenant", return_value=(MagicMock(), "tenant-id")),
|
||||
patch("libs.login.current_user", MagicMock(is_authenticated=True)),
|
||||
):
|
||||
# Explicitly reload ext_fastopenapi to ensure it uses the patched console_router
|
||||
import extensions.ext_fastopenapi
|
||||
|
||||
importlib.reload(extensions.ext_fastopenapi)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def _force_reload_module(target_module: str, alias_module: str):
|
||||
"""
|
||||
Forces a reload of the specified module and handles sys.modules aliasing.
|
||||
|
||||
Why reload?
|
||||
Python decorators (like @route, @login_required) run at IMPORT time.
|
||||
To apply our patches (mocks/no-ops) to these decorators, we must re-import
|
||||
the module while the patches are active.
|
||||
|
||||
Why alias?
|
||||
If 'ext_fastopenapi' imports the controller as 'api.controllers...', but we import
|
||||
it as 'controllers...', Python treats them as two separate modules. This causes:
|
||||
1. Double execution of decorators (registering routes twice -> AssertionError).
|
||||
2. Type mismatch errors (Class A from module X is not Class A from module Y).
|
||||
|
||||
This function ensures both names point to the SAME loaded module instance.
|
||||
"""
|
||||
# 1. Clean existing entries to force re-import
|
||||
if target_module in sys.modules:
|
||||
del sys.modules[target_module]
|
||||
if alias_module in sys.modules:
|
||||
del sys.modules[alias_module]
|
||||
|
||||
# 2. Import the module (triggering decorators with active patches)
|
||||
module = importlib.import_module(target_module)
|
||||
|
||||
# 3. Alias the module in sys.modules to prevent double loading
|
||||
sys.modules[alias_module] = sys.modules[target_module]
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _cleanup_modules(target_module: str, alias_module: str):
|
||||
"""
|
||||
Removes the module and its alias from sys.modules to prevent side effects
|
||||
on other tests.
|
||||
"""
|
||||
if target_module in sys.modules:
|
||||
del sys.modules[target_module]
|
||||
if alias_module in sys.modules:
|
||||
del sys.modules[alias_module]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feature_module_env():
|
||||
"""
|
||||
Sets up a mocked environment for the feature module.
|
||||
|
||||
This fixture orchestrates:
|
||||
1. Creating an isolated router.
|
||||
2. Patching authentication and global dependencies.
|
||||
3. Reloading the controller module to apply patches to decorators.
|
||||
4. cleaning up sys.modules afterwards.
|
||||
"""
|
||||
target_module = "controllers.console.feature"
|
||||
alias_module = "api.controllers.console.feature"
|
||||
|
||||
# 1. Prepare isolated router
|
||||
temp_router = _create_isolated_router()
|
||||
|
||||
# 2. Apply patches
|
||||
try:
|
||||
with _patch_auth_and_router(temp_router):
|
||||
# 3. Reload module to register routes on the temp_router
|
||||
feature_module = _force_reload_module(target_module, alias_module)
|
||||
|
||||
yield feature_module
|
||||
|
||||
finally:
|
||||
# 4. Teardown: Clean up sys.modules
|
||||
_cleanup_modules(target_module, alias_module)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Test Cases
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("url", "service_mock_path", "mock_model_instance", "json_key"),
|
||||
[
|
||||
(
|
||||
"/console/api/features",
|
||||
"controllers.console.feature.FeatureService.get_features",
|
||||
FeatureModel(can_replace_logo=True),
|
||||
"features",
|
||||
),
|
||||
(
|
||||
"/console/api/system-features",
|
||||
"controllers.console.feature.FeatureService.get_system_features",
|
||||
SystemFeatureModel(enable_marketplace=True),
|
||||
"features",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_console_features_success(app, mock_feature_module_env, url, service_mock_path, mock_model_instance, json_key):
|
||||
"""
|
||||
Tests that the feature APIs return a 200 OK status and correct JSON structure.
|
||||
"""
|
||||
# Patch the service layer to return our mock model instance
|
||||
with patch(service_mock_path, return_value=mock_model_instance):
|
||||
# Initialize the API extension
|
||||
ext_fastopenapi.init_app(app)
|
||||
|
||||
client = app.test_client()
|
||||
response = client.get(url)
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 200, f"Request failed with status {response.status_code}: {response.text}"
|
||||
|
||||
# Verify the JSON response matches the Pydantic model dump
|
||||
expected_data = mock_model_instance.model_dump(mode="json")
|
||||
assert response.get_json() == {json_key: expected_data}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("url", "service_mock_path"),
|
||||
[
|
||||
("/console/api/features", "controllers.console.feature.FeatureService.get_features"),
|
||||
("/console/api/system-features", "controllers.console.feature.FeatureService.get_system_features"),
|
||||
],
|
||||
)
|
||||
def test_console_features_service_error(app, mock_feature_module_env, url, service_mock_path):
|
||||
"""
|
||||
Tests how the application handles Service layer errors.
|
||||
|
||||
Note: When an exception occurs in the view, it is typically caught by the framework
|
||||
(Flask or the OpenAPI wrapper) and converted to a 500 error response.
|
||||
This test verifies that the application returns a 500 status code.
|
||||
"""
|
||||
# Simulate a service failure
|
||||
with patch(service_mock_path, side_effect=ValueError("Service Failure")):
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# When an exception occurs in the view, it is typically caught by the framework
|
||||
# (Flask or the OpenAPI wrapper) and converted to a 500 error response.
|
||||
response = client.get(url)
|
||||
|
||||
assert response.status_code == 500
|
||||
# Check if the error details are exposed in the response (depends on error handler config)
|
||||
# We accept either generic 500 or the specific error message
|
||||
assert "Service Failure" in response.text or "Internal Server Error" in response.text
|
||||
|
||||
|
||||
def test_system_features_unauthenticated(app, mock_feature_module_env):
|
||||
"""
|
||||
Tests that /console/api/system-features endpoint works without authentication.
|
||||
|
||||
This test verifies the try-except block in get_system_features that handles
|
||||
unauthenticated requests by passing is_authenticated=False to the service layer.
|
||||
"""
|
||||
feature_module = mock_feature_module_env
|
||||
|
||||
# Override the behavior of the current_user mock
|
||||
# The fixture patched 'libs.login.current_user', so 'controllers.console.feature.current_user'
|
||||
# refers to that same Mock object.
|
||||
mock_user = feature_module.current_user
|
||||
|
||||
# Simulate property access raising Unauthorized
|
||||
# Note: We must reset side_effect if it was set, or set it here.
|
||||
# The fixture initialized it as MagicMock(is_authenticated=True).
|
||||
# We want type(mock_user).is_authenticated to raise Unauthorized.
|
||||
type(mock_user).is_authenticated = PropertyMock(side_effect=Unauthorized)
|
||||
|
||||
# Patch the service layer for this specific test
|
||||
with patch("controllers.console.feature.FeatureService.get_system_features") as mock_service:
|
||||
# Setup mock service return value
|
||||
mock_model = SystemFeatureModel(enable_marketplace=True)
|
||||
mock_service.return_value = mock_model
|
||||
|
||||
# Initialize app
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.get("/console/api/system-features")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||
|
||||
# Verify service was called with is_authenticated=False
|
||||
mock_service.assert_called_once_with(is_authenticated=False)
|
||||
|
||||
# Verify response body
|
||||
expected_data = mock_model.model_dump(mode="json")
|
||||
assert response.get_json() == {"features": expected_data}
|
||||
@@ -1,222 +0,0 @@
|
||||
import builtins
|
||||
import contextlib
|
||||
import importlib
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from extensions import ext_fastopenapi
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SECRET_KEY"] = "test-secret"
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
|
||||
db.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fix_method_view_issue(monkeypatch):
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
monkeypatch.setattr(builtins, "MethodView", MethodView, raising=False)
|
||||
|
||||
|
||||
def _create_isolated_router():
|
||||
import controllers.fastopenapi
|
||||
|
||||
router_class = type(controllers.fastopenapi.console_router)
|
||||
return router_class()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _patch_auth_and_router(temp_router):
|
||||
def noop(func):
|
||||
return func
|
||||
|
||||
default_user = MagicMock(has_edit_permission=True, is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
patch("controllers.fastopenapi.console_router", temp_router),
|
||||
patch("extensions.ext_fastopenapi.console_router", temp_router),
|
||||
patch("controllers.console.wraps.setup_required", side_effect=noop),
|
||||
patch("libs.login.login_required", side_effect=noop),
|
||||
patch("controllers.console.wraps.account_initialization_required", side_effect=noop),
|
||||
patch("controllers.console.wraps.edit_permission_required", side_effect=noop),
|
||||
patch("libs.login.current_account_with_tenant", return_value=(default_user, "tenant-id")),
|
||||
patch("configs.dify_config.EDITION", "CLOUD"),
|
||||
):
|
||||
import extensions.ext_fastopenapi
|
||||
|
||||
importlib.reload(extensions.ext_fastopenapi)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def _force_reload_module(target_module: str, alias_module: str):
|
||||
if target_module in sys.modules:
|
||||
del sys.modules[target_module]
|
||||
if alias_module in sys.modules:
|
||||
del sys.modules[alias_module]
|
||||
|
||||
module = importlib.import_module(target_module)
|
||||
sys.modules[alias_module] = sys.modules[target_module]
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _dedupe_routes(router):
|
||||
seen = set()
|
||||
unique_routes = []
|
||||
for path, method, endpoint in reversed(router.get_routes()):
|
||||
key = (path, method, endpoint.__name__)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
unique_routes.append((path, method, endpoint))
|
||||
router._routes = list(reversed(unique_routes))
|
||||
|
||||
|
||||
def _cleanup_modules(target_module: str, alias_module: str):
|
||||
if target_module in sys.modules:
|
||||
del sys.modules[target_module]
|
||||
if alias_module in sys.modules:
|
||||
del sys.modules[alias_module]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tags_module_env():
|
||||
target_module = "controllers.console.tag.tags"
|
||||
alias_module = "api.controllers.console.tag.tags"
|
||||
temp_router = _create_isolated_router()
|
||||
|
||||
try:
|
||||
with _patch_auth_and_router(temp_router):
|
||||
tags_module = _force_reload_module(target_module, alias_module)
|
||||
_dedupe_routes(temp_router)
|
||||
yield tags_module
|
||||
finally:
|
||||
_cleanup_modules(target_module, alias_module)
|
||||
|
||||
|
||||
def test_list_tags_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
tag = SimpleNamespace(id="tag-1", name="Alpha", type="app", binding_count=2)
|
||||
with patch("controllers.console.tag.tags.TagService.get_tags", return_value=[tag]):
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.get("/console/api/tags?type=app&keyword=Alpha")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == [
|
||||
{"id": "tag-1", "name": "Alpha", "type": "app", "binding_count": 2},
|
||||
]
|
||||
|
||||
|
||||
def test_create_tag_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
tag = SimpleNamespace(id="tag-2", name="Beta", type="app")
|
||||
with patch("controllers.console.tag.tags.TagService.save_tags", return_value=tag) as mock_save:
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.post("/console/api/tags", json={"name": "Beta", "type": "app"})
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"id": "tag-2",
|
||||
"name": "Beta",
|
||||
"type": "app",
|
||||
"binding_count": 0,
|
||||
}
|
||||
mock_save.assert_called_once_with({"name": "Beta", "type": "app"})
|
||||
|
||||
|
||||
def test_update_tag_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
tag = SimpleNamespace(id="tag-3", name="Gamma", type="app")
|
||||
with (
|
||||
patch("controllers.console.tag.tags.TagService.update_tags", return_value=tag) as mock_update,
|
||||
patch("controllers.console.tag.tags.TagService.get_tag_binding_count", return_value=4),
|
||||
):
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.patch(
|
||||
"/console/api/tags/11111111-1111-1111-1111-111111111111",
|
||||
json={"name": "Gamma", "type": "app"},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"id": "tag-3",
|
||||
"name": "Gamma",
|
||||
"type": "app",
|
||||
"binding_count": 4,
|
||||
}
|
||||
mock_update.assert_called_once_with(
|
||||
{"name": "Gamma", "type": "app"},
|
||||
"11111111-1111-1111-1111-111111111111",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_tag_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
with patch("controllers.console.tag.tags.TagService.delete_tag") as mock_delete:
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.delete("/console/api/tags/11111111-1111-1111-1111-111111111111")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 204
|
||||
mock_delete.assert_called_once_with("11111111-1111-1111-1111-111111111111")
|
||||
|
||||
|
||||
def test_create_tag_binding_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
payload = {"tag_ids": ["tag-1", "tag-2"], "target_id": "target-1", "type": "app"}
|
||||
with patch("controllers.console.tag.tags.TagService.save_tag_binding") as mock_bind:
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.post("/console/api/tag-bindings/create", json=payload)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
mock_bind.assert_called_once_with(payload)
|
||||
|
||||
|
||||
def test_delete_tag_binding_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
payload = {"tag_id": "tag-1", "target_id": "target-1", "type": "app"}
|
||||
with patch("controllers.console.tag.tags.TagService.delete_tag_binding") as mock_unbind:
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.post("/console/api/tag-bindings/remove", json=payload)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
mock_unbind.assert_called_once_with(payload)
|
||||
@@ -0,0 +1,364 @@
|
||||
"""Endpoint tests for controllers.console.workspace.tool_providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import importlib
|
||||
from contextlib import contextmanager
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
_CONTROLLER_MODULE: ModuleType | None = None
|
||||
_WRAPS_MODULE: ModuleType | None = None
|
||||
_CONTROLLER_PATCHERS: list[patch] = []
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _mock_db():
|
||||
mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True))
|
||||
with patch("extensions.ext_database.db.session", mock_session):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def controller_module(monkeypatch: pytest.MonkeyPatch):
|
||||
module_name = "controllers.console.workspace.tool_providers"
|
||||
global _CONTROLLER_MODULE
|
||||
if _CONTROLLER_MODULE is None:
|
||||
|
||||
def _noop(func):
|
||||
return func
|
||||
|
||||
patch_targets = [
|
||||
("libs.login.login_required", _noop),
|
||||
("controllers.console.wraps.setup_required", _noop),
|
||||
("controllers.console.wraps.account_initialization_required", _noop),
|
||||
("controllers.console.wraps.is_admin_or_owner_required", _noop),
|
||||
("controllers.console.wraps.enterprise_license_required", _noop),
|
||||
]
|
||||
for target, value in patch_targets:
|
||||
patcher = patch(target, value)
|
||||
patcher.start()
|
||||
_CONTROLLER_PATCHERS.append(patcher)
|
||||
monkeypatch.setenv("DIFY_SETUP_READY", "true")
|
||||
with _mock_db():
|
||||
_CONTROLLER_MODULE = importlib.import_module(module_name)
|
||||
|
||||
module = _CONTROLLER_MODULE
|
||||
monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload)
|
||||
|
||||
# Ensure decorators that consult deployment edition do not reach the database.
|
||||
global _WRAPS_MODULE
|
||||
wraps_module = importlib.import_module("controllers.console.wraps")
|
||||
_WRAPS_MODULE = wraps_module
|
||||
monkeypatch.setattr(module.dify_config, "EDITION", "CLOUD")
|
||||
monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD")
|
||||
|
||||
login_module = importlib.import_module("libs.login")
|
||||
monkeypatch.setattr(login_module, "check_csrf_token", lambda *args, **kwargs: None)
|
||||
return module
|
||||
|
||||
|
||||
def _mock_account(user_id: str = "user-123") -> SimpleNamespace:
|
||||
return SimpleNamespace(id=user_id, status="active", is_authenticated=True, current_tenant_id=None)
|
||||
|
||||
|
||||
def _set_current_account(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
controller_module: ModuleType,
|
||||
user: SimpleNamespace,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
def _getter():
|
||||
return user, tenant_id
|
||||
|
||||
user.current_tenant_id = tenant_id
|
||||
|
||||
monkeypatch.setattr(controller_module, "current_account_with_tenant", _getter)
|
||||
if _WRAPS_MODULE is not None:
|
||||
monkeypatch.setattr(_WRAPS_MODULE, "current_account_with_tenant", _getter)
|
||||
|
||||
login_module = importlib.import_module("libs.login")
|
||||
monkeypatch.setattr(login_module, "_get_user", lambda: user)
|
||||
|
||||
|
||||
def test_tool_provider_list_calls_service_with_query(
|
||||
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
|
||||
|
||||
service_mock = MagicMock(return_value=[{"provider": "builtin"}])
|
||||
monkeypatch.setattr(controller_module.ToolCommonService, "list_tool_providers", service_mock)
|
||||
|
||||
with app.test_request_context("/workspaces/current/tool-providers?type=builtin"):
|
||||
response = controller_module.ToolProviderListApi().get()
|
||||
|
||||
assert response == [{"provider": "builtin"}]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-456", "builtin")
|
||||
|
||||
|
||||
def test_builtin_provider_add_passes_payload(
|
||||
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
|
||||
|
||||
service_mock = MagicMock(return_value={"status": "ok"})
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "add_builtin_tool_provider", service_mock)
|
||||
|
||||
payload = {
|
||||
"credentials": {"api_key": "sk-test"},
|
||||
"name": "MyTool",
|
||||
"type": controller_module.CredentialType.API_KEY,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/tool-provider/builtin/openai/add",
|
||||
method="POST",
|
||||
json=payload,
|
||||
):
|
||||
response = controller_module.ToolBuiltinProviderAddApi().post(provider="openai")
|
||||
|
||||
assert response == {"status": "ok"}
|
||||
service_mock.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
tenant_id="tenant-456",
|
||||
provider="openai",
|
||||
credentials={"api_key": "sk-test"},
|
||||
name="MyTool",
|
||||
api_type=controller_module.CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
|
||||
def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-789")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-789")
|
||||
|
||||
service_mock = MagicMock(return_value=[{"name": "tool-a"}])
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "list_builtin_tool_provider_tools", service_mock)
|
||||
monkeypatch.setattr(controller_module, "jsonable_encoder", lambda payload: payload)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/tool-provider/builtin/my-provider/tools",
|
||||
method="GET",
|
||||
):
|
||||
response = controller_module.ToolBuiltinProviderListToolsApi().get(provider="my-provider")
|
||||
|
||||
assert response == [{"name": "tool-a"}]
|
||||
service_mock.assert_called_once_with("tenant-789", "my-provider")
|
||||
|
||||
|
||||
def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-9")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-9")
|
||||
service_mock = MagicMock(return_value={"info": True})
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock)
|
||||
|
||||
with app.test_request_context("/info", method="GET"):
|
||||
resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo")
|
||||
|
||||
assert resp == {"info": True}
|
||||
service_mock.assert_called_once_with("tenant-9", "demo")
|
||||
|
||||
|
||||
def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-cred")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-cred")
|
||||
service_mock = MagicMock(return_value=[{"cred": 1}])
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
"get_builtin_tool_provider_credentials",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
with app.test_request_context("/creds", method="GET"):
|
||||
resp = controller_module.ToolBuiltinProviderGetCredentialsApi().get(provider="demo")
|
||||
|
||||
assert resp == [{"cred": 1}]
|
||||
service_mock.assert_called_once_with(tenant_id="tenant-cred", provider_name="demo")
|
||||
|
||||
|
||||
def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-10")
|
||||
service_mock = MagicMock(return_value={"schema": "ok"})
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider_remote_schema", service_mock)
|
||||
|
||||
with app.test_request_context("/remote?url=https://example.com/"):
|
||||
resp = controller_module.ToolApiProviderGetRemoteSchemaApi().get()
|
||||
|
||||
assert resp == {"schema": "ok"}
|
||||
service_mock.assert_called_once_with(user.id, "tenant-10", "https://example.com/")
|
||||
|
||||
|
||||
def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-11")
|
||||
service_mock = MagicMock(return_value=[{"tool": "t"}])
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "list_api_tool_provider_tools", service_mock)
|
||||
|
||||
with app.test_request_context("/tools?provider=foo"):
|
||||
resp = controller_module.ToolApiProviderListToolsApi().get()
|
||||
|
||||
assert resp == [{"tool": "t"}]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-11", "foo")
|
||||
|
||||
|
||||
def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-12")
|
||||
service_mock = MagicMock(return_value={"provider": "foo"})
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider", service_mock)
|
||||
|
||||
with app.test_request_context("/get?provider=foo"):
|
||||
resp = controller_module.ToolApiProviderGetApi().get()
|
||||
|
||||
assert resp == {"provider": "foo"}
|
||||
service_mock.assert_called_once_with(user.id, "tenant-12", "foo")
|
||||
|
||||
|
||||
def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-13")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-13")
|
||||
service_mock = MagicMock(return_value={"schema": True})
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
"list_builtin_provider_credentials_schema",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
with app.test_request_context("/schema", method="GET"):
|
||||
resp = controller_module.ToolBuiltinProviderCredentialsSchemaApi().get(
|
||||
provider="demo", credential_type="api-key"
|
||||
)
|
||||
|
||||
assert resp == {"schema": True}
|
||||
service_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf")
|
||||
tool_service = MagicMock(return_value={"wf": 1})
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
"get_workflow_tool_by_tool_id",
|
||||
tool_service,
|
||||
)
|
||||
|
||||
tool_id = "00000000-0000-0000-0000-000000000001"
|
||||
with app.test_request_context(f"/workflow?workflow_tool_id={tool_id}"):
|
||||
resp = controller_module.ToolWorkflowProviderGetApi().get()
|
||||
|
||||
assert resp == {"wf": 1}
|
||||
tool_service.assert_called_once_with(user.id, "tenant-wf", tool_id)
|
||||
|
||||
|
||||
def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf2")
|
||||
service_mock = MagicMock(return_value={"app": 1})
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
"get_workflow_tool_by_app_id",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
app_id = "00000000-0000-0000-0000-000000000002"
|
||||
with app.test_request_context(f"/workflow?workflow_app_id={app_id}"):
|
||||
resp = controller_module.ToolWorkflowProviderGetApi().get()
|
||||
|
||||
assert resp == {"app": 1}
|
||||
service_mock.assert_called_once_with(user.id, "tenant-wf2", app_id)
|
||||
|
||||
|
||||
def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf3")
|
||||
service_mock = MagicMock(return_value=[{"id": 1}])
|
||||
monkeypatch.setattr(controller_module.WorkflowToolManageService, "list_single_workflow_tools", service_mock)
|
||||
|
||||
tool_id = "00000000-0000-0000-0000-000000000003"
|
||||
with app.test_request_context(f"/workflow/tools?workflow_tool_id={tool_id}"):
|
||||
resp = controller_module.ToolWorkflowProviderListToolApi().get()
|
||||
|
||||
assert resp == [{"id": 1}]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-wf3", tool_id)
|
||||
|
||||
|
||||
def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-bt")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "builtin"})
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
"list_builtin_tools",
|
||||
MagicMock(return_value=[provider]),
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/builtin"):
|
||||
resp = controller_module.ToolBuiltinListApi().get()
|
||||
|
||||
assert resp == [{"name": "builtin"}]
|
||||
|
||||
|
||||
def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-api")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-api")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "api"})
|
||||
monkeypatch.setattr(
|
||||
controller_module.ApiToolManageService,
|
||||
"list_api_tools",
|
||||
MagicMock(return_value=[provider]),
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/api"):
|
||||
resp = controller_module.ToolApiListApi().get()
|
||||
|
||||
assert resp == [{"name": "api"}]
|
||||
|
||||
|
||||
def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf4")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "wf"})
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
"list_tenant_workflow_tools",
|
||||
MagicMock(return_value=[provider]),
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/workflow"):
|
||||
resp = controller_module.ToolWorkflowListApi().get()
|
||||
|
||||
assert resp == [{"name": "wf"}]
|
||||
|
||||
|
||||
def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-label")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-labels")
|
||||
monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: ["a", "b"])
|
||||
|
||||
with app.test_request_context("/tool-labels"):
|
||||
resp = controller_module.ToolLabelsApi().get()
|
||||
|
||||
assert resp == ["a", "b"]
|
||||
8
api/uv.lock
generated
8
api/uv.lock
generated
@@ -1707,7 +1707,7 @@ dev = [
|
||||
{ name = "types-openpyxl", specifier = "~=3.1.5" },
|
||||
{ name = "types-pexpect", specifier = "~=4.9.0" },
|
||||
{ name = "types-protobuf", specifier = "~=5.29.1" },
|
||||
{ name = "types-psutil", specifier = "~=7.0.0" },
|
||||
{ name = "types-psutil", specifier = "~=7.2.2" },
|
||||
{ name = "types-psycopg2", specifier = "~=2.9.21" },
|
||||
{ name = "types-pygments", specifier = "~=2.19.0" },
|
||||
{ name = "types-pymysql", specifier = "~=1.1.0" },
|
||||
@@ -6508,11 +6508,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "types-psutil"
|
||||
version = "7.0.0.20251116"
|
||||
version = "7.2.2.20260130"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/47/ec/c1e9308b91582cad1d7e7d3007fd003ef45a62c2500f8219313df5fc3bba/types_psutil-7.0.0.20251116.tar.gz", hash = "sha256:92b5c78962e55ce1ed7b0189901a4409ece36ab9fd50c3029cca7e681c606c8a", size = 22192, upload-time = "2025-11-16T03:10:32.859Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/69/14/fc5fb0a6ddfadf68c27e254a02ececd4d5c7fdb0efcb7e7e917a183497fb/types_psutil-7.2.2.20260130.tar.gz", hash = "sha256:15b0ab69c52841cf9ce3c383e8480c620a4d13d6a8e22b16978ebddac5590950", size = 26535, upload-time = "2026-01-30T03:58:14.116Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c3/0e/11ba08a5375c21039ed5f8e6bba41e9452fb69f0e2f7ee05ed5cca2a2cdf/types_psutil-7.0.0.20251116-py3-none-any.whl", hash = "sha256:74c052de077c2024b85cd435e2cba971165fe92a5eace79cbeb821e776dbc047", size = 25376, upload-time = "2025-11-16T03:10:31.813Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/17/d7/60974b7e31545d3768d1770c5fe6e093182c3bfd819429b33133ba6b3e89/types_psutil-7.2.2.20260130-py3-none-any.whl", hash = "sha256:15523a3caa7b3ff03ac7f9b78a6470a59f88f48df1d74a39e70e06d2a99107da", size = 32876, upload-time = "2026-01-30T03:58:13.172Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
set -euxo pipefail
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
cd "$SCRIPT_DIR/../.."
|
||||
|
||||
@@ -662,13 +662,14 @@ services:
|
||||
- "${IRIS_SUPER_SERVER_PORT:-1972}:1972"
|
||||
- "${IRIS_WEB_SERVER_PORT:-52773}:52773"
|
||||
volumes:
|
||||
- ./volumes/iris:/opt/iris
|
||||
- ./volumes/iris:/durable
|
||||
- ./iris/iris-init.script:/iris-init.script
|
||||
- ./iris/docker-entrypoint.sh:/custom-entrypoint.sh
|
||||
entrypoint: ["/custom-entrypoint.sh"]
|
||||
tty: true
|
||||
environment:
|
||||
TZ: ${IRIS_TIMEZONE:-UTC}
|
||||
ISC_DATA_DIRECTORY: /durable/iris
|
||||
|
||||
# Oracle vector database
|
||||
oracle:
|
||||
|
||||
@@ -1348,13 +1348,14 @@ services:
|
||||
- "${IRIS_SUPER_SERVER_PORT:-1972}:1972"
|
||||
- "${IRIS_WEB_SERVER_PORT:-52773}:52773"
|
||||
volumes:
|
||||
- ./volumes/iris:/opt/iris
|
||||
- ./volumes/iris:/durable
|
||||
- ./iris/iris-init.script:/iris-init.script
|
||||
- ./iris/docker-entrypoint.sh:/custom-entrypoint.sh
|
||||
entrypoint: ["/custom-entrypoint.sh"]
|
||||
tty: true
|
||||
environment:
|
||||
TZ: ${IRIS_TIMEZONE:-UTC}
|
||||
ISC_DATA_DIRECTORY: /durable/iris
|
||||
|
||||
# Oracle vector database
|
||||
oracle:
|
||||
|
||||
@@ -1,15 +1,33 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# IRIS configuration flag file
|
||||
IRIS_CONFIG_DONE="/opt/iris/.iris-configured"
|
||||
# IRIS configuration flag file (stored in durable directory to persist with data)
|
||||
IRIS_CONFIG_DONE="/durable/.iris-configured"
|
||||
|
||||
# Function to wait for IRIS to be ready
|
||||
wait_for_iris() {
|
||||
echo "Waiting for IRIS to be ready..."
|
||||
local max_attempts=30
|
||||
local attempt=1
|
||||
while [ "$attempt" -le "$max_attempts" ]; do
|
||||
if iris qlist IRIS 2>/dev/null | grep -q "running"; then
|
||||
echo "IRIS is ready."
|
||||
return 0
|
||||
fi
|
||||
echo "Attempt $attempt/$max_attempts: IRIS not ready yet, waiting..."
|
||||
sleep 2
|
||||
attempt=$((attempt + 1))
|
||||
done
|
||||
echo "ERROR: IRIS failed to start within expected time." >&2
|
||||
return 1
|
||||
}
|
||||
|
||||
# Function to configure IRIS
|
||||
configure_iris() {
|
||||
echo "Configuring IRIS for first-time setup..."
|
||||
|
||||
# Wait for IRIS to be fully started
|
||||
sleep 5
|
||||
wait_for_iris
|
||||
|
||||
# Execute the initialization script
|
||||
iris session IRIS < /iris-init.script
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import Cookies from 'js-cookie'
|
||||
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
|
||||
import { parseAsString, useQueryState } from 'nuqs'
|
||||
import { parseAsBoolean, useQueryState } from 'nuqs'
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import {
|
||||
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
|
||||
@@ -28,7 +28,7 @@ export const AppInitializer = ({
|
||||
const [init, setInit] = useState(false)
|
||||
const [oauthNewUser, setOauthNewUser] = useQueryState(
|
||||
'oauth_new_user',
|
||||
parseAsString.withOptions({ history: 'replace' }),
|
||||
parseAsBoolean.withOptions({ history: 'replace' }),
|
||||
)
|
||||
|
||||
const isSetupFinished = useCallback(async () => {
|
||||
@@ -46,7 +46,7 @@ export const AppInitializer = ({
|
||||
(async () => {
|
||||
const action = searchParams.get('action')
|
||||
|
||||
if (oauthNewUser === 'true') {
|
||||
if (oauthNewUser) {
|
||||
let utmInfo = null
|
||||
const utmInfoStr = Cookies.get('utm_info')
|
||||
if (utmInfoStr) {
|
||||
|
||||
@@ -62,19 +62,19 @@ const AppCard = ({
|
||||
{app.description}
|
||||
</div>
|
||||
</div>
|
||||
{canCreate && (
|
||||
{(canCreate || isTrialApp) && (
|
||||
<div className={cn('absolute bottom-0 left-0 right-0 hidden bg-gradient-to-t from-components-panel-gradient-2 from-[60.27%] to-transparent p-4 pt-8 group-hover:flex')}>
|
||||
<div className={cn('grid h-8 w-full grid-cols-1 items-center space-x-2', isTrialApp && 'grid-cols-2')}>
|
||||
<Button variant="primary" onClick={() => onCreate()}>
|
||||
<PlusIcon className="mr-1 h-4 w-4" />
|
||||
<span className="text-xs">{t('newApp.useTemplate', { ns: 'app' })}</span>
|
||||
</Button>
|
||||
{isTrialApp && (
|
||||
<Button onClick={showTryAPPPanel(app.app_id)}>
|
||||
<RiInformation2Line className="mr-1 size-4" />
|
||||
<span>{t('appCard.try', { ns: 'explore' })}</span>
|
||||
<div className={cn('grid h-8 w-full grid-cols-1 items-center space-x-2', canCreate && 'grid-cols-2')}>
|
||||
{canCreate && (
|
||||
<Button variant="primary" onClick={() => onCreate()}>
|
||||
<PlusIcon className="mr-1 h-4 w-4" />
|
||||
<span className="text-xs">{t('newApp.useTemplate', { ns: 'app' })}</span>
|
||||
</Button>
|
||||
)}
|
||||
<Button onClick={showTryAPPPanel(app.app_id)}>
|
||||
<RiInformation2Line className="mr-1 size-4" />
|
||||
<span>{t('appCard.try', { ns: 'explore' })}</span>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -124,7 +124,7 @@ describe('CreateAppModal', () => {
|
||||
|
||||
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
|
||||
fireEvent.change(nameInput, { target: { value: 'My App' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ }))
|
||||
|
||||
await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({
|
||||
name: 'My App',
|
||||
@@ -152,7 +152,7 @@ describe('CreateAppModal', () => {
|
||||
|
||||
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
|
||||
fireEvent.change(nameInput, { target: { value: 'My App' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ }))
|
||||
|
||||
await waitFor(() => expect(mockCreateApp).toHaveBeenCalled())
|
||||
expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' })
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import type { ListChildComponentProps } from 'react-window'
|
||||
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
|
||||
import { RiArrowDownSLine, RiArrowRightSLine } from '@remixicon/react'
|
||||
import { useVirtualizer } from '@tanstack/react-virtual'
|
||||
import { memo, useCallback, useMemo, useRef, useState } from 'react'
|
||||
import { memo, useEffect, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { areEqual, FixedSizeList as List } from 'react-window'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Checkbox from '../../checkbox'
|
||||
import NotionIcon from '../../notion-icon'
|
||||
@@ -31,22 +32,6 @@ type NotionPageItem = {
|
||||
depth: number
|
||||
} & DataSourceNotionPage
|
||||
|
||||
type ItemProps = {
|
||||
virtualStart: number
|
||||
virtualSize: number
|
||||
current: NotionPageItem
|
||||
onToggle: (pageId: string) => void
|
||||
checkedIds: Set<string>
|
||||
disabledCheckedIds: Set<string>
|
||||
onCheck: (pageId: string) => void
|
||||
canPreview?: boolean
|
||||
onPreview: (pageId: string) => void
|
||||
listMapWithChildrenAndDescendants: NotionPageTreeMap
|
||||
searchValue: string
|
||||
previewPageId: string
|
||||
pagesMap: DataSourceNotionPageMap
|
||||
}
|
||||
|
||||
const recursivePushInParentDescendants = (
|
||||
pagesMap: DataSourceNotionPageMap,
|
||||
listTreeMap: NotionPageTreeMap,
|
||||
@@ -84,22 +69,34 @@ const recursivePushInParentDescendants = (
|
||||
}
|
||||
}
|
||||
|
||||
const ItemComponent = ({
|
||||
virtualStart,
|
||||
virtualSize,
|
||||
current,
|
||||
onToggle,
|
||||
checkedIds,
|
||||
disabledCheckedIds,
|
||||
onCheck,
|
||||
canPreview,
|
||||
onPreview,
|
||||
listMapWithChildrenAndDescendants,
|
||||
searchValue,
|
||||
previewPageId,
|
||||
pagesMap,
|
||||
}: ItemProps) => {
|
||||
const ItemComponent = ({ index, style, data }: ListChildComponentProps<{
|
||||
dataList: NotionPageItem[]
|
||||
handleToggle: (index: number) => void
|
||||
checkedIds: Set<string>
|
||||
disabledCheckedIds: Set<string>
|
||||
handleCheck: (index: number) => void
|
||||
canPreview?: boolean
|
||||
handlePreview: (index: number) => void
|
||||
listMapWithChildrenAndDescendants: NotionPageTreeMap
|
||||
searchValue: string
|
||||
previewPageId: string
|
||||
pagesMap: DataSourceNotionPageMap
|
||||
}>) => {
|
||||
const { t } = useTranslation()
|
||||
const {
|
||||
dataList,
|
||||
handleToggle,
|
||||
checkedIds,
|
||||
disabledCheckedIds,
|
||||
handleCheck,
|
||||
canPreview,
|
||||
handlePreview,
|
||||
listMapWithChildrenAndDescendants,
|
||||
searchValue,
|
||||
previewPageId,
|
||||
pagesMap,
|
||||
} = data
|
||||
const current = dataList[index]
|
||||
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[current.page_id]
|
||||
const hasChild = currentWithChildrenAndDescendants.descendants.size > 0
|
||||
const ancestors = currentWithChildrenAndDescendants.ancestors
|
||||
@@ -112,7 +109,7 @@ const ItemComponent = ({
|
||||
<div
|
||||
className="mr-1 flex h-5 w-5 shrink-0 items-center justify-center rounded-md hover:bg-components-button-ghost-bg-hover"
|
||||
style={{ marginLeft: current.depth * 8 }}
|
||||
onClick={() => onToggle(current.page_id)}
|
||||
onClick={() => handleToggle(index)}
|
||||
>
|
||||
{
|
||||
current.expand
|
||||
@@ -135,21 +132,15 @@ const ItemComponent = ({
|
||||
return (
|
||||
<div
|
||||
className={cn('group flex cursor-pointer items-center rounded-md pl-2 pr-[2px] hover:bg-state-base-hover', previewPageId === current.page_id && 'bg-state-base-hover')}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 8,
|
||||
right: 8,
|
||||
width: 'calc(100% - 16px)',
|
||||
height: virtualSize,
|
||||
transform: `translateY(${virtualStart + 8}px)`,
|
||||
}}
|
||||
style={{ ...style, top: style.top as number + 8, left: 8, right: 8, width: 'calc(100% - 16px)' }}
|
||||
>
|
||||
<Checkbox
|
||||
className="mr-2 shrink-0"
|
||||
checked={checkedIds.has(current.page_id)}
|
||||
disabled={disabled}
|
||||
onCheck={() => onCheck(current.page_id)}
|
||||
onCheck={() => {
|
||||
handleCheck(index)
|
||||
}}
|
||||
/>
|
||||
{!searchValue && renderArrow()}
|
||||
<NotionIcon
|
||||
@@ -169,7 +160,7 @@ const ItemComponent = ({
|
||||
className="ml-1 hidden h-6 shrink-0 cursor-pointer items-center rounded-md border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-2 text-xs
|
||||
font-medium leading-4 text-components-button-secondary-text shadow-xs shadow-shadow-shadow-3 backdrop-blur-[10px]
|
||||
hover:border-components-button-secondary-border-hover hover:bg-components-button-secondary-bg-hover group-hover:flex"
|
||||
onClick={() => onPreview(current.page_id)}
|
||||
onClick={() => handlePreview(index)}
|
||||
>
|
||||
{t('dataSource.notion.selector.preview', { ns: 'common' })}
|
||||
</div>
|
||||
@@ -188,7 +179,7 @@ const ItemComponent = ({
|
||||
</div>
|
||||
)
|
||||
}
|
||||
const Item = memo(ItemComponent)
|
||||
const Item = memo(ItemComponent, areEqual)
|
||||
|
||||
const PageSelector = ({
|
||||
value,
|
||||
@@ -202,10 +193,31 @@ const PageSelector = ({
|
||||
onPreview,
|
||||
}: PageSelectorProps) => {
|
||||
const { t } = useTranslation()
|
||||
const parentRef = useRef<HTMLDivElement>(null)
|
||||
const [expandedIds, setExpandedIds] = useState<Set<string>>(() => new Set())
|
||||
const [dataList, setDataList] = useState<NotionPageItem[]>([])
|
||||
const [localPreviewPageId, setLocalPreviewPageId] = useState('')
|
||||
|
||||
useEffect(() => {
|
||||
setDataList(list.filter(item => item.parent_id === 'root' || !pagesMap[item.parent_id]).map((item) => {
|
||||
return {
|
||||
...item,
|
||||
expand: false,
|
||||
depth: 0,
|
||||
}
|
||||
}))
|
||||
}, [list])
|
||||
|
||||
const searchDataList = list.filter((item) => {
|
||||
return item.page_name.includes(searchValue)
|
||||
}).map((item) => {
|
||||
return {
|
||||
...item,
|
||||
expand: false,
|
||||
depth: 0,
|
||||
}
|
||||
})
|
||||
const currentDataList = searchValue ? searchDataList : dataList
|
||||
const currentPreviewPageId = previewPageId === undefined ? localPreviewPageId : previewPageId
|
||||
|
||||
const listMapWithChildrenAndDescendants = useMemo(() => {
|
||||
return list.reduce((prev: NotionPageTreeMap, next: DataSourceNotionPage) => {
|
||||
const pageId = next.page_id
|
||||
@@ -217,89 +229,47 @@ const PageSelector = ({
|
||||
}, {})
|
||||
}, [list, pagesMap])
|
||||
|
||||
const childrenByParent = useMemo(() => {
|
||||
const map = new Map<string | null, DataSourceNotionPage[]>()
|
||||
for (const item of list) {
|
||||
const isRoot = item.parent_id === 'root' || !pagesMap[item.parent_id]
|
||||
const parentKey = isRoot ? null : item.parent_id
|
||||
const children = map.get(parentKey) || []
|
||||
children.push(item)
|
||||
map.set(parentKey, children)
|
||||
}
|
||||
return map
|
||||
}, [list, pagesMap])
|
||||
|
||||
const dataList = useMemo(() => {
|
||||
const result: NotionPageItem[] = []
|
||||
|
||||
const buildVisibleList = (parentId: string | null, depth: number) => {
|
||||
const items = childrenByParent.get(parentId) || []
|
||||
|
||||
for (const item of items) {
|
||||
const isExpanded = expandedIds.has(item.page_id)
|
||||
result.push({
|
||||
...item,
|
||||
expand: isExpanded,
|
||||
depth,
|
||||
})
|
||||
if (isExpanded) {
|
||||
buildVisibleList(item.page_id, depth + 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
buildVisibleList(null, 0)
|
||||
return result
|
||||
}, [childrenByParent, expandedIds])
|
||||
|
||||
const searchDataList = useMemo(() => list.filter((item) => {
|
||||
return item.page_name.includes(searchValue)
|
||||
}).map((item) => {
|
||||
return {
|
||||
...item,
|
||||
expand: false,
|
||||
depth: 0,
|
||||
}
|
||||
}), [list, searchValue])
|
||||
|
||||
const currentDataList = searchValue ? searchDataList : dataList
|
||||
const currentPreviewPageId = previewPageId === undefined ? localPreviewPageId : previewPageId
|
||||
|
||||
const virtualizer = useVirtualizer({
|
||||
count: currentDataList.length,
|
||||
getScrollElement: () => parentRef.current,
|
||||
estimateSize: () => 28,
|
||||
overscan: 5,
|
||||
getItemKey: index => currentDataList[index].page_id,
|
||||
})
|
||||
|
||||
const handleToggle = useCallback((pageId: string) => {
|
||||
setExpandedIds((prev) => {
|
||||
const next = new Set(prev)
|
||||
if (prev.has(pageId)) {
|
||||
next.delete(pageId)
|
||||
const descendants = listMapWithChildrenAndDescendants[pageId]?.descendants
|
||||
if (descendants) {
|
||||
for (const descendantId of descendants)
|
||||
next.delete(descendantId)
|
||||
}
|
||||
}
|
||||
else {
|
||||
next.add(pageId)
|
||||
}
|
||||
return next
|
||||
})
|
||||
}, [listMapWithChildrenAndDescendants])
|
||||
|
||||
const handleCheck = useCallback((pageId: string) => {
|
||||
const handleToggle = (index: number) => {
|
||||
const current = dataList[index]
|
||||
const pageId = current.page_id
|
||||
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId]
|
||||
const descendantsIds = Array.from(currentWithChildrenAndDescendants.descendants)
|
||||
const childrenIds = Array.from(currentWithChildrenAndDescendants.children)
|
||||
let newDataList = []
|
||||
|
||||
if (current.expand) {
|
||||
current.expand = false
|
||||
|
||||
newDataList = dataList.filter(item => !descendantsIds.includes(item.page_id))
|
||||
}
|
||||
else {
|
||||
current.expand = true
|
||||
|
||||
newDataList = [
|
||||
...dataList.slice(0, index + 1),
|
||||
...childrenIds.map(item => ({
|
||||
...pagesMap[item],
|
||||
expand: false,
|
||||
depth: listMapWithChildrenAndDescendants[item].depth,
|
||||
})),
|
||||
...dataList.slice(index + 1),
|
||||
]
|
||||
}
|
||||
setDataList(newDataList)
|
||||
}
|
||||
|
||||
const copyValue = new Set(value)
|
||||
const handleCheck = (index: number) => {
|
||||
const current = currentDataList[index]
|
||||
const pageId = current.page_id
|
||||
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId]
|
||||
const copyValue = new Set(value)
|
||||
|
||||
if (copyValue.has(pageId)) {
|
||||
if (!searchValue) {
|
||||
for (const item of currentWithChildrenAndDescendants.descendants)
|
||||
copyValue.delete(item)
|
||||
}
|
||||
|
||||
copyValue.delete(pageId)
|
||||
}
|
||||
else {
|
||||
@@ -307,17 +277,22 @@ const PageSelector = ({
|
||||
for (const item of currentWithChildrenAndDescendants.descendants)
|
||||
copyValue.add(item)
|
||||
}
|
||||
|
||||
copyValue.add(pageId)
|
||||
}
|
||||
|
||||
onSelect(copyValue)
|
||||
}, [listMapWithChildrenAndDescendants, onSelect, searchValue, value])
|
||||
onSelect(new Set(copyValue))
|
||||
}
|
||||
|
||||
const handlePreview = (index: number) => {
|
||||
const current = currentDataList[index]
|
||||
const pageId = current.page_id
|
||||
|
||||
const handlePreview = useCallback((pageId: string) => {
|
||||
setLocalPreviewPageId(pageId)
|
||||
|
||||
if (onPreview)
|
||||
onPreview(pageId)
|
||||
}, [onPreview])
|
||||
}
|
||||
|
||||
if (!currentDataList.length) {
|
||||
return (
|
||||
@@ -328,41 +303,29 @@ const PageSelector = ({
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={parentRef}
|
||||
<List
|
||||
className="py-2"
|
||||
style={{ height: 296, width: '100%', overflow: 'auto' }}
|
||||
height={296}
|
||||
itemCount={currentDataList.length}
|
||||
itemSize={28}
|
||||
width="100%"
|
||||
itemKey={(index, data) => data.dataList[index].page_id}
|
||||
itemData={{
|
||||
dataList: currentDataList,
|
||||
handleToggle,
|
||||
checkedIds: value,
|
||||
disabledCheckedIds: disabledValue,
|
||||
handleCheck,
|
||||
canPreview,
|
||||
handlePreview,
|
||||
listMapWithChildrenAndDescendants,
|
||||
searchValue,
|
||||
previewPageId: currentPreviewPageId,
|
||||
pagesMap,
|
||||
}}
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
height: virtualizer.getTotalSize(),
|
||||
width: '100%',
|
||||
position: 'relative',
|
||||
}}
|
||||
>
|
||||
{virtualizer.getVirtualItems().map((virtualRow) => {
|
||||
const current = currentDataList[virtualRow.index]
|
||||
return (
|
||||
<Item
|
||||
key={virtualRow.key}
|
||||
virtualStart={virtualRow.start}
|
||||
virtualSize={virtualRow.size}
|
||||
current={current}
|
||||
onToggle={handleToggle}
|
||||
checkedIds={value}
|
||||
disabledCheckedIds={disabledValue}
|
||||
onCheck={handleCheck}
|
||||
canPreview={canPreview}
|
||||
onPreview={handlePreview}
|
||||
listMapWithChildrenAndDescendants={listMapWithChildrenAndDescendants}
|
||||
searchValue={searchValue}
|
||||
previewPageId={currentPreviewPageId}
|
||||
pagesMap={pagesMap}
|
||||
/>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
{Item}
|
||||
</List>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,18 +11,21 @@ import { recursivePushInParentDescendants } from './utils'
|
||||
|
||||
// Note: react-i18next uses global mock from web/vitest.setup.ts
|
||||
|
||||
// Mock @tanstack/react-virtual useVirtualizer hook - renders items directly for testing
|
||||
vi.mock('@tanstack/react-virtual', () => ({
|
||||
useVirtualizer: ({ count, getItemKey }: { count: number, getItemKey?: (index: number) => string }) => ({
|
||||
getVirtualItems: () =>
|
||||
Array.from({ length: count }).map((_, index) => ({
|
||||
index,
|
||||
key: getItemKey ? getItemKey(index) : index,
|
||||
start: index * 28,
|
||||
size: 28,
|
||||
})),
|
||||
getTotalSize: () => count * 28,
|
||||
}),
|
||||
// Mock react-window FixedSizeList - renders items directly for testing
|
||||
vi.mock('react-window', () => ({
|
||||
FixedSizeList: ({ children: ItemComponent, itemCount, itemData, itemKey }: any) => (
|
||||
<div data-testid="virtual-list">
|
||||
{Array.from({ length: itemCount }).map((_, index) => (
|
||||
<ItemComponent
|
||||
key={itemKey?.(index, itemData) || index}
|
||||
index={index}
|
||||
style={{ top: index * 28, left: 0, right: 0, width: '100%', position: 'absolute' }}
|
||||
data={itemData}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
areEqual: (prevProps: any, nextProps: any) => prevProps === nextProps,
|
||||
}))
|
||||
|
||||
// Note: NotionIcon from @/app/components/base/ is NOT mocked - using real component per testing guidelines
|
||||
@@ -116,7 +119,7 @@ describe('PageSelector', () => {
|
||||
render(<PageSelector {...props} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('Test Page')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('virtual-list')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render empty state when list is empty', () => {
|
||||
@@ -131,7 +134,7 @@ describe('PageSelector', () => {
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('common.dataSource.notion.selector.noSearchResult')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Test Page')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('virtual-list')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render items using FixedSizeList', () => {
|
||||
@@ -1163,7 +1166,7 @@ describe('PageSelector', () => {
|
||||
render(<PageSelector {...props} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('Test Page')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('virtual-list')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle special characters in page name', () => {
|
||||
@@ -1337,7 +1340,7 @@ describe('PageSelector', () => {
|
||||
render(<PageSelector {...props} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('Test Page')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('virtual-list')).toBeInTheDocument()
|
||||
if (propVariation.canPreview)
|
||||
expect(screen.getByText('common.dataSource.notion.selector.preview')).toBeInTheDocument()
|
||||
else
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
|
||||
import { useVirtualizer } from '@tanstack/react-virtual'
|
||||
import { useCallback, useMemo, useRef, useState } from 'react'
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { FixedSizeList as List } from 'react-window'
|
||||
import Item from './item'
|
||||
import { recursivePushInParentDescendants } from './utils'
|
||||
|
||||
@@ -45,16 +45,29 @@ const PageSelector = ({
|
||||
currentCredentialId,
|
||||
}: PageSelectorProps) => {
|
||||
const { t } = useTranslation()
|
||||
const parentRef = useRef<HTMLDivElement>(null)
|
||||
const [expandedIds, setExpandedIds] = useState<Set<string>>(() => new Set())
|
||||
const [dataList, setDataList] = useState<NotionPageItem[]>([])
|
||||
const [currentPreviewPageId, setCurrentPreviewPageId] = useState('')
|
||||
const prevCredentialIdRef = useRef(currentCredentialId)
|
||||
|
||||
// Reset expanded state when credential changes (render-time detection)
|
||||
if (prevCredentialIdRef.current !== currentCredentialId) {
|
||||
prevCredentialIdRef.current = currentCredentialId
|
||||
setExpandedIds(new Set())
|
||||
}
|
||||
useEffect(() => {
|
||||
setDataList(list.filter(item => item.parent_id === 'root' || !pagesMap[item.parent_id]).map((item) => {
|
||||
return {
|
||||
...item,
|
||||
expand: false,
|
||||
depth: 0,
|
||||
}
|
||||
}))
|
||||
}, [currentCredentialId])
|
||||
|
||||
const searchDataList = list.filter((item) => {
|
||||
return item.page_name.includes(searchValue)
|
||||
}).map((item) => {
|
||||
return {
|
||||
...item,
|
||||
expand: false,
|
||||
depth: 0,
|
||||
}
|
||||
})
|
||||
const currentDataList = searchValue ? searchDataList : dataList
|
||||
|
||||
const listMapWithChildrenAndDescendants = useMemo(() => {
|
||||
return list.reduce((prev: NotionPageTreeMap, next: DataSourceNotionPage) => {
|
||||
@@ -67,86 +80,39 @@ const PageSelector = ({
|
||||
}, {})
|
||||
}, [list, pagesMap])
|
||||
|
||||
// Pre-build children index for O(1) lookup instead of O(n) filter
|
||||
const childrenByParent = useMemo(() => {
|
||||
const map = new Map<string | null, DataSourceNotionPage[]>()
|
||||
for (const item of list) {
|
||||
const isRoot = item.parent_id === 'root' || !pagesMap[item.parent_id]
|
||||
const parentKey = isRoot ? null : item.parent_id
|
||||
const children = map.get(parentKey) || []
|
||||
children.push(item)
|
||||
map.set(parentKey, children)
|
||||
const handleToggle = useCallback((index: number) => {
|
||||
const current = dataList[index]
|
||||
const pageId = current.page_id
|
||||
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId]
|
||||
const descendantsIds = Array.from(currentWithChildrenAndDescendants.descendants)
|
||||
const childrenIds = Array.from(currentWithChildrenAndDescendants.children)
|
||||
let newDataList = []
|
||||
|
||||
if (current.expand) {
|
||||
current.expand = false
|
||||
|
||||
newDataList = dataList.filter(item => !descendantsIds.includes(item.page_id))
|
||||
}
|
||||
return map
|
||||
}, [list, pagesMap])
|
||||
else {
|
||||
current.expand = true
|
||||
|
||||
// Compute visible data list based on expanded state
|
||||
const dataList = useMemo(() => {
|
||||
const result: NotionPageItem[] = []
|
||||
|
||||
const buildVisibleList = (parentId: string | null, depth: number) => {
|
||||
const items = childrenByParent.get(parentId) || []
|
||||
|
||||
for (const item of items) {
|
||||
const isExpanded = expandedIds.has(item.page_id)
|
||||
result.push({
|
||||
...item,
|
||||
expand: isExpanded,
|
||||
depth,
|
||||
})
|
||||
if (isExpanded) {
|
||||
buildVisibleList(item.page_id, depth + 1)
|
||||
}
|
||||
}
|
||||
newDataList = [
|
||||
...dataList.slice(0, index + 1),
|
||||
...childrenIds.map(item => ({
|
||||
...pagesMap[item],
|
||||
expand: false,
|
||||
depth: listMapWithChildrenAndDescendants[item].depth,
|
||||
})),
|
||||
...dataList.slice(index + 1),
|
||||
]
|
||||
}
|
||||
setDataList(newDataList)
|
||||
}, [dataList, listMapWithChildrenAndDescendants, pagesMap])
|
||||
|
||||
buildVisibleList(null, 0)
|
||||
return result
|
||||
}, [childrenByParent, expandedIds])
|
||||
|
||||
const searchDataList = useMemo(() => list.filter((item) => {
|
||||
return item.page_name.includes(searchValue)
|
||||
}).map((item) => {
|
||||
return {
|
||||
...item,
|
||||
expand: false,
|
||||
depth: 0,
|
||||
}
|
||||
}), [list, searchValue])
|
||||
|
||||
const currentDataList = searchValue ? searchDataList : dataList
|
||||
|
||||
const virtualizer = useVirtualizer({
|
||||
count: currentDataList.length,
|
||||
getScrollElement: () => parentRef.current,
|
||||
estimateSize: () => 28,
|
||||
overscan: 5,
|
||||
getItemKey: index => currentDataList[index].page_id,
|
||||
})
|
||||
|
||||
// Stable callback - no dependencies on dataList
|
||||
const handleToggle = useCallback((pageId: string) => {
|
||||
setExpandedIds((prev) => {
|
||||
const next = new Set(prev)
|
||||
if (prev.has(pageId)) {
|
||||
// Collapse: remove current and all descendants
|
||||
next.delete(pageId)
|
||||
const descendants = listMapWithChildrenAndDescendants[pageId]?.descendants
|
||||
if (descendants) {
|
||||
for (const descendantId of descendants)
|
||||
next.delete(descendantId)
|
||||
}
|
||||
}
|
||||
else {
|
||||
next.add(pageId)
|
||||
}
|
||||
return next
|
||||
})
|
||||
}, [listMapWithChildrenAndDescendants])
|
||||
|
||||
// Stable callback - uses pageId parameter instead of index
|
||||
const handleCheck = useCallback((pageId: string) => {
|
||||
const handleCheck = useCallback((index: number) => {
|
||||
const copyValue = new Set(checkedIds)
|
||||
const current = currentDataList[index]
|
||||
const pageId = current.page_id
|
||||
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId]
|
||||
|
||||
if (copyValue.has(pageId)) {
|
||||
@@ -154,6 +120,7 @@ const PageSelector = ({
|
||||
for (const item of currentWithChildrenAndDescendants.descendants)
|
||||
copyValue.delete(item)
|
||||
}
|
||||
|
||||
copyValue.delete(pageId)
|
||||
}
|
||||
else {
|
||||
@@ -171,15 +138,18 @@ const PageSelector = ({
|
||||
}
|
||||
}
|
||||
|
||||
onSelect(copyValue)
|
||||
}, [checkedIds, isMultipleChoice, listMapWithChildrenAndDescendants, onSelect, searchValue])
|
||||
onSelect(new Set(copyValue))
|
||||
}, [currentDataList, isMultipleChoice, listMapWithChildrenAndDescendants, onSelect, searchValue, checkedIds])
|
||||
|
||||
const handlePreview = useCallback((index: number) => {
|
||||
const current = currentDataList[index]
|
||||
const pageId = current.page_id
|
||||
|
||||
// Stable callback
|
||||
const handlePreview = useCallback((pageId: string) => {
|
||||
setCurrentPreviewPageId(pageId)
|
||||
|
||||
if (onPreview)
|
||||
onPreview(pageId)
|
||||
}, [onPreview])
|
||||
}, [currentDataList, onPreview])
|
||||
|
||||
if (!currentDataList.length) {
|
||||
return (
|
||||
@@ -190,42 +160,30 @@ const PageSelector = ({
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={parentRef}
|
||||
<List
|
||||
className="py-2"
|
||||
style={{ height: 296, width: '100%', overflow: 'auto' }}
|
||||
height={296}
|
||||
itemCount={currentDataList.length}
|
||||
itemSize={28}
|
||||
width="100%"
|
||||
itemKey={(index, data) => data.dataList[index].page_id}
|
||||
itemData={{
|
||||
dataList: currentDataList,
|
||||
handleToggle,
|
||||
checkedIds,
|
||||
disabledCheckedIds: disabledValue,
|
||||
handleCheck,
|
||||
canPreview,
|
||||
handlePreview,
|
||||
listMapWithChildrenAndDescendants,
|
||||
searchValue,
|
||||
previewPageId: currentPreviewPageId,
|
||||
pagesMap,
|
||||
isMultipleChoice,
|
||||
}}
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
height: virtualizer.getTotalSize(),
|
||||
width: '100%',
|
||||
position: 'relative',
|
||||
}}
|
||||
>
|
||||
{virtualizer.getVirtualItems().map((virtualRow) => {
|
||||
const current = currentDataList[virtualRow.index]
|
||||
return (
|
||||
<Item
|
||||
key={virtualRow.key}
|
||||
virtualStart={virtualRow.start}
|
||||
virtualSize={virtualRow.size}
|
||||
current={current}
|
||||
onToggle={handleToggle}
|
||||
checkedIds={checkedIds}
|
||||
disabledCheckedIds={disabledValue}
|
||||
onCheck={handleCheck}
|
||||
canPreview={canPreview}
|
||||
onPreview={handlePreview}
|
||||
listMapWithChildrenAndDescendants={listMapWithChildrenAndDescendants}
|
||||
searchValue={searchValue}
|
||||
previewPageId={currentPreviewPageId}
|
||||
pagesMap={pagesMap}
|
||||
isMultipleChoice={isMultipleChoice}
|
||||
/>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
{Item}
|
||||
</List>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import type { ListChildComponentProps } from 'react-window'
|
||||
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
|
||||
import { RiArrowDownSLine, RiArrowRightSLine } from '@remixicon/react'
|
||||
import { memo } from 'react'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { areEqual } from 'react-window'
|
||||
import Checkbox from '@/app/components/base/checkbox'
|
||||
import NotionIcon from '@/app/components/base/notion-icon'
|
||||
import Radio from '@/app/components/base/radio/ui'
|
||||
@@ -21,40 +23,36 @@ type NotionPageItem = {
|
||||
depth: number
|
||||
} & DataSourceNotionPage
|
||||
|
||||
type ItemProps = {
|
||||
virtualStart: number
|
||||
virtualSize: number
|
||||
current: NotionPageItem
|
||||
onToggle: (pageId: string) => void
|
||||
const Item = ({ index, style, data }: ListChildComponentProps<{
|
||||
dataList: NotionPageItem[]
|
||||
handleToggle: (index: number) => void
|
||||
checkedIds: Set<string>
|
||||
disabledCheckedIds: Set<string>
|
||||
onCheck: (pageId: string) => void
|
||||
handleCheck: (index: number) => void
|
||||
canPreview?: boolean
|
||||
onPreview: (pageId: string) => void
|
||||
handlePreview: (index: number) => void
|
||||
listMapWithChildrenAndDescendants: NotionPageTreeMap
|
||||
searchValue: string
|
||||
previewPageId: string
|
||||
pagesMap: DataSourceNotionPageMap
|
||||
isMultipleChoice?: boolean
|
||||
}
|
||||
|
||||
const Item = ({
|
||||
virtualStart,
|
||||
virtualSize,
|
||||
current,
|
||||
onToggle,
|
||||
checkedIds,
|
||||
disabledCheckedIds,
|
||||
onCheck,
|
||||
canPreview,
|
||||
onPreview,
|
||||
listMapWithChildrenAndDescendants,
|
||||
searchValue,
|
||||
previewPageId,
|
||||
pagesMap,
|
||||
isMultipleChoice,
|
||||
}: ItemProps) => {
|
||||
}>) => {
|
||||
const { t } = useTranslation()
|
||||
const {
|
||||
dataList,
|
||||
handleToggle,
|
||||
checkedIds,
|
||||
disabledCheckedIds,
|
||||
handleCheck,
|
||||
canPreview,
|
||||
handlePreview,
|
||||
listMapWithChildrenAndDescendants,
|
||||
searchValue,
|
||||
previewPageId,
|
||||
pagesMap,
|
||||
isMultipleChoice,
|
||||
} = data
|
||||
const current = dataList[index]
|
||||
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[current.page_id]
|
||||
const hasChild = currentWithChildrenAndDescendants.descendants.size > 0
|
||||
const ancestors = currentWithChildrenAndDescendants.ancestors
|
||||
@@ -67,7 +65,7 @@ const Item = ({
|
||||
<div
|
||||
className="mr-1 flex h-5 w-5 shrink-0 items-center justify-center rounded-md hover:bg-components-button-ghost-bg-hover"
|
||||
style={{ marginLeft: current.depth * 8 }}
|
||||
onClick={() => onToggle(current.page_id)}
|
||||
onClick={() => handleToggle(index)}
|
||||
>
|
||||
{
|
||||
current.expand
|
||||
@@ -90,15 +88,7 @@ const Item = ({
|
||||
return (
|
||||
<div
|
||||
className={cn('group flex cursor-pointer items-center rounded-md pl-2 pr-[2px] hover:bg-state-base-hover', previewPageId === current.page_id && 'bg-state-base-hover')}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 8,
|
||||
right: 8,
|
||||
width: 'calc(100% - 16px)',
|
||||
height: virtualSize,
|
||||
transform: `translateY(${virtualStart + 8}px)`,
|
||||
}}
|
||||
style={{ ...style, top: style.top as number + 8, left: 8, right: 8, width: 'calc(100% - 16px)' }}
|
||||
>
|
||||
{isMultipleChoice
|
||||
? (
|
||||
@@ -106,7 +96,9 @@ const Item = ({
|
||||
className="mr-2 shrink-0"
|
||||
checked={checkedIds.has(current.page_id)}
|
||||
disabled={disabled}
|
||||
onCheck={() => onCheck(current.page_id)}
|
||||
onCheck={() => {
|
||||
handleCheck(index)
|
||||
}}
|
||||
/>
|
||||
)
|
||||
: (
|
||||
@@ -114,7 +106,9 @@ const Item = ({
|
||||
className="mr-2 shrink-0"
|
||||
isChecked={checkedIds.has(current.page_id)}
|
||||
disabled={disabled}
|
||||
onCheck={() => onCheck(current.page_id)}
|
||||
onCheck={() => {
|
||||
handleCheck(index)
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{!searchValue && renderArrow()}
|
||||
@@ -135,7 +129,7 @@ const Item = ({
|
||||
className="ml-1 hidden h-6 shrink-0 cursor-pointer items-center rounded-md border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-2 text-xs
|
||||
font-medium leading-4 text-components-button-secondary-text shadow-xs shadow-shadow-shadow-3 backdrop-blur-[10px]
|
||||
hover:border-components-button-secondary-border-hover hover:bg-components-button-secondary-bg-hover group-hover:flex"
|
||||
onClick={() => onPreview(current.page_id)}
|
||||
onClick={() => handlePreview(index)}
|
||||
>
|
||||
{t('dataSource.notion.selector.preview', { ns: 'common' })}
|
||||
</div>
|
||||
@@ -155,4 +149,4 @@ const Item = ({
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(Item)
|
||||
export default React.memo(Item, areEqual)
|
||||
|
||||
@@ -3,8 +3,6 @@ import type { FC, ReactNode } from 'react'
|
||||
import type { SliceProps } from './type'
|
||||
import { autoUpdate, flip, FloatingFocusManager, offset, shift, useDismiss, useFloating, useHover, useInteractions, useRole } from '@floating-ui/react'
|
||||
import { RiDeleteBinLine } from '@remixicon/react'
|
||||
// @ts-expect-error no types available
|
||||
import lineClamp from 'line-clamp'
|
||||
import { useState } from 'react'
|
||||
import ActionButton, { ActionButtonState } from '@/app/components/base/action-button'
|
||||
import { cn } from '@/utils/classnames'
|
||||
@@ -58,12 +56,8 @@ export const EditSlice: FC<EditSliceProps> = (props) => {
|
||||
<>
|
||||
<SliceContainer
|
||||
{...rest}
|
||||
className={cn('mr-0 block', className)}
|
||||
ref={(ref) => {
|
||||
refs.setReference(ref)
|
||||
if (ref)
|
||||
lineClamp(ref, 4)
|
||||
}}
|
||||
className={cn('mr-0 line-clamp-4 block', className)}
|
||||
ref={refs.setReference}
|
||||
{...getReferenceProps()}
|
||||
>
|
||||
<SliceLabel
|
||||
|
||||
@@ -74,11 +74,15 @@ const AppCard = ({
|
||||
</div>
|
||||
{isExplore && (canCreate || isTrialApp) && (
|
||||
<div className={cn('absolute bottom-0 left-0 right-0 hidden bg-gradient-to-t from-components-panel-gradient-2 from-[60.27%] to-transparent p-4 pt-8 group-hover:flex')}>
|
||||
<div className={cn('grid h-8 w-full grid-cols-2 space-x-2')}>
|
||||
<Button variant="primary" className="h-7" onClick={() => onCreate()}>
|
||||
<PlusIcon className="mr-1 h-4 w-4" />
|
||||
<span className="text-xs">{t('appCard.addToWorkspace', { ns: 'explore' })}</span>
|
||||
</Button>
|
||||
<div className={cn('grid h-8 w-full grid-cols-1 space-x-2', canCreate && 'grid-cols-2')}>
|
||||
{
|
||||
canCreate && (
|
||||
<Button variant="primary" className="h-7" onClick={() => onCreate()}>
|
||||
<PlusIcon className="mr-1 h-4 w-4" />
|
||||
<span className="text-xs">{t('appCard.addToWorkspace', { ns: 'explore' })}</span>
|
||||
</Button>
|
||||
)
|
||||
}
|
||||
<Button className="h-7" onClick={showTryAPPPanel(app.app_id)}>
|
||||
<RiInformation2Line className="mr-1 size-4" />
|
||||
<span>{t('appCard.try', { ns: 'explore' })}</span>
|
||||
|
||||
@@ -138,7 +138,7 @@ describe('CreateAppModal', () => {
|
||||
setup({ appName: 'My App', isEditModal: false })
|
||||
|
||||
expect(screen.getByText('explore.appCustomize.title:{"name":"My App"}')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.cancel' })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
@@ -146,7 +146,7 @@ describe('CreateAppModal', () => {
|
||||
setup({ isEditModal: true, appMode: AppModeEnum.CHAT, max_active_requests: 5 })
|
||||
|
||||
expect(screen.getByText('app.editAppTitle')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: /common\.operation\.save/ })).toBeInTheDocument()
|
||||
expect(screen.getByRole('switch')).toBeInTheDocument()
|
||||
expect((screen.getByRole('spinbutton') as HTMLInputElement).value).toBe('5')
|
||||
})
|
||||
@@ -166,7 +166,7 @@ describe('CreateAppModal', () => {
|
||||
it('should not render modal content when hidden', () => {
|
||||
setup({ show: false })
|
||||
|
||||
expect(screen.queryByRole('button', { name: 'common.operation.create' })).not.toBeInTheDocument()
|
||||
expect(screen.queryByRole('button', { name: /common\.operation\.create/ })).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -175,13 +175,13 @@ describe('CreateAppModal', () => {
|
||||
it('should disable confirm action when confirmDisabled is true', () => {
|
||||
setup({ confirmDisabled: true })
|
||||
|
||||
expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled()
|
||||
expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should disable confirm action when appName is empty', () => {
|
||||
setup({ appName: ' ' })
|
||||
|
||||
expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled()
|
||||
expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -245,7 +245,7 @@ describe('CreateAppModal', () => {
|
||||
setup({ isEditModal: false })
|
||||
|
||||
expect(screen.getByText('billing.apps.fullTip2')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled()
|
||||
expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should allow saving when apps quota is reached in edit mode', () => {
|
||||
@@ -257,7 +257,7 @@ describe('CreateAppModal', () => {
|
||||
setup({ isEditModal: true })
|
||||
|
||||
expect(screen.queryByText('billing.apps.fullTip2')).not.toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeEnabled()
|
||||
expect(screen.getByRole('button', { name: /common\.operation\.save/ })).toBeEnabled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -384,7 +384,7 @@ describe('CreateAppModal', () => {
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.iconPicker.ok' }))
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(300)
|
||||
})
|
||||
@@ -433,7 +433,7 @@ describe('CreateAppModal', () => {
|
||||
expect(screen.queryByRole('button', { name: 'app.iconPicker.cancel' })).not.toBeInTheDocument()
|
||||
|
||||
// Submit and verify the payload uses the original icon (cancel reverts to props)
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(300)
|
||||
})
|
||||
@@ -471,7 +471,7 @@ describe('CreateAppModal', () => {
|
||||
appIconBackground: '#000000',
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(300)
|
||||
})
|
||||
@@ -495,7 +495,7 @@ describe('CreateAppModal', () => {
|
||||
const { onConfirm } = setup({ appDescription: 'Old description' })
|
||||
|
||||
fireEvent.change(screen.getByPlaceholderText('app.newApp.appDescriptionPlaceholder'), { target: { value: 'Updated description' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(300)
|
||||
})
|
||||
@@ -512,7 +512,7 @@ describe('CreateAppModal', () => {
|
||||
appIconBackground: null,
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(300)
|
||||
})
|
||||
@@ -536,7 +536,7 @@ describe('CreateAppModal', () => {
|
||||
fireEvent.click(screen.getByRole('switch'))
|
||||
fireEvent.change(screen.getByRole('spinbutton'), { target: { value: '12' } })
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ }))
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(300)
|
||||
})
|
||||
@@ -551,7 +551,7 @@ describe('CreateAppModal', () => {
|
||||
it('should omit max_active_requests when input is empty', () => {
|
||||
const { onConfirm } = setup({ isEditModal: true, max_active_requests: null })
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ }))
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(300)
|
||||
})
|
||||
@@ -564,7 +564,7 @@ describe('CreateAppModal', () => {
|
||||
const { onConfirm } = setup({ isEditModal: true, max_active_requests: null })
|
||||
|
||||
fireEvent.change(screen.getByRole('spinbutton'), { target: { value: 'abc' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ }))
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(300)
|
||||
})
|
||||
@@ -576,7 +576,7 @@ describe('CreateAppModal', () => {
|
||||
it('should show toast error and not submit when name becomes empty before debounced submit runs', () => {
|
||||
const { onConfirm, onHide } = setup({ appName: 'My App' })
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
|
||||
fireEvent.change(screen.getByPlaceholderText('app.newApp.appNamePlaceholder'), { target: { value: ' ' } })
|
||||
|
||||
act(() => {
|
||||
|
||||
@@ -16,6 +16,14 @@ vi.mock('react-i18next', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/config', async (importOriginal) => {
|
||||
const actual = await importOriginal() as object
|
||||
return {
|
||||
...actual,
|
||||
IS_CLOUD_EDITION: true,
|
||||
}
|
||||
})
|
||||
|
||||
const mockUseGetTryAppInfo = vi.fn()
|
||||
|
||||
vi.mock('@/service/use-try-app', () => ({
|
||||
|
||||
@@ -14,6 +14,14 @@ vi.mock('react-i18next', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/config', async (importOriginal) => {
|
||||
const actual = await importOriginal() as object
|
||||
return {
|
||||
...actual,
|
||||
IS_CLOUD_EDITION: true,
|
||||
}
|
||||
})
|
||||
|
||||
describe('Tab', () => {
|
||||
afterEach(() => {
|
||||
cleanup()
|
||||
|
||||
@@ -81,4 +81,205 @@ describe('CommandSelector', () => {
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith('/zen')
|
||||
})
|
||||
|
||||
it('should show all slash commands when no filter provided', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="/"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
// Should show the zen command from mock
|
||||
expect(screen.getByText('/zen')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should exclude slash action when in @ mode', () => {
|
||||
const actions = {
|
||||
...createActions(),
|
||||
slash: {
|
||||
key: '/',
|
||||
shortcut: '/',
|
||||
title: 'Slash',
|
||||
search: vi.fn(),
|
||||
description: '',
|
||||
} as ActionItem,
|
||||
}
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
// Should show @ commands but not /
|
||||
expect(screen.getByText('@app')).toBeInTheDocument()
|
||||
expect(screen.queryByText('/')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show all actions when no filter in @ mode', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('@app')).toBeInTheDocument()
|
||||
expect(screen.getByText('@plugin')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should set default command value when items exist but value does not', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
const onCommandValueChange = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
commandValue="non-existent"
|
||||
onCommandValueChange={onCommandValueChange}
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(onCommandValueChange).toHaveBeenCalledWith('@app')
|
||||
})
|
||||
|
||||
it('should NOT set command value when value already exists in items', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
const onCommandValueChange = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
commandValue="@app"
|
||||
onCommandValueChange={onCommandValueChange}
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(onCommandValueChange).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show no matching commands message when filter has no results', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter="nonexistent"
|
||||
originalQuery="@nonexistent"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show no matching commands for slash mode with no results', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter="nonexistentcommand"
|
||||
originalQuery="/nonexistentcommand"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render description for @ commands', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.actions.searchApplicationsDesc')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.gotoAnything.actions.searchPluginsDesc')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render group header for @ mode', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="@"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.selectSearchType')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render group header for slash mode', () => {
|
||||
const actions = createActions()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Command>
|
||||
<CommandSelector
|
||||
actions={actions}
|
||||
onCommandSelect={onSelect}
|
||||
searchFilter=""
|
||||
originalQuery="/"
|
||||
/>
|
||||
</Command>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.groups.commands')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
157
web/app/components/goto-anything/components/empty-state.spec.tsx
Normal file
157
web/app/components/goto-anything/components/empty-state.spec.tsx
Normal file
@@ -0,0 +1,157 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import EmptyState from './empty-state'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: { ns?: string, shortcuts?: string }) => {
|
||||
if (options?.shortcuts !== undefined)
|
||||
return `${key}:${options.shortcuts}`
|
||||
return `${options?.ns || 'common'}.${key}`
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('EmptyState', () => {
|
||||
describe('loading variant', () => {
|
||||
it('should render loading spinner', () => {
|
||||
render(<EmptyState variant="loading" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.searching')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should have spinner animation class', () => {
|
||||
const { container } = render(<EmptyState variant="loading" />)
|
||||
|
||||
const spinner = container.querySelector('.animate-spin')
|
||||
expect(spinner).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('error variant', () => {
|
||||
it('should render error message when error has message', () => {
|
||||
const error = new Error('Connection failed')
|
||||
render(<EmptyState variant="error" error={error} />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.searchFailed')).toBeInTheDocument()
|
||||
expect(screen.getByText('Connection failed')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render generic error when error has no message', () => {
|
||||
render(<EmptyState variant="error" error={null} />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.searchTemporarilyUnavailable')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.gotoAnything.servicesUnavailableMessage')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render generic error when error is undefined', () => {
|
||||
render(<EmptyState variant="error" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.searchTemporarilyUnavailable')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should have red error text styling', () => {
|
||||
const error = new Error('Test error')
|
||||
const { container } = render(<EmptyState variant="error" error={error} />)
|
||||
|
||||
const errorText = container.querySelector('.text-red-500')
|
||||
expect(errorText).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('default variant', () => {
|
||||
it('should render search title', () => {
|
||||
render(<EmptyState variant="default" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.searchTitle')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render all hint messages', () => {
|
||||
render(<EmptyState variant="default" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.searchHint')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.gotoAnything.commandHint')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.gotoAnything.slashHint')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('no-results variant', () => {
|
||||
describe('general search mode', () => {
|
||||
it('should render generic no results message', () => {
|
||||
render(<EmptyState variant="no-results" searchMode="general" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show specific search hint with shortcuts', () => {
|
||||
const Actions = {
|
||||
app: { key: '@app', shortcut: '@app' },
|
||||
plugin: { key: '@plugin', shortcut: '@plugin' },
|
||||
} as unknown as Record<string, import('../actions/types').ActionItem>
|
||||
render(<EmptyState variant="no-results" searchMode="general" Actions={Actions} />)
|
||||
|
||||
expect(screen.getByText('gotoAnything.emptyState.trySpecificSearch:@app, @plugin')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('app search mode', () => {
|
||||
it('should render no apps found message', () => {
|
||||
render(<EmptyState variant="no-results" searchMode="@app" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.emptyState.noAppsFound')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show try different term hint', () => {
|
||||
render(<EmptyState variant="no-results" searchMode="@app" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.emptyState.tryDifferentTerm')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('plugin search mode', () => {
|
||||
it('should render no plugins found message', () => {
|
||||
render(<EmptyState variant="no-results" searchMode="@plugin" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.emptyState.noPluginsFound')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('knowledge search mode', () => {
|
||||
it('should render no knowledge bases found message', () => {
|
||||
render(<EmptyState variant="no-results" searchMode="@knowledge" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.emptyState.noKnowledgeBasesFound')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('node search mode', () => {
|
||||
it('should render no workflow nodes found message', () => {
|
||||
render(<EmptyState variant="no-results" searchMode="@node" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.emptyState.noWorkflowNodesFound')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('unknown search mode', () => {
|
||||
it('should fallback to generic no results message', () => {
|
||||
render(<EmptyState variant="no-results" searchMode="@unknown" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('default props', () => {
|
||||
it('should use general as default searchMode', () => {
|
||||
render(<EmptyState variant="no-results" />)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should use empty object as default Actions', () => {
|
||||
render(<EmptyState variant="no-results" searchMode="general" />)
|
||||
|
||||
// Should show empty shortcuts
|
||||
expect(screen.getByText('gotoAnything.emptyState.trySpecificSearch:')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
105
web/app/components/goto-anything/components/empty-state.tsx
Normal file
105
web/app/components/goto-anything/components/empty-state.tsx
Normal file
@@ -0,0 +1,105 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { ActionItem } from '../actions/types'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
export type EmptyStateVariant = 'no-results' | 'error' | 'default' | 'loading'
|
||||
|
||||
export type EmptyStateProps = {
|
||||
variant: EmptyStateVariant
|
||||
searchMode?: string
|
||||
error?: Error | null
|
||||
Actions?: Record<string, ActionItem>
|
||||
}
|
||||
|
||||
const EmptyState: FC<EmptyStateProps> = ({
|
||||
variant,
|
||||
searchMode = 'general',
|
||||
error,
|
||||
Actions = {},
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
if (variant === 'loading') {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="h-4 w-4 animate-spin rounded-full border-2 border-gray-300 border-t-gray-600"></div>
|
||||
<span className="text-sm">{t('gotoAnything.searching', { ns: 'app' })}</span>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (variant === 'error') {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
|
||||
<div>
|
||||
<div className="text-sm font-medium text-red-500">
|
||||
{error?.message
|
||||
? t('gotoAnything.searchFailed', { ns: 'app' })
|
||||
: t('gotoAnything.searchTemporarilyUnavailable', { ns: 'app' })}
|
||||
</div>
|
||||
<div className="mt-1 text-xs text-text-quaternary">
|
||||
{error?.message || t('gotoAnything.servicesUnavailableMessage', { ns: 'app' })}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (variant === 'default') {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
|
||||
<div>
|
||||
<div className="text-sm font-medium">{t('gotoAnything.searchTitle', { ns: 'app' })}</div>
|
||||
<div className="mt-3 space-y-1 text-xs text-text-quaternary">
|
||||
<div>{t('gotoAnything.searchHint', { ns: 'app' })}</div>
|
||||
<div>{t('gotoAnything.commandHint', { ns: 'app' })}</div>
|
||||
<div>{t('gotoAnything.slashHint', { ns: 'app' })}</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// variant === 'no-results'
|
||||
const isCommandSearch = searchMode !== 'general'
|
||||
const commandType = isCommandSearch ? searchMode.replace('@', '') : ''
|
||||
|
||||
const getNoResultsMessage = () => {
|
||||
if (!isCommandSearch) {
|
||||
return t('gotoAnything.noResults', { ns: 'app' })
|
||||
}
|
||||
|
||||
const keyMap = {
|
||||
app: 'gotoAnything.emptyState.noAppsFound',
|
||||
plugin: 'gotoAnything.emptyState.noPluginsFound',
|
||||
knowledge: 'gotoAnything.emptyState.noKnowledgeBasesFound',
|
||||
node: 'gotoAnything.emptyState.noWorkflowNodesFound',
|
||||
} as const
|
||||
|
||||
return t(keyMap[commandType as keyof typeof keyMap] || 'gotoAnything.noResults', { ns: 'app' })
|
||||
}
|
||||
|
||||
const getHintMessage = () => {
|
||||
if (isCommandSearch) {
|
||||
return t('gotoAnything.emptyState.tryDifferentTerm', { ns: 'app' })
|
||||
}
|
||||
|
||||
const shortcuts = Object.values(Actions).map(action => action.shortcut).join(', ')
|
||||
return t('gotoAnything.emptyState.trySpecificSearch', { ns: 'app', shortcuts })
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
|
||||
<div>
|
||||
<div className="text-sm font-medium">{getNoResultsMessage()}</div>
|
||||
<div className="mt-1 text-xs text-text-quaternary">{getHintMessage()}</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default EmptyState
|
||||
273
web/app/components/goto-anything/components/footer.spec.tsx
Normal file
273
web/app/components/goto-anything/components/footer.spec.tsx
Normal file
@@ -0,0 +1,273 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import Footer from './footer'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: { ns?: string, count?: number, scope?: string }) => {
|
||||
if (options?.count !== undefined)
|
||||
return `${key}:${options.count}`
|
||||
if (options?.scope)
|
||||
return `${key}:${options.scope}`
|
||||
return `${options?.ns || 'common'}.${key}`
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('Footer', () => {
|
||||
describe('left content', () => {
|
||||
describe('when there are results', () => {
|
||||
it('should show result count', () => {
|
||||
render(
|
||||
<Footer
|
||||
resultCount={5}
|
||||
searchMode="general"
|
||||
isError={false}
|
||||
isCommandsMode={false}
|
||||
hasQuery={true}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('gotoAnything.resultCount:5')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show scope when not in general mode', () => {
|
||||
render(
|
||||
<Footer
|
||||
resultCount={3}
|
||||
searchMode="@app"
|
||||
isError={false}
|
||||
isCommandsMode={false}
|
||||
hasQuery={true}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('gotoAnything.inScope:app')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should NOT show scope when in general mode', () => {
|
||||
render(
|
||||
<Footer
|
||||
resultCount={3}
|
||||
searchMode="general"
|
||||
isError={false}
|
||||
isCommandsMode={false}
|
||||
hasQuery={true}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText(/inScope/)).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('when there is an error', () => {
|
||||
it('should show error message', () => {
|
||||
render(
|
||||
<Footer
|
||||
resultCount={0}
|
||||
searchMode="general"
|
||||
isError={true}
|
||||
isCommandsMode={false}
|
||||
hasQuery={true}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.someServicesUnavailable')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should have red text styling', () => {
|
||||
const { container } = render(
|
||||
<Footer
|
||||
resultCount={0}
|
||||
searchMode="general"
|
||||
isError={true}
|
||||
isCommandsMode={false}
|
||||
hasQuery={true}
|
||||
/>,
|
||||
)
|
||||
|
||||
const errorText = container.querySelector('.text-red-500')
|
||||
expect(errorText).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show error even with results', () => {
|
||||
render(
|
||||
<Footer
|
||||
resultCount={5}
|
||||
searchMode="general"
|
||||
isError={true}
|
||||
isCommandsMode={false}
|
||||
hasQuery={true}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.someServicesUnavailable')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('when no results and no error', () => {
|
||||
it('should show select to navigate in commands mode', () => {
|
||||
render(
|
||||
<Footer
|
||||
resultCount={0}
|
||||
searchMode="general"
|
||||
isError={false}
|
||||
isCommandsMode={true}
|
||||
hasQuery={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.selectToNavigate')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show searching when has query', () => {
|
||||
render(
|
||||
<Footer
|
||||
resultCount={0}
|
||||
searchMode="general"
|
||||
isError={false}
|
||||
isCommandsMode={false}
|
||||
hasQuery={true}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.searching')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show start typing when no query', () => {
|
||||
render(
|
||||
<Footer
|
||||
resultCount={0}
|
||||
searchMode="general"
|
||||
isError={false}
|
||||
isCommandsMode={false}
|
||||
hasQuery={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.startTyping')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('right content', () => {
|
||||
describe('when there are results or error', () => {
|
||||
it('should show clear to search all when in specific mode', () => {
|
||||
render(
|
||||
<Footer
|
||||
resultCount={5}
|
||||
searchMode="@app"
|
||||
isError={false}
|
||||
isCommandsMode={false}
|
||||
hasQuery={true}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.clearToSearchAll')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show use @ for specific when in general mode', () => {
|
||||
render(
|
||||
<Footer
|
||||
resultCount={5}
|
||||
searchMode="general"
|
||||
isError={false}
|
||||
isCommandsMode={false}
|
||||
hasQuery={true}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.useAtForSpecific')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show same hint when error', () => {
|
||||
render(
|
||||
<Footer
|
||||
resultCount={0}
|
||||
searchMode="general"
|
||||
isError={true}
|
||||
isCommandsMode={false}
|
||||
hasQuery={true}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.useAtForSpecific')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('when no results and no error', () => {
|
||||
it('should show tips when has query', () => {
|
||||
render(
|
||||
<Footer
|
||||
resultCount={0}
|
||||
searchMode="general"
|
||||
isError={false}
|
||||
isCommandsMode={false}
|
||||
hasQuery={true}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.tips')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show tips when in commands mode', () => {
|
||||
render(
|
||||
<Footer
|
||||
resultCount={0}
|
||||
searchMode="general"
|
||||
isError={false}
|
||||
isCommandsMode={true}
|
||||
hasQuery={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.tips')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show press ESC to close when no query and not in commands mode', () => {
|
||||
render(
|
||||
<Footer
|
||||
resultCount={0}
|
||||
searchMode="general"
|
||||
isError={false}
|
||||
isCommandsMode={false}
|
||||
hasQuery={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.pressEscToClose')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('styling', () => {
|
||||
it('should have border and background classes', () => {
|
||||
const { container } = render(
|
||||
<Footer
|
||||
resultCount={0}
|
||||
searchMode="general"
|
||||
isError={false}
|
||||
isCommandsMode={false}
|
||||
hasQuery={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
const footer = container.firstChild
|
||||
expect(footer).toHaveClass('border-t', 'border-divider-subtle', 'bg-components-panel-bg-blur')
|
||||
})
|
||||
|
||||
it('should have flex layout for content', () => {
|
||||
const { container } = render(
|
||||
<Footer
|
||||
resultCount={0}
|
||||
searchMode="general"
|
||||
isError={false}
|
||||
isCommandsMode={false}
|
||||
hasQuery={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
const flexContainer = container.querySelector('.flex.items-center.justify-between')
|
||||
expect(flexContainer).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
90
web/app/components/goto-anything/components/footer.tsx
Normal file
90
web/app/components/goto-anything/components/footer.tsx
Normal file
@@ -0,0 +1,90 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
export type FooterProps = {
|
||||
resultCount: number
|
||||
searchMode: string
|
||||
isError: boolean
|
||||
isCommandsMode: boolean
|
||||
hasQuery: boolean
|
||||
}
|
||||
|
||||
const Footer: FC<FooterProps> = ({
|
||||
resultCount,
|
||||
searchMode,
|
||||
isError,
|
||||
isCommandsMode,
|
||||
hasQuery,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const renderLeftContent = () => {
|
||||
if (resultCount > 0 || isError) {
|
||||
if (isError) {
|
||||
return (
|
||||
<span className="text-red-500">
|
||||
{t('gotoAnything.someServicesUnavailable', { ns: 'app' })}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{t('gotoAnything.resultCount', { ns: 'app', count: resultCount })}
|
||||
{searchMode !== 'general' && (
|
||||
<span className="ml-2 opacity-60">
|
||||
{t('gotoAnything.inScope', { ns: 'app', scope: searchMode.replace('@', '') })}
|
||||
</span>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<span className="opacity-60">
|
||||
{(() => {
|
||||
if (isCommandsMode)
|
||||
return t('gotoAnything.selectToNavigate', { ns: 'app' })
|
||||
|
||||
if (hasQuery)
|
||||
return t('gotoAnything.searching', { ns: 'app' })
|
||||
|
||||
return t('gotoAnything.startTyping', { ns: 'app' })
|
||||
})()}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
const renderRightContent = () => {
|
||||
if (resultCount > 0 || isError) {
|
||||
return (
|
||||
<span className="opacity-60">
|
||||
{searchMode !== 'general'
|
||||
? t('gotoAnything.clearToSearchAll', { ns: 'app' })
|
||||
: t('gotoAnything.useAtForSpecific', { ns: 'app' })}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<span className="opacity-60">
|
||||
{hasQuery || isCommandsMode
|
||||
? t('gotoAnything.tips', { ns: 'app' })
|
||||
: t('gotoAnything.pressEscToClose', { ns: 'app' })}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="border-t border-divider-subtle bg-components-panel-bg-blur px-4 py-2 text-xs text-text-tertiary">
|
||||
<div className="flex min-h-[16px] items-center justify-between">
|
||||
<span>{renderLeftContent()}</span>
|
||||
{renderRightContent()}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default Footer
|
||||
14
web/app/components/goto-anything/components/index.ts
Normal file
14
web/app/components/goto-anything/components/index.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
export { default as EmptyState } from './empty-state'
|
||||
export type { EmptyStateProps, EmptyStateVariant } from './empty-state'
|
||||
|
||||
export { default as Footer } from './footer'
|
||||
export type { FooterProps } from './footer'
|
||||
|
||||
export { default as ResultItem } from './result-item'
|
||||
export type { ResultItemProps } from './result-item'
|
||||
|
||||
export { default as ResultList } from './result-list'
|
||||
export type { ResultListProps } from './result-list'
|
||||
|
||||
export { default as SearchInput } from './search-input'
|
||||
export type { SearchInputProps } from './search-input'
|
||||
38
web/app/components/goto-anything/components/result-item.tsx
Normal file
38
web/app/components/goto-anything/components/result-item.tsx
Normal file
@@ -0,0 +1,38 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { SearchResult } from '../actions/types'
|
||||
import { Command } from 'cmdk'
|
||||
|
||||
export type ResultItemProps = {
|
||||
result: SearchResult
|
||||
onSelect: () => void
|
||||
}
|
||||
|
||||
const ResultItem: FC<ResultItemProps> = ({ result, onSelect }) => {
|
||||
return (
|
||||
<Command.Item
|
||||
key={`${result.type}-${result.id}`}
|
||||
value={`${result.type}-${result.id}`}
|
||||
className="flex cursor-pointer items-center gap-3 rounded-md p-3 will-change-[background-color] hover:bg-state-base-hover aria-[selected=true]:bg-state-base-hover-alt data-[selected=true]:bg-state-base-hover-alt"
|
||||
onSelect={onSelect}
|
||||
>
|
||||
{result.icon}
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="truncate font-medium text-text-secondary">
|
||||
{result.title}
|
||||
</div>
|
||||
{result.description && (
|
||||
<div className="mt-0.5 truncate text-xs text-text-quaternary">
|
||||
{result.description}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="text-xs capitalize text-text-quaternary">
|
||||
{result.type}
|
||||
</div>
|
||||
</Command.Item>
|
||||
)
|
||||
}
|
||||
|
||||
export default ResultItem
|
||||
49
web/app/components/goto-anything/components/result-list.tsx
Normal file
49
web/app/components/goto-anything/components/result-list.tsx
Normal file
@@ -0,0 +1,49 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { SearchResult } from '../actions/types'
|
||||
import { Command } from 'cmdk'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ResultItem from './result-item'
|
||||
|
||||
export type ResultListProps = {
|
||||
groupedResults: Record<string, SearchResult[]>
|
||||
onSelect: (result: SearchResult) => void
|
||||
}
|
||||
|
||||
const ResultList: FC<ResultListProps> = ({ groupedResults, onSelect }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const getGroupHeading = (type: string) => {
|
||||
const typeMap = {
|
||||
'app': 'gotoAnything.groups.apps',
|
||||
'plugin': 'gotoAnything.groups.plugins',
|
||||
'knowledge': 'gotoAnything.groups.knowledgeBases',
|
||||
'workflow-node': 'gotoAnything.groups.workflowNodes',
|
||||
'command': 'gotoAnything.groups.commands',
|
||||
} as const
|
||||
return t(typeMap[type as keyof typeof typeMap] || `${type}s`, { ns: 'app' })
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{Object.entries(groupedResults).map(([type, results]) => (
|
||||
<Command.Group
|
||||
key={type}
|
||||
heading={getGroupHeading(type)}
|
||||
className="p-2 capitalize text-text-secondary"
|
||||
>
|
||||
{results.map(result => (
|
||||
<ResultItem
|
||||
key={`${result.type}-${result.id}`}
|
||||
result={result}
|
||||
onSelect={() => onSelect(result)}
|
||||
/>
|
||||
))}
|
||||
</Command.Group>
|
||||
))}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default ResultList
|
||||
@@ -0,0 +1,206 @@
|
||||
import type { ChangeEvent, KeyboardEvent, RefObject } from 'react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import SearchInput from './search-input'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: { ns?: string }) => `${options?.ns || 'common'}.${key}`,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@remixicon/react', () => ({
|
||||
RiSearchLine: ({ className }: { className?: string }) => (
|
||||
<svg data-testid="search-icon" className={className} />
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/shortcuts-name', () => ({
|
||||
default: ({ keys, textColor }: { keys: string[], textColor: string }) => (
|
||||
<div data-testid="shortcuts-name" data-keys={keys.join(',')} data-color={textColor}>
|
||||
{keys.join('+')}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/input', async () => {
|
||||
const { forwardRef } = await import('react')
|
||||
|
||||
type MockInputProps = {
|
||||
value?: string
|
||||
placeholder?: string
|
||||
onChange?: (e: ChangeEvent<HTMLInputElement>) => void
|
||||
onKeyDown?: (e: KeyboardEvent<HTMLInputElement>) => void
|
||||
className?: string
|
||||
wrapperClassName?: string
|
||||
autoFocus?: boolean
|
||||
}
|
||||
|
||||
const MockInput = forwardRef<HTMLInputElement, MockInputProps>(
|
||||
({ value, placeholder, onChange, onKeyDown, className, wrapperClassName, autoFocus }, ref) => (
|
||||
<input
|
||||
ref={ref}
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
onChange={onChange}
|
||||
onKeyDown={onKeyDown}
|
||||
className={className}
|
||||
data-wrapper-class={wrapperClassName}
|
||||
autoFocus={autoFocus}
|
||||
data-testid="search-input"
|
||||
/>
|
||||
),
|
||||
)
|
||||
MockInput.displayName = 'MockInput'
|
||||
|
||||
return { default: MockInput }
|
||||
})
|
||||
|
||||
describe('SearchInput', () => {
|
||||
const defaultProps = {
|
||||
inputRef: { current: null } as RefObject<HTMLInputElement | null>,
|
||||
value: '',
|
||||
onChange: vi.fn(),
|
||||
searchMode: 'general',
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('rendering', () => {
|
||||
it('should render search icon', () => {
|
||||
render(<SearchInput {...defaultProps} />)
|
||||
|
||||
expect(screen.getByTestId('search-icon')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render input field', () => {
|
||||
render(<SearchInput {...defaultProps} />)
|
||||
|
||||
expect(screen.getByTestId('search-input')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render shortcuts name', () => {
|
||||
render(<SearchInput {...defaultProps} />)
|
||||
|
||||
const shortcuts = screen.getByTestId('shortcuts-name')
|
||||
expect(shortcuts).toBeInTheDocument()
|
||||
expect(shortcuts).toHaveAttribute('data-keys', 'ctrl,K')
|
||||
expect(shortcuts).toHaveAttribute('data-color', 'secondary')
|
||||
})
|
||||
|
||||
it('should use provided placeholder', () => {
|
||||
render(<SearchInput {...defaultProps} placeholder="Custom placeholder" />)
|
||||
|
||||
expect(screen.getByPlaceholderText('Custom placeholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should use default placeholder from translation', () => {
|
||||
render(<SearchInput {...defaultProps} />)
|
||||
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('mode label', () => {
|
||||
it('should NOT show mode badge in general mode', () => {
|
||||
render(<SearchInput {...defaultProps} searchMode="general" />)
|
||||
|
||||
expect(screen.queryByText('GENERAL')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show SCOPES label in scopes mode', () => {
|
||||
render(<SearchInput {...defaultProps} searchMode="scopes" />)
|
||||
|
||||
expect(screen.getByText('SCOPES')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show COMMANDS label in commands mode', () => {
|
||||
render(<SearchInput {...defaultProps} searchMode="commands" />)
|
||||
|
||||
expect(screen.getByText('COMMANDS')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show APP label in @app mode', () => {
|
||||
render(<SearchInput {...defaultProps} searchMode="@app" />)
|
||||
|
||||
expect(screen.getByText('APP')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show PLUGIN label in @plugin mode', () => {
|
||||
render(<SearchInput {...defaultProps} searchMode="@plugin" />)
|
||||
|
||||
expect(screen.getByText('PLUGIN')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show KNOWLEDGE label in @knowledge mode', () => {
|
||||
render(<SearchInput {...defaultProps} searchMode="@knowledge" />)
|
||||
|
||||
expect(screen.getByText('KNOWLEDGE')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show NODE label in @node mode', () => {
|
||||
render(<SearchInput {...defaultProps} searchMode="@node" />)
|
||||
|
||||
expect(screen.getByText('NODE')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should uppercase custom mode label', () => {
|
||||
render(<SearchInput {...defaultProps} searchMode="@custom" />)
|
||||
|
||||
expect(screen.getByText('CUSTOM')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('input interactions', () => {
|
||||
it('should call onChange when typing', () => {
|
||||
const onChange = vi.fn()
|
||||
render(<SearchInput {...defaultProps} onChange={onChange} />)
|
||||
|
||||
const input = screen.getByTestId('search-input')
|
||||
fireEvent.change(input, { target: { value: 'test query' } })
|
||||
|
||||
expect(onChange).toHaveBeenCalledWith('test query')
|
||||
})
|
||||
|
||||
it('should call onKeyDown when pressing keys', () => {
|
||||
const onKeyDown = vi.fn()
|
||||
render(<SearchInput {...defaultProps} onKeyDown={onKeyDown} />)
|
||||
|
||||
const input = screen.getByTestId('search-input')
|
||||
fireEvent.keyDown(input, { key: 'Enter' })
|
||||
|
||||
expect(onKeyDown).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should render with provided value', () => {
|
||||
render(<SearchInput {...defaultProps} value="existing query" />)
|
||||
|
||||
expect(screen.getByDisplayValue('existing query')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should NOT throw when onKeyDown is undefined', () => {
|
||||
render(<SearchInput {...defaultProps} onKeyDown={undefined} />)
|
||||
|
||||
const input = screen.getByTestId('search-input')
|
||||
expect(() => fireEvent.keyDown(input, { key: 'Enter' })).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('styling', () => {
|
||||
it('should have search icon styling', () => {
|
||||
render(<SearchInput {...defaultProps} />)
|
||||
|
||||
const icon = screen.getByTestId('search-icon')
|
||||
expect(icon).toHaveClass('h-4', 'w-4', 'text-text-quaternary')
|
||||
})
|
||||
|
||||
it('should have mode badge styling when visible', () => {
|
||||
const { container } = render(<SearchInput {...defaultProps} searchMode="@app" />)
|
||||
|
||||
const badge = container.querySelector('.bg-gray-100')
|
||||
expect(badge).toBeInTheDocument()
|
||||
expect(badge).toHaveClass('rounded', 'px-2', 'text-xs', 'font-medium')
|
||||
})
|
||||
})
|
||||
})
|
||||
62
web/app/components/goto-anything/components/search-input.tsx
Normal file
62
web/app/components/goto-anything/components/search-input.tsx
Normal file
@@ -0,0 +1,62 @@
|
||||
'use client'
|
||||
|
||||
import type { FC, KeyboardEvent, RefObject } from 'react'
|
||||
import { RiSearchLine } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Input from '@/app/components/base/input'
|
||||
import ShortcutsName from '@/app/components/workflow/shortcuts-name'
|
||||
|
||||
export type SearchInputProps = {
|
||||
inputRef: RefObject<HTMLInputElement | null>
|
||||
value: string
|
||||
onChange: (value: string) => void
|
||||
onKeyDown?: (e: KeyboardEvent<HTMLInputElement>) => void
|
||||
searchMode: string
|
||||
placeholder?: string
|
||||
}
|
||||
|
||||
const SearchInput: FC<SearchInputProps> = ({
|
||||
inputRef,
|
||||
value,
|
||||
onChange,
|
||||
onKeyDown,
|
||||
searchMode,
|
||||
placeholder,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const getModeLabel = () => {
|
||||
if (searchMode === 'scopes')
|
||||
return 'SCOPES'
|
||||
else if (searchMode === 'commands')
|
||||
return 'COMMANDS'
|
||||
else
|
||||
return searchMode.replace('@', '').toUpperCase()
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center gap-3 border-b border-divider-subtle bg-components-panel-bg-blur px-4 py-3">
|
||||
<RiSearchLine className="h-4 w-4 text-text-quaternary" />
|
||||
<div className="flex flex-1 items-center gap-2">
|
||||
<Input
|
||||
ref={inputRef}
|
||||
value={value}
|
||||
placeholder={placeholder || t('gotoAnything.searchPlaceholder', { ns: 'app' })}
|
||||
onChange={e => onChange(e.target.value)}
|
||||
onKeyDown={onKeyDown}
|
||||
className="flex-1 !border-0 !bg-transparent !shadow-none"
|
||||
wrapperClassName="flex-1 !border-0 !bg-transparent"
|
||||
autoFocus
|
||||
/>
|
||||
{searchMode !== 'general' && (
|
||||
<div className="flex items-center gap-1 rounded bg-gray-100 px-2 py-[2px] text-xs font-medium text-gray-700 dark:bg-gray-800 dark:text-gray-300">
|
||||
<span>{getModeLabel()}</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<ShortcutsName keys={['ctrl', 'K']} textColor="secondary" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default SearchInput
|
||||
@@ -2,7 +2,7 @@ import { render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { GotoAnythingProvider, useGotoAnythingContext } from './context'
|
||||
|
||||
let pathnameMock = '/'
|
||||
let pathnameMock: string | null | undefined = '/'
|
||||
vi.mock('next/navigation', () => ({
|
||||
usePathname: () => pathnameMock,
|
||||
}))
|
||||
@@ -57,4 +57,79 @@ describe('GotoAnythingProvider', () => {
|
||||
expect(screen.getByTestId('status')).toHaveTextContent('false|true')
|
||||
})
|
||||
})
|
||||
|
||||
it('should set both flags to false when pathname is null', async () => {
|
||||
pathnameMock = null
|
||||
|
||||
render(
|
||||
<GotoAnythingProvider>
|
||||
<ContextConsumer />
|
||||
</GotoAnythingProvider>,
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('status')).toHaveTextContent('false|false')
|
||||
})
|
||||
})
|
||||
|
||||
it('should set both flags to false when pathname is undefined', async () => {
|
||||
pathnameMock = undefined
|
||||
|
||||
render(
|
||||
<GotoAnythingProvider>
|
||||
<ContextConsumer />
|
||||
</GotoAnythingProvider>,
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('status')).toHaveTextContent('false|false')
|
||||
})
|
||||
})
|
||||
|
||||
it('should set both flags to false for regular paths', async () => {
|
||||
pathnameMock = '/apps'
|
||||
|
||||
render(
|
||||
<GotoAnythingProvider>
|
||||
<ContextConsumer />
|
||||
</GotoAnythingProvider>,
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('status')).toHaveTextContent('false|false')
|
||||
})
|
||||
})
|
||||
|
||||
it('should NOT match non-pipeline dataset paths', async () => {
|
||||
pathnameMock = '/datasets/abc/documents'
|
||||
|
||||
render(
|
||||
<GotoAnythingProvider>
|
||||
<ContextConsumer />
|
||||
</GotoAnythingProvider>,
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('status')).toHaveTextContent('false|false')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('useGotoAnythingContext', () => {
|
||||
it('should return default values when used outside provider', () => {
|
||||
const TestComponent = () => {
|
||||
const { isWorkflowPage, isRagPipelinePage } = useGotoAnythingContext()
|
||||
return (
|
||||
<div data-testid="context">
|
||||
{String(isWorkflowPage)}
|
||||
|
|
||||
{String(isRagPipelinePage)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
render(<TestComponent />)
|
||||
|
||||
expect(screen.getByTestId('context')).toHaveTextContent('false|false')
|
||||
})
|
||||
})
|
||||
|
||||
11
web/app/components/goto-anything/hooks/index.ts
Normal file
11
web/app/components/goto-anything/hooks/index.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
export { useGotoAnythingModal } from './use-goto-anything-modal'
|
||||
export type { UseGotoAnythingModalReturn } from './use-goto-anything-modal'
|
||||
|
||||
export { useGotoAnythingNavigation } from './use-goto-anything-navigation'
|
||||
export type { UseGotoAnythingNavigationOptions, UseGotoAnythingNavigationReturn } from './use-goto-anything-navigation'
|
||||
|
||||
export { useGotoAnythingResults } from './use-goto-anything-results'
|
||||
export type { UseGotoAnythingResultsOptions, UseGotoAnythingResultsReturn } from './use-goto-anything-results'
|
||||
|
||||
export { useGotoAnythingSearch } from './use-goto-anything-search'
|
||||
export type { UseGotoAnythingSearchReturn } from './use-goto-anything-search'
|
||||
@@ -0,0 +1,291 @@
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { useGotoAnythingModal } from './use-goto-anything-modal'
|
||||
|
||||
type KeyPressEvent = {
|
||||
preventDefault: () => void
|
||||
target?: EventTarget
|
||||
}
|
||||
|
||||
const keyPressHandlers: Record<string, (event: KeyPressEvent) => void> = {}
|
||||
let mockIsEventTargetInputArea = false
|
||||
|
||||
vi.mock('ahooks', () => ({
|
||||
useKeyPress: (keys: string | string[], handler: (event: KeyPressEvent) => void) => {
|
||||
const keyList = Array.isArray(keys) ? keys : [keys]
|
||||
keyList.forEach((key) => {
|
||||
keyPressHandlers[key] = handler
|
||||
})
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/utils/common', () => ({
|
||||
getKeyboardKeyCodeBySystem: () => 'ctrl',
|
||||
isEventTargetInputArea: () => mockIsEventTargetInputArea,
|
||||
}))
|
||||
|
||||
describe('useGotoAnythingModal', () => {
|
||||
beforeEach(() => {
|
||||
Object.keys(keyPressHandlers).forEach(key => delete keyPressHandlers[key])
|
||||
mockIsEventTargetInputArea = false
|
||||
vi.useFakeTimers()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
describe('initialization', () => {
|
||||
it('should initialize with show=false', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
expect(result.current.show).toBe(false)
|
||||
})
|
||||
|
||||
it('should provide inputRef initialized to null', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
expect(result.current.inputRef).toBeDefined()
|
||||
expect(result.current.inputRef.current).toBe(null)
|
||||
})
|
||||
|
||||
it('should provide setShow function', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
expect(typeof result.current.setShow).toBe('function')
|
||||
})
|
||||
|
||||
it('should provide handleClose function', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
expect(typeof result.current.handleClose).toBe('function')
|
||||
})
|
||||
})
|
||||
|
||||
describe('keyboard shortcuts', () => {
|
||||
it('should toggle show state when Ctrl+K is triggered', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
|
||||
expect(result.current.show).toBe(false)
|
||||
|
||||
act(() => {
|
||||
keyPressHandlers['ctrl.k']?.({ preventDefault: vi.fn(), target: document.body })
|
||||
})
|
||||
|
||||
expect(result.current.show).toBe(true)
|
||||
})
|
||||
|
||||
it('should toggle back to closed when Ctrl+K is triggered twice', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
|
||||
act(() => {
|
||||
keyPressHandlers['ctrl.k']?.({ preventDefault: vi.fn(), target: document.body })
|
||||
})
|
||||
expect(result.current.show).toBe(true)
|
||||
|
||||
act(() => {
|
||||
keyPressHandlers['ctrl.k']?.({ preventDefault: vi.fn(), target: document.body })
|
||||
})
|
||||
expect(result.current.show).toBe(false)
|
||||
})
|
||||
|
||||
it('should NOT toggle when focus is in input area and modal is closed', () => {
|
||||
mockIsEventTargetInputArea = true
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
|
||||
expect(result.current.show).toBe(false)
|
||||
|
||||
act(() => {
|
||||
keyPressHandlers['ctrl.k']?.({ preventDefault: vi.fn(), target: document.body })
|
||||
})
|
||||
|
||||
// Should remain closed because focus is in input area
|
||||
expect(result.current.show).toBe(false)
|
||||
})
|
||||
|
||||
it('should close modal when escape is pressed and modal is open', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
|
||||
// Open modal first
|
||||
act(() => {
|
||||
result.current.setShow(true)
|
||||
})
|
||||
expect(result.current.show).toBe(true)
|
||||
|
||||
// Press escape
|
||||
act(() => {
|
||||
keyPressHandlers.esc?.({ preventDefault: vi.fn() })
|
||||
})
|
||||
|
||||
expect(result.current.show).toBe(false)
|
||||
})
|
||||
|
||||
it('should NOT do anything when escape is pressed and modal is already closed', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
|
||||
expect(result.current.show).toBe(false)
|
||||
|
||||
const preventDefaultMock = vi.fn()
|
||||
act(() => {
|
||||
keyPressHandlers.esc?.({ preventDefault: preventDefaultMock })
|
||||
})
|
||||
|
||||
// Should remain closed, and preventDefault should not be called
|
||||
expect(result.current.show).toBe(false)
|
||||
expect(preventDefaultMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call preventDefault when Ctrl+K is triggered', () => {
|
||||
renderHook(() => useGotoAnythingModal())
|
||||
|
||||
const preventDefaultMock = vi.fn()
|
||||
act(() => {
|
||||
keyPressHandlers['ctrl.k']?.({ preventDefault: preventDefaultMock, target: document.body })
|
||||
})
|
||||
|
||||
expect(preventDefaultMock).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleClose', () => {
|
||||
it('should close modal when handleClose is called', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
|
||||
// Open modal first
|
||||
act(() => {
|
||||
result.current.setShow(true)
|
||||
})
|
||||
expect(result.current.show).toBe(true)
|
||||
|
||||
// Close via handleClose
|
||||
act(() => {
|
||||
result.current.handleClose()
|
||||
})
|
||||
|
||||
expect(result.current.show).toBe(false)
|
||||
})
|
||||
|
||||
it('should be safe to call handleClose when modal is already closed', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
|
||||
expect(result.current.show).toBe(false)
|
||||
|
||||
act(() => {
|
||||
result.current.handleClose()
|
||||
})
|
||||
|
||||
expect(result.current.show).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('setShow', () => {
|
||||
it('should accept boolean value', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
|
||||
act(() => {
|
||||
result.current.setShow(true)
|
||||
})
|
||||
expect(result.current.show).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.setShow(false)
|
||||
})
|
||||
expect(result.current.show).toBe(false)
|
||||
})
|
||||
|
||||
it('should accept function value', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
|
||||
act(() => {
|
||||
result.current.setShow(prev => !prev)
|
||||
})
|
||||
expect(result.current.show).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.setShow(prev => !prev)
|
||||
})
|
||||
expect(result.current.show).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('focus management', () => {
|
||||
it('should call requestAnimationFrame when modal opens', () => {
|
||||
const rafSpy = vi.spyOn(window, 'requestAnimationFrame')
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
|
||||
act(() => {
|
||||
result.current.setShow(true)
|
||||
})
|
||||
|
||||
expect(rafSpy).toHaveBeenCalled()
|
||||
rafSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should not call requestAnimationFrame when modal closes', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
|
||||
// First open
|
||||
act(() => {
|
||||
result.current.setShow(true)
|
||||
})
|
||||
|
||||
const rafSpy = vi.spyOn(window, 'requestAnimationFrame')
|
||||
|
||||
// Then close
|
||||
act(() => {
|
||||
result.current.setShow(false)
|
||||
})
|
||||
|
||||
expect(rafSpy).not.toHaveBeenCalled()
|
||||
rafSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should focus input when modal opens and inputRef.current exists', () => {
|
||||
// Mock requestAnimationFrame to execute callback immediately
|
||||
const originalRAF = window.requestAnimationFrame
|
||||
window.requestAnimationFrame = (callback: FrameRequestCallback) => {
|
||||
callback(0)
|
||||
return 0
|
||||
}
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
|
||||
// Create a mock input element with focus method
|
||||
const mockFocus = vi.fn()
|
||||
const mockInput = { focus: mockFocus } as unknown as HTMLInputElement
|
||||
|
||||
// Manually set the inputRef
|
||||
Object.defineProperty(result.current.inputRef, 'current', {
|
||||
value: mockInput,
|
||||
writable: true,
|
||||
})
|
||||
|
||||
act(() => {
|
||||
result.current.setShow(true)
|
||||
})
|
||||
|
||||
expect(mockFocus).toHaveBeenCalled()
|
||||
|
||||
// Restore original requestAnimationFrame
|
||||
window.requestAnimationFrame = originalRAF
|
||||
})
|
||||
|
||||
it('should not throw when inputRef.current is null when modal opens', () => {
|
||||
// Mock requestAnimationFrame to execute callback immediately
|
||||
const originalRAF = window.requestAnimationFrame
|
||||
window.requestAnimationFrame = (callback: FrameRequestCallback) => {
|
||||
callback(0)
|
||||
return 0
|
||||
}
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingModal())
|
||||
|
||||
// inputRef.current is already null by default
|
||||
|
||||
// Should not throw
|
||||
act(() => {
|
||||
result.current.setShow(true)
|
||||
})
|
||||
|
||||
expect(result.current.show).toBe(true)
|
||||
|
||||
// Restore original requestAnimationFrame
|
||||
window.requestAnimationFrame = originalRAF
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,59 @@
|
||||
'use client'
|
||||
|
||||
import type { RefObject } from 'react'
|
||||
import { useKeyPress } from 'ahooks'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { getKeyboardKeyCodeBySystem, isEventTargetInputArea } from '@/app/components/workflow/utils/common'
|
||||
|
||||
export type UseGotoAnythingModalReturn = {
|
||||
show: boolean
|
||||
setShow: (show: boolean | ((prev: boolean) => boolean)) => void
|
||||
inputRef: RefObject<HTMLInputElement | null>
|
||||
handleClose: () => void
|
||||
}
|
||||
|
||||
export const useGotoAnythingModal = (): UseGotoAnythingModalReturn => {
|
||||
const [show, setShow] = useState<boolean>(false)
|
||||
const inputRef = useRef<HTMLInputElement>(null)
|
||||
|
||||
// Handle keyboard shortcuts
|
||||
const handleToggleModal = useCallback((e: KeyboardEvent) => {
|
||||
// Allow closing when modal is open, even if focus is in the search input
|
||||
if (!show && isEventTargetInputArea(e.target as HTMLElement))
|
||||
return
|
||||
e.preventDefault()
|
||||
setShow(prev => !prev)
|
||||
}, [show])
|
||||
|
||||
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.k`, handleToggleModal, {
|
||||
exactMatch: true,
|
||||
useCapture: true,
|
||||
})
|
||||
|
||||
useKeyPress(['esc'], (e) => {
|
||||
if (show) {
|
||||
e.preventDefault()
|
||||
setShow(false)
|
||||
}
|
||||
})
|
||||
|
||||
const handleClose = useCallback(() => {
|
||||
setShow(false)
|
||||
}, [])
|
||||
|
||||
// Focus input when modal opens
|
||||
useEffect(() => {
|
||||
if (show) {
|
||||
requestAnimationFrame(() => {
|
||||
inputRef.current?.focus()
|
||||
})
|
||||
}
|
||||
}, [show])
|
||||
|
||||
return {
|
||||
show,
|
||||
setShow,
|
||||
inputRef,
|
||||
handleClose,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,391 @@
|
||||
import type * as React from 'react'
|
||||
import type { Plugin } from '../../plugins/types'
|
||||
import type { CommonNodeType } from '../../workflow/types'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import type { App } from '@/types/app'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { useGotoAnythingNavigation } from './use-goto-anything-navigation'
|
||||
|
||||
const mockRouterPush = vi.fn()
|
||||
const mockSelectWorkflowNode = vi.fn()
|
||||
|
||||
type MockCommandResult = {
|
||||
mode: string
|
||||
execute?: () => void
|
||||
} | null
|
||||
|
||||
let mockFindCommandResult: MockCommandResult = null
|
||||
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
push: mockRouterPush,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/utils/node-navigation', () => ({
|
||||
selectWorkflowNode: (...args: unknown[]) => mockSelectWorkflowNode(...args),
|
||||
}))
|
||||
|
||||
vi.mock('../actions/commands/registry', () => ({
|
||||
slashCommandRegistry: {
|
||||
findCommand: () => mockFindCommandResult,
|
||||
},
|
||||
}))
|
||||
|
||||
const createMockActionItem = (
|
||||
key: '@app' | '@knowledge' | '@plugin' | '@node' | '/',
|
||||
extra: Record<string, unknown> = {},
|
||||
) => ({
|
||||
key,
|
||||
shortcut: key,
|
||||
title: `${key} title`,
|
||||
description: `${key} description`,
|
||||
search: vi.fn().mockResolvedValue([]),
|
||||
...extra,
|
||||
})
|
||||
|
||||
const createMockOptions = (overrides = {}) => ({
|
||||
Actions: {
|
||||
slash: createMockActionItem('/', { action: vi.fn() }),
|
||||
app: createMockActionItem('@app'),
|
||||
},
|
||||
setSearchQuery: vi.fn(),
|
||||
clearSelection: vi.fn(),
|
||||
inputRef: { current: { focus: vi.fn() } } as unknown as React.RefObject<HTMLInputElement>,
|
||||
onClose: vi.fn(),
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('useGotoAnythingNavigation', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockFindCommandResult = null
|
||||
vi.useFakeTimers()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
describe('initialization', () => {
|
||||
it('should return handleCommandSelect function', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(createMockOptions()))
|
||||
expect(typeof result.current.handleCommandSelect).toBe('function')
|
||||
})
|
||||
|
||||
it('should return handleNavigate function', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(createMockOptions()))
|
||||
expect(typeof result.current.handleNavigate).toBe('function')
|
||||
})
|
||||
|
||||
it('should initialize activePlugin as undefined', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(createMockOptions()))
|
||||
expect(result.current.activePlugin).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should return setActivePlugin function', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(createMockOptions()))
|
||||
expect(typeof result.current.setActivePlugin).toBe('function')
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleCommandSelect', () => {
|
||||
it('should execute direct mode slash command immediately', () => {
|
||||
const execute = vi.fn()
|
||||
mockFindCommandResult = { mode: 'direct', execute }
|
||||
const options = createMockOptions()
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
act(() => {
|
||||
result.current.handleCommandSelect('/theme')
|
||||
})
|
||||
|
||||
expect(execute).toHaveBeenCalled()
|
||||
expect(options.onClose).toHaveBeenCalled()
|
||||
expect(options.setSearchQuery).toHaveBeenCalledWith('')
|
||||
})
|
||||
|
||||
it('should NOT execute when handler has no execute function', () => {
|
||||
mockFindCommandResult = { mode: 'direct', execute: undefined }
|
||||
const options = createMockOptions()
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
act(() => {
|
||||
result.current.handleCommandSelect('/theme')
|
||||
})
|
||||
|
||||
expect(options.onClose).not.toHaveBeenCalled()
|
||||
// Should proceed with submenu mode
|
||||
expect(options.setSearchQuery).toHaveBeenCalledWith('/theme ')
|
||||
})
|
||||
|
||||
it('should proceed with submenu mode for non-direct commands', () => {
|
||||
mockFindCommandResult = { mode: 'submenu' }
|
||||
const options = createMockOptions()
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
act(() => {
|
||||
result.current.handleCommandSelect('/language')
|
||||
})
|
||||
|
||||
expect(options.setSearchQuery).toHaveBeenCalledWith('/language ')
|
||||
expect(options.clearSelection).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle @ commands (scopes)', () => {
|
||||
const options = createMockOptions()
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
act(() => {
|
||||
result.current.handleCommandSelect('@app')
|
||||
})
|
||||
|
||||
expect(options.setSearchQuery).toHaveBeenCalledWith('@app ')
|
||||
expect(options.clearSelection).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should focus input after setting search query', () => {
|
||||
const focusMock = vi.fn()
|
||||
const options = createMockOptions({
|
||||
inputRef: { current: { focus: focusMock } },
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
act(() => {
|
||||
result.current.handleCommandSelect('@app')
|
||||
})
|
||||
|
||||
act(() => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
expect(focusMock).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle null handler from registry', () => {
|
||||
mockFindCommandResult = null
|
||||
const options = createMockOptions()
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
act(() => {
|
||||
result.current.handleCommandSelect('/unknown')
|
||||
})
|
||||
|
||||
// Should proceed with submenu mode
|
||||
expect(options.setSearchQuery).toHaveBeenCalledWith('/unknown ')
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleNavigate', () => {
|
||||
it('should navigate to path for default result types', () => {
|
||||
const options = createMockOptions()
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
act(() => {
|
||||
result.current.handleNavigate({
|
||||
id: '1',
|
||||
type: 'app' as const,
|
||||
title: 'My App',
|
||||
path: '/apps/1',
|
||||
data: { id: '1', name: 'My App' } as unknown as App,
|
||||
})
|
||||
})
|
||||
|
||||
expect(options.onClose).toHaveBeenCalled()
|
||||
expect(options.setSearchQuery).toHaveBeenCalledWith('')
|
||||
expect(mockRouterPush).toHaveBeenCalledWith('/apps/1')
|
||||
})
|
||||
|
||||
it('should NOT call router.push when path is empty', () => {
|
||||
const options = createMockOptions()
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
act(() => {
|
||||
result.current.handleNavigate({
|
||||
id: '1',
|
||||
type: 'app' as const,
|
||||
title: 'My App',
|
||||
path: '',
|
||||
data: { id: '1', name: 'My App' } as unknown as App,
|
||||
})
|
||||
})
|
||||
|
||||
expect(mockRouterPush).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should execute slash command action for command type', () => {
|
||||
const actionMock = vi.fn()
|
||||
const options = createMockOptions({
|
||||
Actions: {
|
||||
slash: { key: '/', shortcut: '/', action: actionMock },
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
const commandResult = {
|
||||
id: 'cmd-1',
|
||||
type: 'command' as const,
|
||||
title: 'Theme Dark',
|
||||
data: { command: 'theme.set', args: { theme: 'dark' } },
|
||||
}
|
||||
|
||||
act(() => {
|
||||
result.current.handleNavigate(commandResult)
|
||||
})
|
||||
|
||||
expect(actionMock).toHaveBeenCalledWith(commandResult)
|
||||
})
|
||||
|
||||
it('should set activePlugin for plugin type', () => {
|
||||
const options = createMockOptions()
|
||||
const pluginData = { name: 'My Plugin', latest_package_identifier: 'pkg' } as unknown as Plugin
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
act(() => {
|
||||
result.current.handleNavigate({
|
||||
id: 'plugin-1',
|
||||
type: 'plugin' as const,
|
||||
title: 'My Plugin',
|
||||
data: pluginData,
|
||||
})
|
||||
})
|
||||
|
||||
expect(result.current.activePlugin).toEqual(pluginData)
|
||||
})
|
||||
|
||||
it('should select workflow node for workflow-node type', () => {
|
||||
const options = createMockOptions()
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
act(() => {
|
||||
result.current.handleNavigate({
|
||||
id: 'node-1',
|
||||
type: 'workflow-node' as const,
|
||||
title: 'Start Node',
|
||||
metadata: { nodeId: 'node-123', nodeData: {} as CommonNodeType },
|
||||
data: { id: 'node-1' } as unknown as CommonNodeType,
|
||||
})
|
||||
})
|
||||
|
||||
expect(mockSelectWorkflowNode).toHaveBeenCalledWith('node-123', true)
|
||||
})
|
||||
|
||||
it('should NOT select workflow node when metadata.nodeId is missing', () => {
|
||||
const options = createMockOptions()
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
act(() => {
|
||||
result.current.handleNavigate({
|
||||
id: 'node-1',
|
||||
type: 'workflow-node' as const,
|
||||
title: 'Start Node',
|
||||
metadata: undefined,
|
||||
data: { id: 'node-1' } as unknown as CommonNodeType,
|
||||
})
|
||||
})
|
||||
|
||||
expect(mockSelectWorkflowNode).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle knowledge type (default case with path)', () => {
|
||||
const options = createMockOptions()
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
act(() => {
|
||||
result.current.handleNavigate({
|
||||
id: 'kb-1',
|
||||
type: 'knowledge' as const,
|
||||
title: 'My Knowledge Base',
|
||||
path: '/datasets/kb-1',
|
||||
data: { id: 'kb-1', name: 'My Knowledge Base' } as unknown as DataSet,
|
||||
})
|
||||
})
|
||||
|
||||
expect(mockRouterPush).toHaveBeenCalledWith('/datasets/kb-1')
|
||||
})
|
||||
})
|
||||
|
||||
describe('setActivePlugin', () => {
|
||||
it('should update activePlugin state', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(createMockOptions()))
|
||||
|
||||
const plugin = { name: 'Test Plugin', latest_package_identifier: 'test-pkg' } as unknown as Plugin
|
||||
act(() => {
|
||||
result.current.setActivePlugin(plugin)
|
||||
})
|
||||
|
||||
expect(result.current.activePlugin).toEqual(plugin)
|
||||
})
|
||||
|
||||
it('should clear activePlugin when set to undefined', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(createMockOptions()))
|
||||
|
||||
// First set a plugin
|
||||
act(() => {
|
||||
result.current.setActivePlugin({ name: 'Plugin', latest_package_identifier: 'pkg' } as unknown as Plugin)
|
||||
})
|
||||
expect(result.current.activePlugin).toBeDefined()
|
||||
|
||||
// Then clear it
|
||||
act(() => {
|
||||
result.current.setActivePlugin(undefined)
|
||||
})
|
||||
|
||||
expect(result.current.activePlugin).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle undefined inputRef.current', () => {
|
||||
const options = createMockOptions({
|
||||
inputRef: { current: null },
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
// Should not throw
|
||||
act(() => {
|
||||
result.current.handleCommandSelect('@app')
|
||||
})
|
||||
|
||||
act(() => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
// No error should occur
|
||||
})
|
||||
|
||||
it('should handle missing slash action', () => {
|
||||
const options = createMockOptions({
|
||||
Actions: {},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingNavigation(options))
|
||||
|
||||
// Should not throw
|
||||
act(() => {
|
||||
result.current.handleNavigate({
|
||||
id: 'cmd-1',
|
||||
type: 'command' as const,
|
||||
title: 'Command',
|
||||
data: { command: 'test-command' },
|
||||
})
|
||||
})
|
||||
|
||||
// No error should occur
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,96 @@
|
||||
'use client'
|
||||
|
||||
import type { RefObject } from 'react'
|
||||
import type { Plugin } from '../../plugins/types'
|
||||
import type { ActionItem, SearchResult } from '../actions/types'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { selectWorkflowNode } from '@/app/components/workflow/utils/node-navigation'
|
||||
import { slashCommandRegistry } from '../actions/commands/registry'
|
||||
|
||||
export type UseGotoAnythingNavigationReturn = {
|
||||
handleCommandSelect: (commandKey: string) => void
|
||||
handleNavigate: (result: SearchResult) => void
|
||||
activePlugin: Plugin | undefined
|
||||
setActivePlugin: (plugin: Plugin | undefined) => void
|
||||
}
|
||||
|
||||
export type UseGotoAnythingNavigationOptions = {
|
||||
Actions: Record<string, ActionItem>
|
||||
setSearchQuery: (query: string) => void
|
||||
clearSelection: () => void
|
||||
inputRef: RefObject<HTMLInputElement | null>
|
||||
onClose: () => void
|
||||
}
|
||||
|
||||
export const useGotoAnythingNavigation = (
|
||||
options: UseGotoAnythingNavigationOptions,
|
||||
): UseGotoAnythingNavigationReturn => {
|
||||
const {
|
||||
Actions,
|
||||
setSearchQuery,
|
||||
clearSelection,
|
||||
inputRef,
|
||||
onClose,
|
||||
} = options
|
||||
|
||||
const router = useRouter()
|
||||
const [activePlugin, setActivePlugin] = useState<Plugin>()
|
||||
|
||||
const handleCommandSelect = useCallback((commandKey: string) => {
|
||||
// Check if it's a slash command
|
||||
if (commandKey.startsWith('/')) {
|
||||
const commandName = commandKey.substring(1)
|
||||
const handler = slashCommandRegistry.findCommand(commandName)
|
||||
|
||||
// If it's a direct mode command, execute immediately
|
||||
if (handler?.mode === 'direct' && handler.execute) {
|
||||
handler.execute()
|
||||
onClose()
|
||||
setSearchQuery('')
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, proceed with the normal flow (submenu mode)
|
||||
setSearchQuery(`${commandKey} `)
|
||||
clearSelection()
|
||||
setTimeout(() => {
|
||||
inputRef.current?.focus()
|
||||
}, 0)
|
||||
}, [onClose, setSearchQuery, clearSelection, inputRef])
|
||||
|
||||
// Handle navigation to selected result
|
||||
const handleNavigate = useCallback((result: SearchResult) => {
|
||||
onClose()
|
||||
setSearchQuery('')
|
||||
|
||||
switch (result.type) {
|
||||
case 'command': {
|
||||
// Execute slash commands
|
||||
const action = Actions.slash
|
||||
action?.action?.(result)
|
||||
break
|
||||
}
|
||||
case 'plugin':
|
||||
setActivePlugin(result.data)
|
||||
break
|
||||
case 'workflow-node':
|
||||
// Handle workflow node selection and navigation
|
||||
if (result.metadata?.nodeId)
|
||||
selectWorkflowNode(result.metadata.nodeId, true)
|
||||
|
||||
break
|
||||
default:
|
||||
if (result.path)
|
||||
router.push(result.path)
|
||||
}
|
||||
}, [router, Actions, onClose, setSearchQuery])
|
||||
|
||||
return {
|
||||
handleCommandSelect,
|
||||
handleNavigate,
|
||||
activePlugin,
|
||||
setActivePlugin,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,354 @@
|
||||
import type { SearchResult } from '../actions/types'
|
||||
import { renderHook } from '@testing-library/react'
|
||||
import { useGotoAnythingResults } from './use-goto-anything-results'
|
||||
|
||||
type MockQueryResult = {
|
||||
data: Array<{ id: string, type: string, title: string }> | undefined
|
||||
isLoading: boolean
|
||||
isError: boolean
|
||||
error: Error | null
|
||||
}
|
||||
|
||||
type UseQueryOptions = {
|
||||
queryFn: () => Promise<SearchResult[]>
|
||||
}
|
||||
|
||||
let mockQueryResult: MockQueryResult = { data: [], isLoading: false, isError: false, error: null }
|
||||
let capturedQueryFn: (() => Promise<SearchResult[]>) | null = null
|
||||
|
||||
vi.mock('@tanstack/react-query', () => ({
|
||||
useQuery: (options: UseQueryOptions) => {
|
||||
capturedQueryFn = options.queryFn
|
||||
return mockQueryResult
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useGetLanguage: () => 'en_US',
|
||||
}))
|
||||
|
||||
const mockMatchAction = vi.fn()
|
||||
const mockSearchAnything = vi.fn()
|
||||
|
||||
vi.mock('../actions', () => ({
|
||||
matchAction: (...args: unknown[]) => mockMatchAction(...args),
|
||||
searchAnything: (...args: unknown[]) => mockSearchAnything(...args),
|
||||
}))
|
||||
|
||||
const createMockActionItem = (key: '@app' | '@knowledge' | '@plugin' | '@node' | '/') => ({
|
||||
key,
|
||||
shortcut: key,
|
||||
title: `${key} title`,
|
||||
description: `${key} description`,
|
||||
search: vi.fn().mockResolvedValue([]),
|
||||
})
|
||||
|
||||
const createMockOptions = (overrides = {}) => ({
|
||||
searchQueryDebouncedValue: '',
|
||||
searchMode: 'general',
|
||||
isCommandsMode: false,
|
||||
Actions: { app: createMockActionItem('@app') },
|
||||
isWorkflowPage: false,
|
||||
isRagPipelinePage: false,
|
||||
cmdVal: '_',
|
||||
setCmdVal: vi.fn(),
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('useGotoAnythingResults', () => {
|
||||
beforeEach(() => {
|
||||
mockQueryResult = { data: [], isLoading: false, isError: false, error: null }
|
||||
capturedQueryFn = null
|
||||
mockMatchAction.mockReset()
|
||||
mockSearchAnything.mockReset()
|
||||
})
|
||||
|
||||
describe('initialization', () => {
|
||||
it('should return empty arrays when no results', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
|
||||
|
||||
expect(result.current.searchResults).toEqual([])
|
||||
expect(result.current.dedupedResults).toEqual([])
|
||||
expect(result.current.groupedResults).toEqual({})
|
||||
})
|
||||
|
||||
it('should return loading state', () => {
|
||||
mockQueryResult = { data: [], isLoading: true, isError: false, error: null }
|
||||
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
|
||||
|
||||
expect(result.current.isLoading).toBe(true)
|
||||
})
|
||||
|
||||
it('should return error state', () => {
|
||||
const error = new Error('Test error')
|
||||
mockQueryResult = { data: [], isLoading: false, isError: true, error }
|
||||
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
|
||||
|
||||
expect(result.current.isError).toBe(true)
|
||||
expect(result.current.error).toBe(error)
|
||||
})
|
||||
})
|
||||
|
||||
describe('dedupedResults', () => {
|
||||
it('should remove duplicate results', () => {
|
||||
mockQueryResult = {
|
||||
data: [
|
||||
{ id: '1', type: 'app', title: 'App 1' },
|
||||
{ id: '1', type: 'app', title: 'App 1 Duplicate' },
|
||||
{ id: '2', type: 'app', title: 'App 2' },
|
||||
],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
|
||||
|
||||
expect(result.current.dedupedResults).toHaveLength(2)
|
||||
expect(result.current.dedupedResults[0].id).toBe('1')
|
||||
expect(result.current.dedupedResults[1].id).toBe('2')
|
||||
})
|
||||
|
||||
it('should keep first occurrence when duplicates exist', () => {
|
||||
mockQueryResult = {
|
||||
data: [
|
||||
{ id: '1', type: 'app', title: 'First' },
|
||||
{ id: '1', type: 'app', title: 'Second' },
|
||||
],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
|
||||
|
||||
expect(result.current.dedupedResults).toHaveLength(1)
|
||||
expect(result.current.dedupedResults[0].title).toBe('First')
|
||||
})
|
||||
|
||||
it('should handle different types with same id', () => {
|
||||
mockQueryResult = {
|
||||
data: [
|
||||
{ id: '1', type: 'app', title: 'App' },
|
||||
{ id: '1', type: 'plugin', title: 'Plugin' },
|
||||
],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
|
||||
|
||||
// Different types, same id = different keys, so both should remain
|
||||
expect(result.current.dedupedResults).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('groupedResults', () => {
|
||||
it('should group results by type', () => {
|
||||
mockQueryResult = {
|
||||
data: [
|
||||
{ id: '1', type: 'app', title: 'App 1' },
|
||||
{ id: '2', type: 'app', title: 'App 2' },
|
||||
{ id: '3', type: 'plugin', title: 'Plugin 1' },
|
||||
],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
|
||||
|
||||
expect(result.current.groupedResults.app).toHaveLength(2)
|
||||
expect(result.current.groupedResults.plugin).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('should handle single type', () => {
|
||||
mockQueryResult = {
|
||||
data: [
|
||||
{ id: '1', type: 'knowledge', title: 'KB 1' },
|
||||
{ id: '2', type: 'knowledge', title: 'KB 2' },
|
||||
],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
|
||||
|
||||
expect(Object.keys(result.current.groupedResults)).toEqual(['knowledge'])
|
||||
expect(result.current.groupedResults.knowledge).toHaveLength(2)
|
||||
})
|
||||
|
||||
it('should return empty object when no results', () => {
|
||||
mockQueryResult = { data: [], isLoading: false, isError: false, error: null }
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
|
||||
|
||||
expect(result.current.groupedResults).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('auto-select first result', () => {
|
||||
it('should call setCmdVal when results change and current value does not exist', () => {
|
||||
const setCmdVal = vi.fn()
|
||||
mockQueryResult = {
|
||||
data: [{ id: '1', type: 'app', title: 'App 1' }],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
|
||||
renderHook(() => useGotoAnythingResults(createMockOptions({
|
||||
cmdVal: 'non-existent',
|
||||
setCmdVal,
|
||||
})))
|
||||
|
||||
expect(setCmdVal).toHaveBeenCalledWith('app-1')
|
||||
})
|
||||
|
||||
it('should NOT call setCmdVal when in commands mode', () => {
|
||||
const setCmdVal = vi.fn()
|
||||
mockQueryResult = {
|
||||
data: [{ id: '1', type: 'app', title: 'App 1' }],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
|
||||
renderHook(() => useGotoAnythingResults(createMockOptions({
|
||||
isCommandsMode: true,
|
||||
setCmdVal,
|
||||
})))
|
||||
|
||||
expect(setCmdVal).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should NOT call setCmdVal when results are empty', () => {
|
||||
const setCmdVal = vi.fn()
|
||||
mockQueryResult = { data: [], isLoading: false, isError: false, error: null }
|
||||
|
||||
renderHook(() => useGotoAnythingResults(createMockOptions({
|
||||
setCmdVal,
|
||||
})))
|
||||
|
||||
expect(setCmdVal).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should NOT call setCmdVal when current value exists in results', () => {
|
||||
const setCmdVal = vi.fn()
|
||||
mockQueryResult = {
|
||||
data: [
|
||||
{ id: '1', type: 'app', title: 'App 1' },
|
||||
{ id: '2', type: 'app', title: 'App 2' },
|
||||
],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
|
||||
renderHook(() => useGotoAnythingResults(createMockOptions({
|
||||
cmdVal: 'app-2',
|
||||
setCmdVal,
|
||||
})))
|
||||
|
||||
expect(setCmdVal).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('error handling', () => {
|
||||
it('should return error as Error | null', () => {
|
||||
const error = new Error('Search failed')
|
||||
mockQueryResult = { data: [], isLoading: false, isError: true, error }
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
|
||||
|
||||
expect(result.current.error).toBeInstanceOf(Error)
|
||||
expect(result.current.error?.message).toBe('Search failed')
|
||||
})
|
||||
|
||||
it('should return null error when no error', () => {
|
||||
mockQueryResult = { data: [], isLoading: false, isError: false, error: null }
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
|
||||
|
||||
expect(result.current.error).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('searchResults', () => {
|
||||
it('should return raw search results', () => {
|
||||
const mockData = [
|
||||
{ id: '1', type: 'app', title: 'App 1' },
|
||||
{ id: '2', type: 'plugin', title: 'Plugin 1' },
|
||||
]
|
||||
mockQueryResult = { data: mockData, isLoading: false, isError: false, error: null }
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
|
||||
|
||||
expect(result.current.searchResults).toEqual(mockData)
|
||||
})
|
||||
|
||||
it('should default to empty array when data is undefined', () => {
|
||||
mockQueryResult = { data: undefined, isLoading: false, isError: false, error: null }
|
||||
|
||||
const { result } = renderHook(() => useGotoAnythingResults(createMockOptions()))
|
||||
|
||||
expect(result.current.searchResults).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('queryFn execution', () => {
|
||||
it('should call matchAction with lowercased query', async () => {
|
||||
const mockActions = { app: createMockActionItem('@app') }
|
||||
mockMatchAction.mockReturnValue({ key: '@app' })
|
||||
mockSearchAnything.mockResolvedValue([])
|
||||
|
||||
renderHook(() => useGotoAnythingResults(createMockOptions({
|
||||
searchQueryDebouncedValue: 'TEST QUERY',
|
||||
Actions: mockActions,
|
||||
})))
|
||||
|
||||
expect(capturedQueryFn).toBeDefined()
|
||||
await capturedQueryFn!()
|
||||
|
||||
expect(mockMatchAction).toHaveBeenCalledWith('test query', mockActions)
|
||||
})
|
||||
|
||||
it('should call searchAnything with correct parameters', async () => {
|
||||
const mockActions = { app: createMockActionItem('@app') }
|
||||
const mockAction = { key: '@app' }
|
||||
mockMatchAction.mockReturnValue(mockAction)
|
||||
mockSearchAnything.mockResolvedValue([{ id: '1', type: 'app', title: 'Result' }])
|
||||
|
||||
renderHook(() => useGotoAnythingResults(createMockOptions({
|
||||
searchQueryDebouncedValue: 'My Query',
|
||||
Actions: mockActions,
|
||||
})))
|
||||
|
||||
expect(capturedQueryFn).toBeDefined()
|
||||
const result = await capturedQueryFn!()
|
||||
|
||||
expect(mockSearchAnything).toHaveBeenCalledWith('en_US', 'my query', mockAction, mockActions)
|
||||
expect(result).toEqual([{ id: '1', type: 'app', title: 'Result' }])
|
||||
})
|
||||
|
||||
it('should handle searchAnything returning results', async () => {
|
||||
const expectedResults = [
|
||||
{ id: '1', type: 'app', title: 'App 1' },
|
||||
{ id: '2', type: 'plugin', title: 'Plugin 1' },
|
||||
]
|
||||
mockMatchAction.mockReturnValue(null)
|
||||
mockSearchAnything.mockResolvedValue(expectedResults)
|
||||
|
||||
renderHook(() => useGotoAnythingResults(createMockOptions({
|
||||
searchQueryDebouncedValue: 'search term',
|
||||
})))
|
||||
|
||||
expect(capturedQueryFn).toBeDefined()
|
||||
const result = await capturedQueryFn!()
|
||||
|
||||
expect(result).toEqual(expectedResults)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,115 @@
|
||||
'use client'
|
||||
|
||||
import type { ActionItem, SearchResult } from '../actions/types'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { useEffect, useMemo } from 'react'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import { matchAction, searchAnything } from '../actions'
|
||||
|
||||
export type UseGotoAnythingResultsReturn = {
|
||||
searchResults: SearchResult[]
|
||||
dedupedResults: SearchResult[]
|
||||
groupedResults: Record<string, SearchResult[]>
|
||||
isLoading: boolean
|
||||
isError: boolean
|
||||
error: Error | null
|
||||
}
|
||||
|
||||
export type UseGotoAnythingResultsOptions = {
|
||||
searchQueryDebouncedValue: string
|
||||
searchMode: string
|
||||
isCommandsMode: boolean
|
||||
Actions: Record<string, ActionItem>
|
||||
isWorkflowPage: boolean
|
||||
isRagPipelinePage: boolean
|
||||
cmdVal: string
|
||||
setCmdVal: (val: string) => void
|
||||
}
|
||||
|
||||
export const useGotoAnythingResults = (
|
||||
options: UseGotoAnythingResultsOptions,
|
||||
): UseGotoAnythingResultsReturn => {
|
||||
const {
|
||||
searchQueryDebouncedValue,
|
||||
searchMode,
|
||||
isCommandsMode,
|
||||
Actions,
|
||||
isWorkflowPage,
|
||||
isRagPipelinePage,
|
||||
cmdVal,
|
||||
setCmdVal,
|
||||
} = options
|
||||
|
||||
const defaultLocale = useGetLanguage()
|
||||
|
||||
// Use action keys as stable cache key instead of the full Actions object
|
||||
// (Actions contains functions which are not serializable)
|
||||
const actionKeys = useMemo(() => Object.keys(Actions).sort(), [Actions])
|
||||
|
||||
const { data: searchResults = [], isLoading, isError, error } = useQuery(
|
||||
{
|
||||
// eslint-disable-next-line @tanstack/query/exhaustive-deps -- Actions intentionally excluded: contains non-serializable functions; actionKeys provides stable representation
|
||||
queryKey: [
|
||||
'goto-anything',
|
||||
'search-result',
|
||||
searchQueryDebouncedValue,
|
||||
searchMode,
|
||||
isWorkflowPage,
|
||||
isRagPipelinePage,
|
||||
defaultLocale,
|
||||
actionKeys,
|
||||
],
|
||||
queryFn: async () => {
|
||||
const query = searchQueryDebouncedValue.toLowerCase()
|
||||
const action = matchAction(query, Actions)
|
||||
return await searchAnything(defaultLocale, query, action, Actions)
|
||||
},
|
||||
enabled: !!searchQueryDebouncedValue && !isCommandsMode,
|
||||
staleTime: 30000,
|
||||
gcTime: 300000,
|
||||
},
|
||||
)
|
||||
|
||||
const dedupedResults = useMemo(() => {
|
||||
const seen = new Set<string>()
|
||||
return searchResults.filter((result) => {
|
||||
const key = `${result.type}-${result.id}`
|
||||
if (seen.has(key))
|
||||
return false
|
||||
seen.add(key)
|
||||
return true
|
||||
})
|
||||
}, [searchResults])
|
||||
|
||||
// Group results by type
|
||||
const groupedResults = useMemo(() => dedupedResults.reduce((acc, result) => {
|
||||
if (!acc[result.type])
|
||||
acc[result.type] = []
|
||||
|
||||
acc[result.type].push(result)
|
||||
return acc
|
||||
}, {} as Record<string, SearchResult[]>), [dedupedResults])
|
||||
|
||||
// Auto-select first result when results change
|
||||
useEffect(() => {
|
||||
if (isCommandsMode)
|
||||
return
|
||||
|
||||
if (!dedupedResults.length)
|
||||
return
|
||||
|
||||
const currentValueExists = dedupedResults.some(result => `${result.type}-${result.id}` === cmdVal)
|
||||
|
||||
if (!currentValueExists)
|
||||
setCmdVal(`${dedupedResults[0].type}-${dedupedResults[0].id}`)
|
||||
}, [isCommandsMode, dedupedResults, cmdVal, setCmdVal])
|
||||
|
||||
return {
|
||||
searchResults,
|
||||
dedupedResults,
|
||||
groupedResults,
|
||||
isLoading,
|
||||
isError,
|
||||
error: error as Error | null,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,301 @@
|
||||
import type { ActionItem } from '../actions/types'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { useGotoAnythingSearch } from './use-goto-anything-search'
|
||||
|
||||
let mockContextValue = { isWorkflowPage: false, isRagPipelinePage: false }
|
||||
let mockMatchActionResult: Partial<ActionItem> | undefined
|
||||
|
||||
vi.mock('ahooks', () => ({
|
||||
useDebounce: <T>(value: T) => value,
|
||||
}))
|
||||
|
||||
vi.mock('../context', () => ({
|
||||
useGotoAnythingContext: () => mockContextValue,
|
||||
}))
|
||||
|
||||
vi.mock('../actions', () => ({
|
||||
createActions: (isWorkflowPage: boolean, isRagPipelinePage: boolean) => {
|
||||
const base = {
|
||||
slash: { key: '/', shortcut: '/' },
|
||||
app: { key: '@app', shortcut: '@app' },
|
||||
knowledge: { key: '@knowledge', shortcut: '@kb' },
|
||||
}
|
||||
if (isWorkflowPage) {
|
||||
return { ...base, node: { key: '@node', shortcut: '@node' } }
|
||||
}
|
||||
if (isRagPipelinePage) {
|
||||
return { ...base, ragNode: { key: '@node', shortcut: '@node' } }
|
||||
}
|
||||
return base
|
||||
},
|
||||
matchAction: () => mockMatchActionResult,
|
||||
}))
|
||||
|
||||
describe('useGotoAnythingSearch', () => {
|
||||
beforeEach(() => {
|
||||
mockContextValue = { isWorkflowPage: false, isRagPipelinePage: false }
|
||||
mockMatchActionResult = undefined
|
||||
})
|
||||
|
||||
describe('initialization', () => {
|
||||
it('should initialize with empty search query', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
expect(result.current.searchQuery).toBe('')
|
||||
})
|
||||
|
||||
it('should initialize cmdVal with "_"', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
expect(result.current.cmdVal).toBe('_')
|
||||
})
|
||||
|
||||
it('should initialize searchMode as "general"', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
expect(result.current.searchMode).toBe('general')
|
||||
})
|
||||
|
||||
it('should initialize isCommandsMode as false', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
expect(result.current.isCommandsMode).toBe(false)
|
||||
})
|
||||
|
||||
it('should provide setSearchQuery function', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
expect(typeof result.current.setSearchQuery).toBe('function')
|
||||
})
|
||||
|
||||
it('should provide setCmdVal function', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
expect(typeof result.current.setCmdVal).toBe('function')
|
||||
})
|
||||
|
||||
it('should provide clearSelection function', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
expect(typeof result.current.clearSelection).toBe('function')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Actions', () => {
|
||||
it('should provide Actions based on context', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
expect(result.current.Actions).toBeDefined()
|
||||
expect(typeof result.current.Actions).toBe('object')
|
||||
})
|
||||
|
||||
it('should include node action when on workflow page', () => {
|
||||
mockContextValue = { isWorkflowPage: true, isRagPipelinePage: false }
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
expect(result.current.Actions.node).toBeDefined()
|
||||
})
|
||||
|
||||
it('should include ragNode action when on RAG pipeline page', () => {
|
||||
mockContextValue = { isWorkflowPage: false, isRagPipelinePage: true }
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
expect(result.current.Actions.ragNode).toBeDefined()
|
||||
})
|
||||
|
||||
it('should not include node actions when on regular page', () => {
|
||||
mockContextValue = { isWorkflowPage: false, isRagPipelinePage: false }
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
expect(result.current.Actions.node).toBeUndefined()
|
||||
expect(result.current.Actions.ragNode).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('isCommandsMode', () => {
|
||||
it('should return true when query is exactly "@"', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
result.current.setSearchQuery('@')
|
||||
})
|
||||
|
||||
expect(result.current.isCommandsMode).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true when query is exactly "/"', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
result.current.setSearchQuery('/')
|
||||
})
|
||||
|
||||
expect(result.current.isCommandsMode).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true when query starts with "@" and no action matches', () => {
|
||||
mockMatchActionResult = undefined
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
result.current.setSearchQuery('@unknown')
|
||||
})
|
||||
|
||||
expect(result.current.isCommandsMode).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true when query starts with "/" and no action matches', () => {
|
||||
mockMatchActionResult = undefined
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
result.current.setSearchQuery('/unknown')
|
||||
})
|
||||
|
||||
expect(result.current.isCommandsMode).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false when query starts with "@" and action matches', () => {
|
||||
mockMatchActionResult = { key: '@app', shortcut: '@app' }
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
result.current.setSearchQuery('@app test')
|
||||
})
|
||||
|
||||
expect(result.current.isCommandsMode).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for regular search query', () => {
|
||||
mockMatchActionResult = undefined
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
result.current.setSearchQuery('hello world')
|
||||
})
|
||||
|
||||
expect(result.current.isCommandsMode).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('searchMode', () => {
|
||||
it('should return "general" when query is empty', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
expect(result.current.searchMode).toBe('general')
|
||||
})
|
||||
|
||||
it('should return "scopes" when in commands mode and query starts with "@"', () => {
|
||||
mockMatchActionResult = undefined
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
result.current.setSearchQuery('@')
|
||||
})
|
||||
|
||||
expect(result.current.searchMode).toBe('scopes')
|
||||
})
|
||||
|
||||
it('should return "commands" when in commands mode and query starts with "/"', () => {
|
||||
mockMatchActionResult = undefined
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
result.current.setSearchQuery('/')
|
||||
})
|
||||
|
||||
expect(result.current.searchMode).toBe('commands')
|
||||
})
|
||||
|
||||
it('should return "general" when no action matches', () => {
|
||||
mockMatchActionResult = undefined
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
result.current.setSearchQuery('hello')
|
||||
})
|
||||
|
||||
expect(result.current.searchMode).toBe('general')
|
||||
})
|
||||
|
||||
it('should return action key when action matches', () => {
|
||||
mockMatchActionResult = { key: '@app', shortcut: '@app' }
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
result.current.setSearchQuery('@app test')
|
||||
})
|
||||
|
||||
expect(result.current.searchMode).toBe('@app')
|
||||
})
|
||||
|
||||
it('should return "@command" when action key is "/"', () => {
|
||||
mockMatchActionResult = { key: '/', shortcut: '/' }
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
result.current.setSearchQuery('/theme dark')
|
||||
})
|
||||
|
||||
expect(result.current.searchMode).toBe('@command')
|
||||
})
|
||||
})
|
||||
|
||||
describe('clearSelection', () => {
|
||||
it('should reset cmdVal to "_"', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
// First change cmdVal
|
||||
act(() => {
|
||||
result.current.setCmdVal('app-1')
|
||||
})
|
||||
expect(result.current.cmdVal).toBe('app-1')
|
||||
|
||||
// Then clear
|
||||
act(() => {
|
||||
result.current.clearSelection()
|
||||
})
|
||||
|
||||
expect(result.current.cmdVal).toBe('_')
|
||||
})
|
||||
})
|
||||
|
||||
describe('setSearchQuery', () => {
|
||||
it('should update search query', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
result.current.setSearchQuery('test query')
|
||||
})
|
||||
|
||||
expect(result.current.searchQuery).toBe('test query')
|
||||
})
|
||||
|
||||
it('should handle empty string', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
result.current.setSearchQuery('test')
|
||||
})
|
||||
expect(result.current.searchQuery).toBe('test')
|
||||
|
||||
act(() => {
|
||||
result.current.setSearchQuery('')
|
||||
})
|
||||
expect(result.current.searchQuery).toBe('')
|
||||
})
|
||||
})
|
||||
|
||||
describe('setCmdVal', () => {
|
||||
it('should update cmdVal', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
result.current.setCmdVal('plugin-2')
|
||||
})
|
||||
|
||||
expect(result.current.cmdVal).toBe('plugin-2')
|
||||
})
|
||||
})
|
||||
|
||||
describe('searchQueryDebouncedValue', () => {
|
||||
it('should return trimmed debounced value', () => {
|
||||
const { result } = renderHook(() => useGotoAnythingSearch())
|
||||
|
||||
act(() => {
|
||||
result.current.setSearchQuery(' test ')
|
||||
})
|
||||
|
||||
// Since we mock useDebounce to return value directly
|
||||
expect(result.current.searchQueryDebouncedValue).toBe('test')
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,77 @@
|
||||
'use client'
|
||||
|
||||
import type { ActionItem } from '../actions/types'
|
||||
import { useDebounce } from 'ahooks'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { createActions, matchAction } from '../actions'
|
||||
import { useGotoAnythingContext } from '../context'
|
||||
|
||||
export type UseGotoAnythingSearchReturn = {
|
||||
searchQuery: string
|
||||
setSearchQuery: (query: string) => void
|
||||
searchQueryDebouncedValue: string
|
||||
searchMode: string
|
||||
isCommandsMode: boolean
|
||||
cmdVal: string
|
||||
setCmdVal: (val: string) => void
|
||||
clearSelection: () => void
|
||||
Actions: Record<string, ActionItem>
|
||||
}
|
||||
|
||||
export const useGotoAnythingSearch = (): UseGotoAnythingSearchReturn => {
|
||||
const { isWorkflowPage, isRagPipelinePage } = useGotoAnythingContext()
|
||||
const [searchQuery, setSearchQuery] = useState<string>('')
|
||||
const [cmdVal, setCmdVal] = useState<string>('_')
|
||||
|
||||
// Filter actions based on context
|
||||
const Actions = useMemo(() => {
|
||||
return createActions(isWorkflowPage, isRagPipelinePage)
|
||||
}, [isWorkflowPage, isRagPipelinePage])
|
||||
|
||||
const searchQueryDebouncedValue = useDebounce(searchQuery.trim(), {
|
||||
wait: 300,
|
||||
})
|
||||
|
||||
const isCommandsMode = useMemo(() => {
|
||||
const trimmed = searchQuery.trim()
|
||||
return trimmed === '@' || trimmed === '/'
|
||||
|| (trimmed.startsWith('@') && !matchAction(trimmed, Actions))
|
||||
|| (trimmed.startsWith('/') && !matchAction(trimmed, Actions))
|
||||
}, [searchQuery, Actions])
|
||||
|
||||
const searchMode = useMemo(() => {
|
||||
if (isCommandsMode) {
|
||||
// Distinguish between @ (scopes) and / (commands) mode
|
||||
if (searchQuery.trim().startsWith('@'))
|
||||
return 'scopes'
|
||||
else if (searchQuery.trim().startsWith('/'))
|
||||
return 'commands'
|
||||
return 'commands' // default fallback
|
||||
}
|
||||
|
||||
const query = searchQueryDebouncedValue.toLowerCase()
|
||||
const action = matchAction(query, Actions)
|
||||
|
||||
if (!action)
|
||||
return 'general'
|
||||
|
||||
return action.key === '/' ? '@command' : action.key
|
||||
}, [searchQueryDebouncedValue, Actions, isCommandsMode, searchQuery])
|
||||
|
||||
// Prevent automatic selection of the first option when cmdVal is not set
|
||||
const clearSelection = useCallback(() => {
|
||||
setCmdVal('_')
|
||||
}, [])
|
||||
|
||||
return {
|
||||
searchQuery,
|
||||
setSearchQuery,
|
||||
searchQueryDebouncedValue,
|
||||
searchMode,
|
||||
isCommandsMode,
|
||||
cmdVal,
|
||||
setCmdVal,
|
||||
clearSelection,
|
||||
Actions,
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,27 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import type { ActionItem, SearchResult } from './actions/types'
|
||||
import { act, render, screen } from '@testing-library/react'
|
||||
import { act, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import * as React from 'react'
|
||||
import GotoAnything from './index'
|
||||
|
||||
// Test helper type that matches SearchResult but allows ReactNode for icon and flexible data
|
||||
type TestSearchResult = Omit<SearchResult, 'icon' | 'data'> & {
|
||||
icon?: ReactNode
|
||||
data?: Record<string, unknown>
|
||||
}
|
||||
|
||||
// Mock react-i18next to return namespace.key format
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: { ns?: string }) => {
|
||||
const ns = options?.ns || 'common'
|
||||
return `${ns}.${key}`
|
||||
},
|
||||
i18n: { language: 'en' },
|
||||
}),
|
||||
}))
|
||||
|
||||
const routerPush = vi.fn()
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
@@ -12,10 +30,15 @@ vi.mock('next/navigation', () => ({
|
||||
usePathname: () => '/',
|
||||
}))
|
||||
|
||||
const keyPressHandlers: Record<string, (event: any) => void> = {}
|
||||
type KeyPressEvent = {
|
||||
preventDefault: () => void
|
||||
target?: EventTarget
|
||||
}
|
||||
|
||||
const keyPressHandlers: Record<string, (event: KeyPressEvent) => void> = {}
|
||||
vi.mock('ahooks', () => ({
|
||||
useDebounce: (value: any) => value,
|
||||
useKeyPress: (keys: string | string[], handler: (event: any) => void) => {
|
||||
useDebounce: <T,>(value: T) => value,
|
||||
useKeyPress: (keys: string | string[], handler: (event: KeyPressEvent) => void) => {
|
||||
const keyList = Array.isArray(keys) ? keys : [keys]
|
||||
keyList.forEach((key) => {
|
||||
keyPressHandlers[key] = handler
|
||||
@@ -32,7 +55,7 @@ const triggerKeyPress = (combo: string) => {
|
||||
}
|
||||
}
|
||||
|
||||
let mockQueryResult = { data: [] as SearchResult[], isLoading: false, isError: false, error: null as Error | null }
|
||||
let mockQueryResult = { data: [] as TestSearchResult[], isLoading: false, isError: false, error: null as Error | null }
|
||||
vi.mock('@tanstack/react-query', () => ({
|
||||
useQuery: () => mockQueryResult,
|
||||
}))
|
||||
@@ -76,9 +99,16 @@ vi.mock('./actions/commands', () => ({
|
||||
SlashCommandProvider: () => null,
|
||||
}))
|
||||
|
||||
type MockSlashCommand = {
|
||||
mode: string
|
||||
execute?: () => void
|
||||
isAvailable?: () => boolean
|
||||
} | null
|
||||
|
||||
let mockFindCommand: MockSlashCommand = null
|
||||
vi.mock('./actions/commands/registry', () => ({
|
||||
slashCommandRegistry: {
|
||||
findCommand: () => null,
|
||||
findCommand: () => mockFindCommand,
|
||||
getAvailableCommands: () => [],
|
||||
getAllCommands: () => [],
|
||||
},
|
||||
@@ -86,6 +116,7 @@ vi.mock('./actions/commands/registry', () => ({
|
||||
|
||||
vi.mock('@/app/components/workflow/utils/common', () => ({
|
||||
getKeyboardKeyCodeBySystem: () => 'ctrl',
|
||||
getKeyboardKeyNameBySystem: (key: string) => key,
|
||||
isEventTargetInputArea: () => false,
|
||||
isMac: () => false,
|
||||
}))
|
||||
@@ -95,10 +126,11 @@ vi.mock('@/app/components/workflow/utils/node-navigation', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('../plugins/install-plugin/install-from-marketplace', () => ({
|
||||
default: (props: { manifest?: { name?: string }, onClose: () => void }) => (
|
||||
default: (props: { manifest?: { name?: string }, onClose: () => void, onSuccess: () => void }) => (
|
||||
<div data-testid="install-modal">
|
||||
<span>{props.manifest?.name}</span>
|
||||
<button onClick={props.onClose}>close</button>
|
||||
<button onClick={props.onClose} data-testid="close-install">close</button>
|
||||
<button onClick={props.onSuccess} data-testid="success-install">success</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
@@ -110,65 +142,504 @@ describe('GotoAnything', () => {
|
||||
mockQueryResult = { data: [], isLoading: false, isError: false, error: null }
|
||||
matchActionMock.mockReset()
|
||||
searchAnythingMock.mockClear()
|
||||
mockFindCommand = null
|
||||
})
|
||||
|
||||
it('should open modal via shortcut and navigate to selected result', async () => {
|
||||
mockQueryResult = {
|
||||
data: [{
|
||||
id: 'app-1',
|
||||
type: 'app',
|
||||
title: 'Sample App',
|
||||
description: 'desc',
|
||||
path: '/apps/1',
|
||||
icon: <div data-testid="icon">🧩</div>,
|
||||
data: {},
|
||||
} as any],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
describe('modal behavior', () => {
|
||||
it('should open modal via Ctrl+K shortcut', async () => {
|
||||
render(<GotoAnything />)
|
||||
|
||||
render(<GotoAnything />)
|
||||
triggerKeyPress('ctrl.k')
|
||||
|
||||
triggerKeyPress('ctrl.k')
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
const input = await screen.findByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await userEvent.type(input, 'app')
|
||||
it('should close modal via ESC key', async () => {
|
||||
render(<GotoAnything />)
|
||||
|
||||
const result = await screen.findByText('Sample App')
|
||||
await userEvent.click(result)
|
||||
triggerKeyPress('ctrl.k')
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(routerPush).toHaveBeenCalledWith('/apps/1')
|
||||
triggerKeyPress('esc')
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByPlaceholderText('app.gotoAnything.searchPlaceholder')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should toggle modal when pressing Ctrl+K twice', async () => {
|
||||
render(<GotoAnything />)
|
||||
|
||||
triggerKeyPress('ctrl.k')
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
triggerKeyPress('ctrl.k')
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByPlaceholderText('app.gotoAnything.searchPlaceholder')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should call onHide when modal closes', async () => {
|
||||
const onHide = vi.fn()
|
||||
render(<GotoAnything onHide={onHide} />)
|
||||
|
||||
triggerKeyPress('ctrl.k')
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
triggerKeyPress('esc')
|
||||
await waitFor(() => {
|
||||
expect(onHide).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should reset search query when modal opens', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<GotoAnything />)
|
||||
|
||||
// Open modal first time
|
||||
triggerKeyPress('ctrl.k')
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Type something
|
||||
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await user.type(input, 'test')
|
||||
|
||||
// Close modal
|
||||
triggerKeyPress('esc')
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByPlaceholderText('app.gotoAnything.searchPlaceholder')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Open modal again - should be empty
|
||||
triggerKeyPress('ctrl.k')
|
||||
await waitFor(() => {
|
||||
const newInput = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
expect(newInput).toHaveValue('')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should open plugin installer when selecting plugin result', async () => {
|
||||
mockQueryResult = {
|
||||
data: [{
|
||||
id: 'plugin-1',
|
||||
type: 'plugin',
|
||||
title: 'Plugin Item',
|
||||
description: 'desc',
|
||||
path: '',
|
||||
icon: <div />,
|
||||
data: {
|
||||
name: 'Plugin Item',
|
||||
latest_package_identifier: 'pkg',
|
||||
},
|
||||
} as any],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
describe('search functionality', () => {
|
||||
it('should navigate to selected result', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockQueryResult = {
|
||||
data: [{
|
||||
id: 'app-1',
|
||||
type: 'app',
|
||||
title: 'Sample App',
|
||||
description: 'desc',
|
||||
path: '/apps/1',
|
||||
icon: <div data-testid="icon">🧩</div>,
|
||||
data: {},
|
||||
}],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
|
||||
render(<GotoAnything />)
|
||||
render(<GotoAnything />)
|
||||
triggerKeyPress('ctrl.k')
|
||||
|
||||
triggerKeyPress('ctrl.k')
|
||||
const input = await screen.findByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await userEvent.type(input, 'plugin')
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const pluginItem = await screen.findByText('Plugin Item')
|
||||
await userEvent.click(pluginItem)
|
||||
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await user.type(input, 'app')
|
||||
|
||||
expect(await screen.findByTestId('install-modal')).toHaveTextContent('Plugin Item')
|
||||
const result = await screen.findByText('Sample App')
|
||||
await user.click(result)
|
||||
|
||||
expect(routerPush).toHaveBeenCalledWith('/apps/1')
|
||||
})
|
||||
|
||||
it('should clear selection when typing without prefix', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<GotoAnything />)
|
||||
triggerKeyPress('ctrl.k')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await user.type(input, 'test query')
|
||||
|
||||
// Should not throw and input should have value
|
||||
expect(input).toHaveValue('test query')
|
||||
})
|
||||
})
|
||||
|
||||
describe('empty states', () => {
|
||||
it('should show loading state', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockQueryResult = {
|
||||
data: [],
|
||||
isLoading: true,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
|
||||
render(<GotoAnything />)
|
||||
triggerKeyPress('ctrl.k')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await user.type(input, 'search')
|
||||
|
||||
// Loading state shows in both EmptyState (spinner) and Footer
|
||||
const searchingTexts = screen.getAllByText('app.gotoAnything.searching')
|
||||
expect(searchingTexts.length).toBeGreaterThanOrEqual(1)
|
||||
})
|
||||
|
||||
it('should show error state', async () => {
|
||||
const user = userEvent.setup()
|
||||
const testError = new Error('Search failed')
|
||||
mockQueryResult = {
|
||||
data: [],
|
||||
isLoading: false,
|
||||
isError: true,
|
||||
error: testError,
|
||||
}
|
||||
|
||||
render(<GotoAnything />)
|
||||
triggerKeyPress('ctrl.k')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await user.type(input, 'search')
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.searchFailed')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show default state when no query', async () => {
|
||||
render(<GotoAnything />)
|
||||
triggerKeyPress('ctrl.k')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.searchTitle')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show no results state when search returns empty', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockQueryResult = {
|
||||
data: [],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
|
||||
render(<GotoAnything />)
|
||||
triggerKeyPress('ctrl.k')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await user.type(input, 'nonexistent')
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('plugin installation', () => {
|
||||
it('should open plugin installer when selecting plugin result', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockQueryResult = {
|
||||
data: [{
|
||||
id: 'plugin-1',
|
||||
type: 'plugin',
|
||||
title: 'Plugin Item',
|
||||
description: 'desc',
|
||||
path: '',
|
||||
icon: <div />,
|
||||
data: {
|
||||
name: 'Plugin Item',
|
||||
latest_package_identifier: 'pkg',
|
||||
},
|
||||
}],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
|
||||
render(<GotoAnything />)
|
||||
triggerKeyPress('ctrl.k')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await user.type(input, 'plugin')
|
||||
|
||||
const pluginItem = await screen.findByText('Plugin Item')
|
||||
await user.click(pluginItem)
|
||||
|
||||
expect(await screen.findByTestId('install-modal')).toHaveTextContent('Plugin Item')
|
||||
})
|
||||
|
||||
it('should close plugin installer via close button', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockQueryResult = {
|
||||
data: [{
|
||||
id: 'plugin-1',
|
||||
type: 'plugin',
|
||||
title: 'Plugin Item',
|
||||
description: 'desc',
|
||||
path: '',
|
||||
icon: <div />,
|
||||
data: {
|
||||
name: 'Plugin Item',
|
||||
latest_package_identifier: 'pkg',
|
||||
},
|
||||
}],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
|
||||
render(<GotoAnything />)
|
||||
triggerKeyPress('ctrl.k')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await user.type(input, 'plugin')
|
||||
|
||||
const pluginItem = await screen.findByText('Plugin Item')
|
||||
await user.click(pluginItem)
|
||||
|
||||
const closeBtn = await screen.findByTestId('close-install')
|
||||
await user.click(closeBtn)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('install-modal')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should close plugin installer on success', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockQueryResult = {
|
||||
data: [{
|
||||
id: 'plugin-1',
|
||||
type: 'plugin',
|
||||
title: 'Plugin Item',
|
||||
description: 'desc',
|
||||
path: '',
|
||||
icon: <div />,
|
||||
data: {
|
||||
name: 'Plugin Item',
|
||||
latest_package_identifier: 'pkg',
|
||||
},
|
||||
}],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
|
||||
render(<GotoAnything />)
|
||||
triggerKeyPress('ctrl.k')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await user.type(input, 'plugin')
|
||||
|
||||
const pluginItem = await screen.findByText('Plugin Item')
|
||||
await user.click(pluginItem)
|
||||
|
||||
const successBtn = await screen.findByTestId('success-install')
|
||||
await user.click(successBtn)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('install-modal')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('slash command handling', () => {
|
||||
it('should execute direct slash command on Enter', async () => {
|
||||
const user = userEvent.setup()
|
||||
const executeMock = vi.fn()
|
||||
mockFindCommand = {
|
||||
mode: 'direct',
|
||||
execute: executeMock,
|
||||
isAvailable: () => true,
|
||||
}
|
||||
|
||||
render(<GotoAnything />)
|
||||
triggerKeyPress('ctrl.k')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await user.type(input, '/theme')
|
||||
await user.keyboard('{Enter}')
|
||||
|
||||
expect(executeMock).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should NOT execute unavailable slash command', async () => {
|
||||
const user = userEvent.setup()
|
||||
const executeMock = vi.fn()
|
||||
mockFindCommand = {
|
||||
mode: 'direct',
|
||||
execute: executeMock,
|
||||
isAvailable: () => false,
|
||||
}
|
||||
|
||||
render(<GotoAnything />)
|
||||
triggerKeyPress('ctrl.k')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await user.type(input, '/theme')
|
||||
await user.keyboard('{Enter}')
|
||||
|
||||
expect(executeMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should NOT execute non-direct mode slash command on Enter', async () => {
|
||||
const user = userEvent.setup()
|
||||
const executeMock = vi.fn()
|
||||
mockFindCommand = {
|
||||
mode: 'submenu',
|
||||
execute: executeMock,
|
||||
}
|
||||
|
||||
render(<GotoAnything />)
|
||||
triggerKeyPress('ctrl.k')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await user.type(input, '/language')
|
||||
await user.keyboard('{Enter}')
|
||||
|
||||
expect(executeMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should close modal after executing direct slash command', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockFindCommand = {
|
||||
mode: 'direct',
|
||||
execute: vi.fn(),
|
||||
isAvailable: () => true,
|
||||
}
|
||||
|
||||
render(<GotoAnything />)
|
||||
triggerKeyPress('ctrl.k')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await user.type(input, '/theme')
|
||||
await user.keyboard('{Enter}')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByPlaceholderText('app.gotoAnything.searchPlaceholder')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('result navigation', () => {
|
||||
it('should handle knowledge result navigation', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockQueryResult = {
|
||||
data: [{
|
||||
id: 'kb-1',
|
||||
type: 'knowledge',
|
||||
title: 'Knowledge Base',
|
||||
description: 'desc',
|
||||
path: '/datasets/kb-1',
|
||||
icon: <div />,
|
||||
data: {},
|
||||
}],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
|
||||
render(<GotoAnything />)
|
||||
triggerKeyPress('ctrl.k')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await user.type(input, 'knowledge')
|
||||
|
||||
const result = await screen.findByText('Knowledge Base')
|
||||
await user.click(result)
|
||||
|
||||
expect(routerPush).toHaveBeenCalledWith('/datasets/kb-1')
|
||||
})
|
||||
|
||||
it('should NOT navigate when result has no path', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockQueryResult = {
|
||||
data: [{
|
||||
id: 'item-1',
|
||||
type: 'app',
|
||||
title: 'No Path Item',
|
||||
description: 'desc',
|
||||
path: '',
|
||||
icon: <div />,
|
||||
data: {},
|
||||
}],
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
}
|
||||
|
||||
render(<GotoAnything />)
|
||||
triggerKeyPress('ctrl.k')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const input = screen.getByPlaceholderText('app.gotoAnything.searchPlaceholder')
|
||||
await user.type(input, 'no path')
|
||||
|
||||
const result = await screen.findByText('No Path Item')
|
||||
await user.click(result)
|
||||
|
||||
expect(routerPush).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,300 +1,149 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { Plugin } from '../plugins/types'
|
||||
import type { SearchResult } from './actions'
|
||||
import { RiSearchLine } from '@remixicon/react'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { useDebounce, useKeyPress } from 'ahooks'
|
||||
import type { FC, KeyboardEvent } from 'react'
|
||||
import { Command } from 'cmdk'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useCallback, useEffect, useMemo, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Modal from '@/app/components/base/modal'
|
||||
import ShortcutsName from '@/app/components/workflow/shortcuts-name'
|
||||
import { getKeyboardKeyCodeBySystem, isEventTargetInputArea } from '@/app/components/workflow/utils/common'
|
||||
import { selectWorkflowNode } from '@/app/components/workflow/utils/node-navigation'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import InstallFromMarketplace from '../plugins/install-plugin/install-from-marketplace'
|
||||
import { createActions, matchAction, searchAnything } from './actions'
|
||||
import { SlashCommandProvider } from './actions/commands'
|
||||
import { slashCommandRegistry } from './actions/commands/registry'
|
||||
import CommandSelector from './command-selector'
|
||||
import { EmptyState, Footer, ResultList, SearchInput } from './components'
|
||||
import { GotoAnythingProvider, useGotoAnythingContext } from './context'
|
||||
import {
|
||||
useGotoAnythingModal,
|
||||
useGotoAnythingNavigation,
|
||||
useGotoAnythingResults,
|
||||
useGotoAnythingSearch,
|
||||
} from './hooks'
|
||||
|
||||
type Props = {
|
||||
onHide?: () => void
|
||||
}
|
||||
|
||||
const GotoAnything: FC<Props> = ({
|
||||
onHide,
|
||||
}) => {
|
||||
const router = useRouter()
|
||||
const defaultLocale = useGetLanguage()
|
||||
const { isWorkflowPage, isRagPipelinePage } = useGotoAnythingContext()
|
||||
const { t } = useTranslation()
|
||||
const [show, setShow] = useState<boolean>(false)
|
||||
const [searchQuery, setSearchQuery] = useState<string>('')
|
||||
const [cmdVal, setCmdVal] = useState<string>('_')
|
||||
const inputRef = useRef<HTMLInputElement>(null)
|
||||
const { isWorkflowPage, isRagPipelinePage } = useGotoAnythingContext()
|
||||
const prevShowRef = useRef(false)
|
||||
|
||||
// Filter actions based on context
|
||||
const Actions = useMemo(() => {
|
||||
// Create actions based on current page context
|
||||
return createActions(isWorkflowPage, isRagPipelinePage)
|
||||
}, [isWorkflowPage, isRagPipelinePage])
|
||||
// Search state management (called first so setSearchQuery is available)
|
||||
const {
|
||||
searchQuery,
|
||||
setSearchQuery,
|
||||
searchQueryDebouncedValue,
|
||||
searchMode,
|
||||
isCommandsMode,
|
||||
cmdVal,
|
||||
setCmdVal,
|
||||
clearSelection,
|
||||
Actions,
|
||||
} = useGotoAnythingSearch()
|
||||
|
||||
const [activePlugin, setActivePlugin] = useState<Plugin>()
|
||||
// Modal state management
|
||||
const {
|
||||
show,
|
||||
setShow,
|
||||
inputRef,
|
||||
handleClose: modalClose,
|
||||
} = useGotoAnythingModal()
|
||||
|
||||
// Handle keyboard shortcuts
|
||||
const handleToggleModal = useCallback((e: KeyboardEvent) => {
|
||||
// Allow closing when modal is open, even if focus is in the search input
|
||||
if (!show && isEventTargetInputArea(e.target as HTMLElement))
|
||||
return
|
||||
e.preventDefault()
|
||||
setShow((prev) => {
|
||||
if (!prev) {
|
||||
// Opening modal - reset search state
|
||||
setSearchQuery('')
|
||||
}
|
||||
return !prev
|
||||
})
|
||||
}, [show])
|
||||
|
||||
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.k`, handleToggleModal, {
|
||||
exactMatch: true,
|
||||
useCapture: true,
|
||||
})
|
||||
|
||||
useKeyPress(['esc'], (e) => {
|
||||
if (show) {
|
||||
e.preventDefault()
|
||||
setShow(false)
|
||||
// Reset state when modal opens/closes
|
||||
useEffect(() => {
|
||||
if (show && !prevShowRef.current) {
|
||||
// Modal just opened - reset search
|
||||
setSearchQuery('')
|
||||
}
|
||||
else if (!show && prevShowRef.current) {
|
||||
// Modal just closed
|
||||
setSearchQuery('')
|
||||
clearSelection()
|
||||
onHide?.()
|
||||
}
|
||||
prevShowRef.current = show
|
||||
}, [show, setSearchQuery, clearSelection, onHide])
|
||||
|
||||
// Results fetching and processing
|
||||
const {
|
||||
dedupedResults,
|
||||
groupedResults,
|
||||
isLoading,
|
||||
isError,
|
||||
error,
|
||||
} = useGotoAnythingResults({
|
||||
searchQueryDebouncedValue,
|
||||
searchMode,
|
||||
isCommandsMode,
|
||||
Actions,
|
||||
isWorkflowPage,
|
||||
isRagPipelinePage,
|
||||
cmdVal,
|
||||
setCmdVal,
|
||||
})
|
||||
|
||||
const searchQueryDebouncedValue = useDebounce(searchQuery.trim(), {
|
||||
wait: 300,
|
||||
// Navigation handlers
|
||||
const {
|
||||
handleCommandSelect,
|
||||
handleNavigate,
|
||||
activePlugin,
|
||||
setActivePlugin,
|
||||
} = useGotoAnythingNavigation({
|
||||
Actions,
|
||||
setSearchQuery,
|
||||
clearSelection,
|
||||
inputRef,
|
||||
onClose: () => setShow(false),
|
||||
})
|
||||
|
||||
const isCommandsMode = searchQuery.trim() === '@' || searchQuery.trim() === '/'
|
||||
|| (searchQuery.trim().startsWith('@') && !matchAction(searchQuery.trim(), Actions))
|
||||
|| (searchQuery.trim().startsWith('/') && !matchAction(searchQuery.trim(), Actions))
|
||||
// Handle search input change
|
||||
const handleSearchChange = useCallback((value: string) => {
|
||||
setSearchQuery(value)
|
||||
if (!value.startsWith('@') && !value.startsWith('/'))
|
||||
clearSelection()
|
||||
}, [setSearchQuery, clearSelection])
|
||||
|
||||
const searchMode = useMemo(() => {
|
||||
if (isCommandsMode) {
|
||||
// Distinguish between @ (scopes) and / (commands) mode
|
||||
if (searchQuery.trim().startsWith('@'))
|
||||
return 'scopes'
|
||||
else if (searchQuery.trim().startsWith('/'))
|
||||
return 'commands'
|
||||
return 'commands' // default fallback
|
||||
}
|
||||
// Handle search input keydown for slash commands
|
||||
const handleSearchKeyDown = useCallback((e: KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === 'Enter') {
|
||||
const query = searchQuery.trim()
|
||||
// Check if it's a complete slash command
|
||||
if (query.startsWith('/')) {
|
||||
const commandName = query.substring(1).split(' ')[0]
|
||||
const handler = slashCommandRegistry.findCommand(commandName)
|
||||
|
||||
const query = searchQueryDebouncedValue.toLowerCase()
|
||||
const action = matchAction(query, Actions)
|
||||
|
||||
if (!action)
|
||||
return 'general'
|
||||
|
||||
return action.key === '/' ? '@command' : action.key
|
||||
}, [searchQueryDebouncedValue, Actions, isCommandsMode, searchQuery])
|
||||
|
||||
const { data: searchResults = [], isLoading, isError, error } = useQuery(
|
||||
{
|
||||
queryKey: [
|
||||
'goto-anything',
|
||||
'search-result',
|
||||
searchQueryDebouncedValue,
|
||||
searchMode,
|
||||
isWorkflowPage,
|
||||
isRagPipelinePage,
|
||||
defaultLocale,
|
||||
Actions,
|
||||
],
|
||||
queryFn: async () => {
|
||||
const query = searchQueryDebouncedValue.toLowerCase()
|
||||
const action = matchAction(query, Actions)
|
||||
return await searchAnything(defaultLocale, query, action, Actions)
|
||||
},
|
||||
enabled: !!searchQueryDebouncedValue && !isCommandsMode,
|
||||
staleTime: 30000,
|
||||
gcTime: 300000,
|
||||
},
|
||||
)
|
||||
|
||||
// Prevent automatic selection of the first option when cmdVal is not set
|
||||
const clearSelection = () => {
|
||||
setCmdVal('_')
|
||||
}
|
||||
|
||||
const handleCommandSelect = useCallback((commandKey: string) => {
|
||||
// Check if it's a slash command
|
||||
if (commandKey.startsWith('/')) {
|
||||
const commandName = commandKey.substring(1)
|
||||
const handler = slashCommandRegistry.findCommand(commandName)
|
||||
|
||||
// If it's a direct mode command, execute immediately
|
||||
if (handler?.mode === 'direct' && handler.execute) {
|
||||
handler.execute()
|
||||
setShow(false)
|
||||
setSearchQuery('')
|
||||
return
|
||||
// If it's a direct mode command, execute immediately
|
||||
const isAvailable = handler?.isAvailable?.() ?? true
|
||||
if (handler?.mode === 'direct' && handler.execute && isAvailable) {
|
||||
e.preventDefault()
|
||||
handler.execute()
|
||||
setShow(false)
|
||||
setSearchQuery('')
|
||||
}
|
||||
}
|
||||
}
|
||||
}, [searchQuery, setShow, setSearchQuery])
|
||||
|
||||
// Otherwise, proceed with the normal flow (submenu mode)
|
||||
setSearchQuery(`${commandKey} `)
|
||||
clearSelection()
|
||||
setTimeout(() => {
|
||||
inputRef.current?.focus()
|
||||
}, 0)
|
||||
}, [])
|
||||
|
||||
// Handle navigation to selected result
|
||||
const handleNavigate = useCallback((result: SearchResult) => {
|
||||
setShow(false)
|
||||
setSearchQuery('')
|
||||
|
||||
switch (result.type) {
|
||||
case 'command': {
|
||||
// Execute slash commands
|
||||
const action = Actions.slash
|
||||
action?.action?.(result)
|
||||
break
|
||||
}
|
||||
case 'plugin':
|
||||
setActivePlugin(result.data)
|
||||
break
|
||||
case 'workflow-node':
|
||||
// Handle workflow node selection and navigation
|
||||
if (result.metadata?.nodeId)
|
||||
selectWorkflowNode(result.metadata.nodeId, true)
|
||||
|
||||
break
|
||||
default:
|
||||
if (result.path)
|
||||
router.push(result.path)
|
||||
}
|
||||
}, [router])
|
||||
|
||||
const dedupedResults = useMemo(() => {
|
||||
const seen = new Set<string>()
|
||||
return searchResults.filter((result) => {
|
||||
const key = `${result.type}-${result.id}`
|
||||
if (seen.has(key))
|
||||
return false
|
||||
seen.add(key)
|
||||
return true
|
||||
})
|
||||
}, [searchResults])
|
||||
|
||||
// Group results by type
|
||||
const groupedResults = useMemo(() => dedupedResults.reduce((acc, result) => {
|
||||
if (!acc[result.type])
|
||||
acc[result.type] = []
|
||||
|
||||
acc[result.type].push(result)
|
||||
return acc
|
||||
}, {} as { [key: string]: SearchResult[] }), [dedupedResults])
|
||||
|
||||
useEffect(() => {
|
||||
if (isCommandsMode)
|
||||
return
|
||||
|
||||
if (!dedupedResults.length)
|
||||
return
|
||||
|
||||
const currentValueExists = dedupedResults.some(result => `${result.type}-${result.id}` === cmdVal)
|
||||
|
||||
if (!currentValueExists)
|
||||
setCmdVal(`${dedupedResults[0].type}-${dedupedResults[0].id}`)
|
||||
}, [isCommandsMode, dedupedResults, cmdVal])
|
||||
|
||||
const emptyResult = useMemo(() => {
|
||||
if (dedupedResults.length || !searchQuery.trim() || isLoading || isCommandsMode)
|
||||
return null
|
||||
|
||||
const isCommandSearch = searchMode !== 'general'
|
||||
const commandType = isCommandSearch ? searchMode.replace('@', '') : ''
|
||||
|
||||
if (isError) {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
|
||||
<div>
|
||||
<div className="text-sm font-medium text-red-500">{t('gotoAnything.searchTemporarilyUnavailable', { ns: 'app' })}</div>
|
||||
<div className="mt-1 text-xs text-text-quaternary">
|
||||
{t('gotoAnything.servicesUnavailableMessage', { ns: 'app' })}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
|
||||
<div>
|
||||
<div className="text-sm font-medium">
|
||||
{isCommandSearch
|
||||
? (() => {
|
||||
const keyMap = {
|
||||
app: 'gotoAnything.emptyState.noAppsFound',
|
||||
plugin: 'gotoAnything.emptyState.noPluginsFound',
|
||||
knowledge: 'gotoAnything.emptyState.noKnowledgeBasesFound',
|
||||
node: 'gotoAnything.emptyState.noWorkflowNodesFound',
|
||||
} as const
|
||||
return t(keyMap[commandType as keyof typeof keyMap] || 'gotoAnything.noResults', { ns: 'app' })
|
||||
})()
|
||||
: t('gotoAnything.noResults', { ns: 'app' })}
|
||||
</div>
|
||||
<div className="mt-1 text-xs text-text-quaternary">
|
||||
{isCommandSearch
|
||||
? t('gotoAnything.emptyState.tryDifferentTerm', { ns: 'app' })
|
||||
: t('gotoAnything.emptyState.trySpecificSearch', { ns: 'app', shortcuts: Object.values(Actions).map(action => action.shortcut).join(', ') })}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}, [dedupedResults, searchQuery, Actions, searchMode, isLoading, isError, isCommandsMode])
|
||||
|
||||
const defaultUI = useMemo(() => {
|
||||
if (searchQuery.trim())
|
||||
return null
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
|
||||
<div>
|
||||
<div className="text-sm font-medium">{t('gotoAnything.searchTitle', { ns: 'app' })}</div>
|
||||
<div className="mt-3 space-y-1 text-xs text-text-quaternary">
|
||||
<div>{t('gotoAnything.searchHint', { ns: 'app' })}</div>
|
||||
<div>{t('gotoAnything.commandHint', { ns: 'app' })}</div>
|
||||
<div>{t('gotoAnything.slashHint', { ns: 'app' })}</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}, [searchQuery, Actions])
|
||||
|
||||
useEffect(() => {
|
||||
if (show) {
|
||||
requestAnimationFrame(() => {
|
||||
inputRef.current?.focus()
|
||||
})
|
||||
}
|
||||
}, [show])
|
||||
// Determine which empty state to show
|
||||
const emptyStateVariant = useMemo(() => {
|
||||
if (isLoading)
|
||||
return 'loading'
|
||||
if (isError)
|
||||
return 'error'
|
||||
if (!searchQuery.trim())
|
||||
return 'default'
|
||||
if (dedupedResults.length === 0 && !isCommandsMode)
|
||||
return 'no-results'
|
||||
return null
|
||||
}, [isLoading, isError, searchQuery, dedupedResults.length, isCommandsMode])
|
||||
|
||||
return (
|
||||
<>
|
||||
<SlashCommandProvider />
|
||||
<Modal
|
||||
isShow={show}
|
||||
onClose={() => {
|
||||
setShow(false)
|
||||
setSearchQuery('')
|
||||
clearSelection()
|
||||
onHide?.()
|
||||
}}
|
||||
onClose={modalClose}
|
||||
closable={false}
|
||||
className="!w-[480px] !p-0"
|
||||
highPriority={true}
|
||||
@@ -307,78 +156,24 @@ const GotoAnything: FC<Props> = ({
|
||||
disablePointerSelection
|
||||
loop
|
||||
>
|
||||
<div className="flex items-center gap-3 border-b border-divider-subtle bg-components-panel-bg-blur px-4 py-3">
|
||||
<RiSearchLine className="h-4 w-4 text-text-quaternary" />
|
||||
<div className="flex flex-1 items-center gap-2">
|
||||
<Input
|
||||
ref={inputRef}
|
||||
value={searchQuery}
|
||||
placeholder={t('gotoAnything.searchPlaceholder', { ns: 'app' })}
|
||||
onChange={(e) => {
|
||||
setSearchQuery(e.target.value)
|
||||
if (!e.target.value.startsWith('@') && !e.target.value.startsWith('/'))
|
||||
clearSelection()
|
||||
}}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
const query = searchQuery.trim()
|
||||
// Check if it's a complete slash command
|
||||
if (query.startsWith('/')) {
|
||||
const commandName = query.substring(1).split(' ')[0]
|
||||
const handler = slashCommandRegistry.findCommand(commandName)
|
||||
|
||||
// If it's a direct mode command, execute immediately
|
||||
const isAvailable = handler?.isAvailable?.() ?? true
|
||||
if (handler?.mode === 'direct' && handler.execute && isAvailable) {
|
||||
e.preventDefault()
|
||||
handler.execute()
|
||||
setShow(false)
|
||||
setSearchQuery('')
|
||||
}
|
||||
}
|
||||
}
|
||||
}}
|
||||
className="flex-1 !border-0 !bg-transparent !shadow-none"
|
||||
wrapperClassName="flex-1 !border-0 !bg-transparent"
|
||||
autoFocus
|
||||
/>
|
||||
{searchMode !== 'general' && (
|
||||
<div className="flex items-center gap-1 rounded bg-gray-100 px-2 py-[2px] text-xs font-medium text-gray-700 dark:bg-gray-800 dark:text-gray-300">
|
||||
<span>
|
||||
{(() => {
|
||||
if (searchMode === 'scopes')
|
||||
return 'SCOPES'
|
||||
else if (searchMode === 'commands')
|
||||
return 'COMMANDS'
|
||||
else
|
||||
return searchMode.replace('@', '').toUpperCase()
|
||||
})()}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<ShortcutsName keys={['ctrl', 'K']} textColor="secondary" />
|
||||
</div>
|
||||
<SearchInput
|
||||
inputRef={inputRef}
|
||||
value={searchQuery}
|
||||
onChange={handleSearchChange}
|
||||
onKeyDown={handleSearchKeyDown}
|
||||
searchMode={searchMode}
|
||||
placeholder={t('gotoAnything.searchPlaceholder', { ns: 'app' })}
|
||||
/>
|
||||
|
||||
<Command.List className="h-[240px] overflow-y-auto">
|
||||
{isLoading && (
|
||||
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="h-4 w-4 animate-spin rounded-full border-2 border-gray-300 border-t-gray-600"></div>
|
||||
<span className="text-sm">{t('gotoAnything.searching', { ns: 'app' })}</span>
|
||||
</div>
|
||||
</div>
|
||||
{emptyStateVariant === 'loading' && (
|
||||
<EmptyState variant="loading" />
|
||||
)}
|
||||
{isError && (
|
||||
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
|
||||
<div>
|
||||
<div className="text-sm font-medium text-red-500">{t('gotoAnything.searchFailed', { ns: 'app' })}</div>
|
||||
<div className="mt-1 text-xs text-text-quaternary">
|
||||
{error.message}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{emptyStateVariant === 'error' && (
|
||||
<EmptyState variant="error" error={error} />
|
||||
)}
|
||||
|
||||
{!isLoading && !isError && (
|
||||
<>
|
||||
{isCommandsMode
|
||||
@@ -393,118 +188,46 @@ const GotoAnything: FC<Props> = ({
|
||||
/>
|
||||
)
|
||||
: (
|
||||
Object.entries(groupedResults).map(([type, results], groupIndex) => (
|
||||
<Command.Group
|
||||
key={groupIndex}
|
||||
heading={(() => {
|
||||
const typeMap = {
|
||||
'app': 'gotoAnything.groups.apps',
|
||||
'plugin': 'gotoAnything.groups.plugins',
|
||||
'knowledge': 'gotoAnything.groups.knowledgeBases',
|
||||
'workflow-node': 'gotoAnything.groups.workflowNodes',
|
||||
'command': 'gotoAnything.groups.commands',
|
||||
} as const
|
||||
return t(typeMap[type as keyof typeof typeMap] || `${type}s`, { ns: 'app' })
|
||||
})()}
|
||||
className="p-2 capitalize text-text-secondary"
|
||||
>
|
||||
{results.map(result => (
|
||||
<Command.Item
|
||||
key={`${result.type}-${result.id}`}
|
||||
value={`${result.type}-${result.id}`}
|
||||
className="flex cursor-pointer items-center gap-3 rounded-md p-3 will-change-[background-color] hover:bg-state-base-hover aria-[selected=true]:bg-state-base-hover-alt data-[selected=true]:bg-state-base-hover-alt"
|
||||
onSelect={() => handleNavigate(result)}
|
||||
>
|
||||
{result.icon}
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="truncate font-medium text-text-secondary">
|
||||
{result.title}
|
||||
</div>
|
||||
{result.description && (
|
||||
<div className="mt-0.5 truncate text-xs text-text-quaternary">
|
||||
{result.description}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="text-xs capitalize text-text-quaternary">
|
||||
{result.type}
|
||||
</div>
|
||||
</Command.Item>
|
||||
))}
|
||||
</Command.Group>
|
||||
))
|
||||
<ResultList
|
||||
groupedResults={groupedResults}
|
||||
onSelect={handleNavigate}
|
||||
/>
|
||||
)}
|
||||
{!isCommandsMode && emptyResult}
|
||||
{!isCommandsMode && defaultUI}
|
||||
|
||||
{!isCommandsMode && emptyStateVariant === 'no-results' && (
|
||||
<EmptyState
|
||||
variant="no-results"
|
||||
searchMode={searchMode}
|
||||
Actions={Actions}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isCommandsMode && emptyStateVariant === 'default' && (
|
||||
<EmptyState variant="default" />
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Command.List>
|
||||
|
||||
{/* Always show footer to prevent height jumping */}
|
||||
<div className="border-t border-divider-subtle bg-components-panel-bg-blur px-4 py-2 text-xs text-text-tertiary">
|
||||
<div className="flex min-h-[16px] items-center justify-between">
|
||||
{(!!dedupedResults.length || isError)
|
||||
? (
|
||||
<>
|
||||
<span>
|
||||
{isError
|
||||
? (
|
||||
<span className="text-red-500">{t('gotoAnything.someServicesUnavailable', { ns: 'app' })}</span>
|
||||
)
|
||||
: (
|
||||
<>
|
||||
{t('gotoAnything.resultCount', { ns: 'app', count: dedupedResults.length })}
|
||||
{searchMode !== 'general' && (
|
||||
<span className="ml-2 opacity-60">
|
||||
{t('gotoAnything.inScope', { ns: 'app', scope: searchMode.replace('@', '') })}
|
||||
</span>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</span>
|
||||
<span className="opacity-60">
|
||||
{searchMode !== 'general'
|
||||
? t('gotoAnything.clearToSearchAll', { ns: 'app' })
|
||||
: t('gotoAnything.useAtForSpecific', { ns: 'app' })}
|
||||
</span>
|
||||
</>
|
||||
)
|
||||
: (
|
||||
<>
|
||||
<span className="opacity-60">
|
||||
{(() => {
|
||||
if (isCommandsMode)
|
||||
return t('gotoAnything.selectToNavigate', { ns: 'app' })
|
||||
|
||||
if (searchQuery.trim())
|
||||
return t('gotoAnything.searching', { ns: 'app' })
|
||||
|
||||
return t('gotoAnything.startTyping', { ns: 'app' })
|
||||
})()}
|
||||
</span>
|
||||
<span className="opacity-60">
|
||||
{searchQuery.trim() || isCommandsMode
|
||||
? t('gotoAnything.tips', { ns: 'app' })
|
||||
: t('gotoAnything.pressEscToClose', { ns: 'app' })}
|
||||
</span>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<Footer
|
||||
resultCount={dedupedResults.length}
|
||||
searchMode={searchMode}
|
||||
isError={isError}
|
||||
isCommandsMode={isCommandsMode}
|
||||
hasQuery={!!searchQuery.trim()}
|
||||
/>
|
||||
</Command>
|
||||
</div>
|
||||
|
||||
</Modal>
|
||||
{
|
||||
activePlugin && (
|
||||
<InstallFromMarketplace
|
||||
manifest={activePlugin}
|
||||
uniqueIdentifier={activePlugin.latest_package_identifier}
|
||||
onClose={() => setActivePlugin(undefined)}
|
||||
onSuccess={() => setActivePlugin(undefined)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
{activePlugin && (
|
||||
<InstallFromMarketplace
|
||||
manifest={activePlugin}
|
||||
uniqueIdentifier={activePlugin.latest_package_identifier}
|
||||
onClose={() => setActivePlugin(undefined)}
|
||||
onSuccess={() => setActivePlugin(undefined)}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,27 +1,19 @@
|
||||
'use client'
|
||||
import type { FileUpload } from '@/app/components/base/features/types'
|
||||
import type { App } from '@/types/app'
|
||||
import * as React from 'react'
|
||||
import { useMemo, useRef } from 'react'
|
||||
import { useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
|
||||
import AppInputsForm from '@/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-form'
|
||||
import { BlockEnum, InputVarType, SupportUploadFileTypes } from '@/app/components/workflow/types'
|
||||
import { useAppDetail } from '@/service/use-apps'
|
||||
import { useFileUploadConfig } from '@/service/use-common'
|
||||
import { useAppWorkflow } from '@/service/use-workflow'
|
||||
import { AppModeEnum, Resolution } from '@/types/app'
|
||||
|
||||
import { useAppInputsFormSchema } from '@/app/components/plugins/plugin-detail-panel/app-selector/hooks/use-app-inputs-form-schema'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type Props = {
|
||||
value?: {
|
||||
app_id: string
|
||||
inputs: Record<string, any>
|
||||
inputs: Record<string, unknown>
|
||||
}
|
||||
appDetail: App
|
||||
onFormChange: (value: Record<string, any>) => void
|
||||
onFormChange: (value: Record<string, unknown>) => void
|
||||
}
|
||||
|
||||
const AppInputsPanel = ({
|
||||
@@ -30,155 +22,33 @@ const AppInputsPanel = ({
|
||||
onFormChange,
|
||||
}: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const inputsRef = useRef<any>(value?.inputs || {})
|
||||
const isBasicApp = appDetail.mode !== AppModeEnum.ADVANCED_CHAT && appDetail.mode !== AppModeEnum.WORKFLOW
|
||||
const { data: fileUploadConfig } = useFileUploadConfig()
|
||||
const { data: currentApp, isFetching: isAppLoading } = useAppDetail(appDetail.id)
|
||||
const { data: currentWorkflow, isFetching: isWorkflowLoading } = useAppWorkflow(isBasicApp ? '' : appDetail.id)
|
||||
const isLoading = isAppLoading || isWorkflowLoading
|
||||
const inputsRef = useRef<Record<string, unknown>>(value?.inputs || {})
|
||||
|
||||
const basicAppFileConfig = useMemo(() => {
|
||||
let fileConfig: FileUpload
|
||||
if (isBasicApp)
|
||||
fileConfig = currentApp?.model_config?.file_upload as FileUpload
|
||||
else
|
||||
fileConfig = currentWorkflow?.features?.file_upload as FileUpload
|
||||
return {
|
||||
image: {
|
||||
detail: fileConfig?.image?.detail || Resolution.high,
|
||||
enabled: !!fileConfig?.image?.enabled,
|
||||
number_limits: fileConfig?.image?.number_limits || 3,
|
||||
transfer_methods: fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'],
|
||||
},
|
||||
enabled: !!(fileConfig?.enabled || fileConfig?.image?.enabled),
|
||||
allowed_file_types: fileConfig?.allowed_file_types || [SupportUploadFileTypes.image],
|
||||
allowed_file_extensions: fileConfig?.allowed_file_extensions || [...FILE_EXTS[SupportUploadFileTypes.image]].map(ext => `.${ext}`),
|
||||
allowed_file_upload_methods: fileConfig?.allowed_file_upload_methods || fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'],
|
||||
number_limits: fileConfig?.number_limits || fileConfig?.image?.number_limits || 3,
|
||||
}
|
||||
}, [currentApp?.model_config?.file_upload, currentWorkflow?.features?.file_upload, isBasicApp])
|
||||
const { inputFormSchema, isLoading } = useAppInputsFormSchema({ appDetail })
|
||||
|
||||
const inputFormSchema = useMemo(() => {
|
||||
if (!currentApp)
|
||||
return []
|
||||
let inputFormSchema = []
|
||||
if (isBasicApp) {
|
||||
inputFormSchema = currentApp.model_config?.user_input_form?.filter((item: any) => !item.external_data_tool).map((item: any) => {
|
||||
if (item.paragraph) {
|
||||
return {
|
||||
...item.paragraph,
|
||||
type: 'paragraph',
|
||||
required: false,
|
||||
}
|
||||
}
|
||||
if (item.number) {
|
||||
return {
|
||||
...item.number,
|
||||
type: 'number',
|
||||
required: false,
|
||||
}
|
||||
}
|
||||
if (item.checkbox) {
|
||||
return {
|
||||
...item.checkbox,
|
||||
type: 'checkbox',
|
||||
required: false,
|
||||
}
|
||||
}
|
||||
if (item.select) {
|
||||
return {
|
||||
...item.select,
|
||||
type: 'select',
|
||||
required: false,
|
||||
}
|
||||
}
|
||||
|
||||
if (item['file-list']) {
|
||||
return {
|
||||
...item['file-list'],
|
||||
type: 'file-list',
|
||||
required: false,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
|
||||
if (item.file) {
|
||||
return {
|
||||
...item.file,
|
||||
type: 'file',
|
||||
required: false,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
|
||||
if (item.json_object) {
|
||||
return {
|
||||
...item.json_object,
|
||||
type: 'json_object',
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
...item['text-input'],
|
||||
type: 'text-input',
|
||||
required: false,
|
||||
}
|
||||
}) || []
|
||||
}
|
||||
else {
|
||||
const startNode = currentWorkflow?.graph?.nodes.find(node => node.data.type === BlockEnum.Start) as any
|
||||
inputFormSchema = startNode?.data.variables.map((variable: any) => {
|
||||
if (variable.type === InputVarType.multiFiles) {
|
||||
return {
|
||||
...variable,
|
||||
required: false,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
|
||||
if (variable.type === InputVarType.singleFile) {
|
||||
return {
|
||||
...variable,
|
||||
required: false,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
return {
|
||||
...variable,
|
||||
required: false,
|
||||
}
|
||||
}) || []
|
||||
}
|
||||
if ((currentApp.mode === AppModeEnum.COMPLETION || currentApp.mode === AppModeEnum.WORKFLOW) && basicAppFileConfig.enabled) {
|
||||
inputFormSchema.push({
|
||||
label: 'Image Upload',
|
||||
variable: '#image#',
|
||||
type: InputVarType.singleFile,
|
||||
required: false,
|
||||
...basicAppFileConfig,
|
||||
fileUploadConfig,
|
||||
})
|
||||
}
|
||||
return inputFormSchema || []
|
||||
}, [basicAppFileConfig, currentApp, currentWorkflow, fileUploadConfig, isBasicApp])
|
||||
|
||||
const handleFormChange = (value: Record<string, any>) => {
|
||||
inputsRef.current = value
|
||||
onFormChange(value)
|
||||
const handleFormChange = (newValue: Record<string, unknown>) => {
|
||||
inputsRef.current = newValue
|
||||
onFormChange(newValue)
|
||||
}
|
||||
|
||||
const hasInputs = inputFormSchema.length > 0
|
||||
|
||||
return (
|
||||
<div className={cn('flex max-h-[240px] flex-col rounded-b-2xl border-t border-divider-subtle pb-4')}>
|
||||
{isLoading && <div className="pt-3"><Loading type="app" /></div>}
|
||||
{!isLoading && (
|
||||
<div className="system-sm-semibold mb-2 mt-3 flex h-6 shrink-0 items-center px-4 text-text-secondary">{t('appSelector.params', { ns: 'app' })}</div>
|
||||
)}
|
||||
{!isLoading && !inputFormSchema.length && (
|
||||
<div className="flex h-16 flex-col items-center justify-center">
|
||||
<div className="system-sm-regular text-text-tertiary">{t('appSelector.noParams', { ns: 'app' })}</div>
|
||||
<div className="system-sm-semibold mb-2 mt-3 flex h-6 shrink-0 items-center px-4 text-text-secondary">
|
||||
{t('appSelector.params', { ns: 'app' })}
|
||||
</div>
|
||||
)}
|
||||
{!isLoading && !!inputFormSchema.length && (
|
||||
{!isLoading && !hasInputs && (
|
||||
<div className="flex h-16 flex-col items-center justify-center">
|
||||
<div className="system-sm-regular text-text-tertiary">
|
||||
{t('appSelector.noParams', { ns: 'app' })}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{!isLoading && hasInputs && (
|
||||
<div className="grow overflow-y-auto">
|
||||
<AppInputsForm
|
||||
inputs={value?.inputs || {}}
|
||||
|
||||
@@ -0,0 +1,211 @@
|
||||
'use client'
|
||||
import type { FileUpload } from '@/app/components/base/features/types'
|
||||
import type { FileUploadConfigResponse } from '@/models/common'
|
||||
import type { App } from '@/types/app'
|
||||
import type { FetchWorkflowDraftResponse } from '@/types/workflow'
|
||||
import { useMemo } from 'react'
|
||||
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
|
||||
import { BlockEnum, InputVarType, SupportUploadFileTypes } from '@/app/components/workflow/types'
|
||||
import { useAppDetail } from '@/service/use-apps'
|
||||
import { useFileUploadConfig } from '@/service/use-common'
|
||||
import { useAppWorkflow } from '@/service/use-workflow'
|
||||
import { AppModeEnum, Resolution } from '@/types/app'
|
||||
|
||||
const BASIC_INPUT_TYPE_MAP: Record<string, string> = {
|
||||
'paragraph': 'paragraph',
|
||||
'number': 'number',
|
||||
'checkbox': 'checkbox',
|
||||
'select': 'select',
|
||||
'file-list': 'file-list',
|
||||
'file': 'file',
|
||||
'json_object': 'json_object',
|
||||
}
|
||||
|
||||
const FILE_INPUT_TYPES = new Set(['file-list', 'file'])
|
||||
|
||||
const WORKFLOW_FILE_VAR_TYPES = new Set([InputVarType.multiFiles, InputVarType.singleFile])
|
||||
|
||||
type InputSchemaItem = {
|
||||
label?: string
|
||||
variable?: string
|
||||
type: string
|
||||
required: boolean
|
||||
fileUploadConfig?: FileUploadConfigResponse
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
function isBasicAppMode(mode: string): boolean {
|
||||
return mode !== AppModeEnum.ADVANCED_CHAT && mode !== AppModeEnum.WORKFLOW
|
||||
}
|
||||
|
||||
function supportsImageUpload(mode: string): boolean {
|
||||
return mode === AppModeEnum.COMPLETION || mode === AppModeEnum.WORKFLOW
|
||||
}
|
||||
|
||||
function buildFileConfig(fileConfig: FileUpload | undefined) {
|
||||
return {
|
||||
image: {
|
||||
detail: fileConfig?.image?.detail || Resolution.high,
|
||||
enabled: !!fileConfig?.image?.enabled,
|
||||
number_limits: fileConfig?.image?.number_limits || 3,
|
||||
transfer_methods: fileConfig?.image?.transfer_methods || ['local_file', 'remote_url'],
|
||||
},
|
||||
enabled: !!(fileConfig?.enabled || fileConfig?.image?.enabled),
|
||||
allowed_file_types: fileConfig?.allowed_file_types || [SupportUploadFileTypes.image],
|
||||
allowed_file_extensions: fileConfig?.allowed_file_extensions
|
||||
|| [...FILE_EXTS[SupportUploadFileTypes.image]].map(ext => `.${ext}`),
|
||||
allowed_file_upload_methods: fileConfig?.allowed_file_upload_methods
|
||||
|| fileConfig?.image?.transfer_methods
|
||||
|| ['local_file', 'remote_url'],
|
||||
number_limits: fileConfig?.number_limits || fileConfig?.image?.number_limits || 3,
|
||||
}
|
||||
}
|
||||
|
||||
function mapBasicAppInputItem(
|
||||
item: Record<string, unknown>,
|
||||
fileUploadConfig?: FileUploadConfigResponse,
|
||||
): InputSchemaItem | null {
|
||||
for (const [key, type] of Object.entries(BASIC_INPUT_TYPE_MAP)) {
|
||||
if (!item[key])
|
||||
continue
|
||||
|
||||
const inputData = item[key] as Record<string, unknown>
|
||||
const needsFileConfig = FILE_INPUT_TYPES.has(key)
|
||||
|
||||
return {
|
||||
...inputData,
|
||||
type,
|
||||
required: false,
|
||||
...(needsFileConfig && { fileUploadConfig }),
|
||||
}
|
||||
}
|
||||
|
||||
const textInput = item['text-input'] as Record<string, unknown> | undefined
|
||||
if (!textInput)
|
||||
return null
|
||||
|
||||
return {
|
||||
...textInput,
|
||||
type: 'text-input',
|
||||
required: false,
|
||||
}
|
||||
}
|
||||
|
||||
function mapWorkflowVariable(
|
||||
variable: Record<string, unknown>,
|
||||
fileUploadConfig?: FileUploadConfigResponse,
|
||||
): InputSchemaItem {
|
||||
const needsFileConfig = WORKFLOW_FILE_VAR_TYPES.has(variable.type as InputVarType)
|
||||
|
||||
return {
|
||||
...variable,
|
||||
type: variable.type as string,
|
||||
required: false,
|
||||
...(needsFileConfig && { fileUploadConfig }),
|
||||
}
|
||||
}
|
||||
|
||||
function createImageUploadSchema(
|
||||
basicFileConfig: ReturnType<typeof buildFileConfig>,
|
||||
fileUploadConfig?: FileUploadConfigResponse,
|
||||
): InputSchemaItem {
|
||||
return {
|
||||
label: 'Image Upload',
|
||||
variable: '#image#',
|
||||
type: InputVarType.singleFile,
|
||||
required: false,
|
||||
...basicFileConfig,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
|
||||
function buildBasicAppSchema(
|
||||
currentApp: App,
|
||||
fileUploadConfig?: FileUploadConfigResponse,
|
||||
): InputSchemaItem[] {
|
||||
const userInputForm = currentApp.model_config?.user_input_form as Array<Record<string, unknown>> | undefined
|
||||
if (!userInputForm)
|
||||
return []
|
||||
|
||||
return userInputForm
|
||||
.filter((item: Record<string, unknown>) => !item.external_data_tool)
|
||||
.map((item: Record<string, unknown>) => mapBasicAppInputItem(item, fileUploadConfig))
|
||||
.filter((item): item is InputSchemaItem => item !== null)
|
||||
}
|
||||
|
||||
function buildWorkflowSchema(
|
||||
workflow: FetchWorkflowDraftResponse,
|
||||
fileUploadConfig?: FileUploadConfigResponse,
|
||||
): InputSchemaItem[] {
|
||||
const startNode = workflow.graph?.nodes.find(
|
||||
node => node.data.type === BlockEnum.Start,
|
||||
) as { data: { variables: Array<Record<string, unknown>> } } | undefined
|
||||
|
||||
if (!startNode?.data.variables)
|
||||
return []
|
||||
|
||||
return startNode.data.variables.map(
|
||||
variable => mapWorkflowVariable(variable, fileUploadConfig),
|
||||
)
|
||||
}
|
||||
|
||||
type UseAppInputsFormSchemaParams = {
|
||||
appDetail: App
|
||||
}
|
||||
|
||||
type UseAppInputsFormSchemaResult = {
|
||||
inputFormSchema: InputSchemaItem[]
|
||||
isLoading: boolean
|
||||
fileUploadConfig?: FileUploadConfigResponse
|
||||
}
|
||||
|
||||
export function useAppInputsFormSchema({
|
||||
appDetail,
|
||||
}: UseAppInputsFormSchemaParams): UseAppInputsFormSchemaResult {
|
||||
const isBasicApp = isBasicAppMode(appDetail.mode)
|
||||
|
||||
const { data: fileUploadConfig } = useFileUploadConfig()
|
||||
const { data: currentApp, isFetching: isAppLoading } = useAppDetail(appDetail.id)
|
||||
const { data: currentWorkflow, isFetching: isWorkflowLoading } = useAppWorkflow(
|
||||
isBasicApp ? '' : appDetail.id,
|
||||
)
|
||||
|
||||
const isLoading = isAppLoading || isWorkflowLoading
|
||||
|
||||
const inputFormSchema = useMemo(() => {
|
||||
if (!currentApp)
|
||||
return []
|
||||
|
||||
if (!isBasicApp && !currentWorkflow)
|
||||
return []
|
||||
|
||||
// Build base schema based on app type
|
||||
// Note: currentWorkflow is guaranteed to be defined here due to the early return above
|
||||
const baseSchema = isBasicApp
|
||||
? buildBasicAppSchema(currentApp, fileUploadConfig)
|
||||
: buildWorkflowSchema(currentWorkflow!, fileUploadConfig)
|
||||
|
||||
if (!supportsImageUpload(currentApp.mode))
|
||||
return baseSchema
|
||||
|
||||
const rawFileConfig = isBasicApp
|
||||
? currentApp.model_config?.file_upload as FileUpload
|
||||
: currentWorkflow?.features?.file_upload as FileUpload
|
||||
|
||||
const basicFileConfig = buildFileConfig(rawFileConfig)
|
||||
|
||||
if (!basicFileConfig.enabled)
|
||||
return baseSchema
|
||||
|
||||
return [
|
||||
...baseSchema,
|
||||
createImageUploadSchema(basicFileConfig, fileUploadConfig),
|
||||
]
|
||||
}, [currentApp, currentWorkflow, fileUploadConfig, isBasicApp])
|
||||
|
||||
return {
|
||||
inputFormSchema,
|
||||
isLoading,
|
||||
fileUploadConfig,
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import Toast from '@/app/components/base/toast'
|
||||
import { PluginSource } from '../types'
|
||||
import DetailHeader from './detail-header'
|
||||
|
||||
// Use vi.hoisted for mock functions used in vi.mock factories
|
||||
const {
|
||||
mockSetShowUpdatePluginModal,
|
||||
mockRefreshModelProviders,
|
||||
|
||||
@@ -1,416 +1,2 @@
|
||||
import type { PluginDetail } from '../types'
|
||||
import {
|
||||
RiArrowLeftRightLine,
|
||||
RiBugLine,
|
||||
RiCloseLine,
|
||||
RiHardDrive3Line,
|
||||
} from '@remixicon/react'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import { Github } from '@/app/components/base/icons/src/public/common'
|
||||
import { BoxSparkleFill } from '@/app/components/base/icons/src/vender/plugin'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { AuthCategory, PluginAuth } from '@/app/components/plugins/plugin-auth'
|
||||
import OperationDropdown from '@/app/components/plugins/plugin-detail-panel/operation-dropdown'
|
||||
import PluginInfo from '@/app/components/plugins/plugin-page/plugin-info'
|
||||
import UpdateFromMarketplace from '@/app/components/plugins/update-plugin/from-market-place'
|
||||
import PluginVersionPicker from '@/app/components/plugins/update-plugin/plugin-version-picker'
|
||||
import { API_PREFIX } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useGetLanguage, useLocale } from '@/context/i18n'
|
||||
import { useModalContext } from '@/context/modal-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { uninstallPlugin } from '@/service/plugins'
|
||||
import { useAllToolProviders, useInvalidateAllToolProviders } from '@/service/use-tools'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { getMarketplaceUrl } from '@/utils/var'
|
||||
import { AutoUpdateLine } from '../../base/icons/src/vender/system'
|
||||
import Verified from '../base/badges/verified'
|
||||
import DeprecationNotice from '../base/deprecation-notice'
|
||||
import Icon from '../card/base/card-icon'
|
||||
import Description from '../card/base/description'
|
||||
import OrgInfo from '../card/base/org-info'
|
||||
import Title from '../card/base/title'
|
||||
import { useGitHubReleases } from '../install-plugin/hooks'
|
||||
import useReferenceSetting from '../plugin-page/use-reference-setting'
|
||||
import { AUTO_UPDATE_MODE } from '../reference-setting-modal/auto-update-setting/types'
|
||||
import { convertUTCDaySecondsToLocalSeconds, timeOfDayToDayjs } from '../reference-setting-modal/auto-update-setting/utils'
|
||||
import { PluginCategoryEnum, PluginSource } from '../types'
|
||||
|
||||
const i18nPrefix = 'action'
|
||||
|
||||
type Props = {
|
||||
detail: PluginDetail
|
||||
isReadmeView?: boolean
|
||||
onHide?: () => void
|
||||
onUpdate?: (isDelete?: boolean) => void
|
||||
}
|
||||
|
||||
const DetailHeader = ({
|
||||
detail,
|
||||
isReadmeView = false,
|
||||
onHide,
|
||||
onUpdate,
|
||||
}: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const { userProfile: { timezone } } = useAppContext()
|
||||
|
||||
const { theme } = useTheme()
|
||||
const locale = useGetLanguage()
|
||||
const currentLocale = useLocale()
|
||||
const { checkForUpdates, fetchReleases } = useGitHubReleases()
|
||||
const { setShowUpdatePluginModal } = useModalContext()
|
||||
const { refreshModelProviders } = useProviderContext()
|
||||
const invalidateAllToolProviders = useInvalidateAllToolProviders()
|
||||
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
|
||||
|
||||
const {
|
||||
id,
|
||||
source,
|
||||
tenant_id,
|
||||
version,
|
||||
latest_unique_identifier,
|
||||
latest_version,
|
||||
meta,
|
||||
plugin_id,
|
||||
status,
|
||||
deprecated_reason,
|
||||
alternative_plugin_id,
|
||||
} = detail
|
||||
|
||||
const { author, category, name, label, description, icon, icon_dark, verified, tool } = detail.declaration || detail
|
||||
const isTool = category === PluginCategoryEnum.tool
|
||||
const providerBriefInfo = tool?.identity
|
||||
const providerKey = `${plugin_id}/${providerBriefInfo?.name}`
|
||||
const { data: collectionList = [] } = useAllToolProviders(isTool)
|
||||
const provider = useMemo(() => {
|
||||
return collectionList.find(collection => collection.name === providerKey)
|
||||
}, [collectionList, providerKey])
|
||||
const isFromGitHub = source === PluginSource.github
|
||||
const isFromMarketplace = source === PluginSource.marketplace
|
||||
|
||||
const [isShow, setIsShow] = useState(false)
|
||||
const [targetVersion, setTargetVersion] = useState({
|
||||
version: latest_version,
|
||||
unique_identifier: latest_unique_identifier,
|
||||
})
|
||||
const hasNewVersion = useMemo(() => {
|
||||
if (isFromMarketplace)
|
||||
return !!latest_version && latest_version !== version
|
||||
|
||||
return false
|
||||
}, [isFromMarketplace, latest_version, version])
|
||||
|
||||
const iconFileName = theme === 'dark' && icon_dark ? icon_dark : icon
|
||||
const iconSrc = iconFileName
|
||||
? (iconFileName.startsWith('http') ? iconFileName : `${API_PREFIX}/workspaces/current/plugin/icon?tenant_id=${tenant_id}&filename=${iconFileName}`)
|
||||
: ''
|
||||
|
||||
const detailUrl = useMemo(() => {
|
||||
if (isFromGitHub)
|
||||
return `https://github.com/${meta!.repo}`
|
||||
if (isFromMarketplace)
|
||||
return getMarketplaceUrl(`/plugins/${author}/${name}`, { language: currentLocale, theme })
|
||||
return ''
|
||||
}, [author, isFromGitHub, isFromMarketplace, meta, name, theme])
|
||||
|
||||
const [isShowUpdateModal, {
|
||||
setTrue: showUpdateModal,
|
||||
setFalse: hideUpdateModal,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const { referenceSetting } = useReferenceSetting()
|
||||
const { auto_upgrade: autoUpgradeInfo } = referenceSetting || {}
|
||||
const isAutoUpgradeEnabled = useMemo(() => {
|
||||
if (!enable_marketplace)
|
||||
return false
|
||||
if (!autoUpgradeInfo || !isFromMarketplace)
|
||||
return false
|
||||
if (autoUpgradeInfo.strategy_setting === 'disabled')
|
||||
return false
|
||||
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.update_all)
|
||||
return true
|
||||
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.partial && autoUpgradeInfo.include_plugins.includes(plugin_id))
|
||||
return true
|
||||
if (autoUpgradeInfo.upgrade_mode === AUTO_UPDATE_MODE.exclude && !autoUpgradeInfo.exclude_plugins.includes(plugin_id))
|
||||
return true
|
||||
return false
|
||||
}, [autoUpgradeInfo, plugin_id, isFromMarketplace])
|
||||
|
||||
const [isDowngrade, setIsDowngrade] = useState(false)
|
||||
const handleUpdate = async (isDowngrade?: boolean) => {
|
||||
if (isFromMarketplace) {
|
||||
setIsDowngrade(!!isDowngrade)
|
||||
showUpdateModal()
|
||||
return
|
||||
}
|
||||
|
||||
const owner = meta!.repo.split('/')[0] || author
|
||||
const repo = meta!.repo.split('/')[1] || name
|
||||
const fetchedReleases = await fetchReleases(owner, repo)
|
||||
if (fetchedReleases.length === 0)
|
||||
return
|
||||
const { needUpdate, toastProps } = checkForUpdates(fetchedReleases, meta!.version)
|
||||
Toast.notify(toastProps)
|
||||
if (needUpdate) {
|
||||
setShowUpdatePluginModal({
|
||||
onSaveCallback: () => {
|
||||
onUpdate?.()
|
||||
},
|
||||
payload: {
|
||||
type: PluginSource.github,
|
||||
category: detail.declaration.category,
|
||||
github: {
|
||||
originalPackageInfo: {
|
||||
id: detail.plugin_unique_identifier,
|
||||
repo: meta!.repo,
|
||||
version: meta!.version,
|
||||
package: meta!.package,
|
||||
releases: fetchedReleases,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const handleUpdatedFromMarketplace = () => {
|
||||
onUpdate?.()
|
||||
hideUpdateModal()
|
||||
}
|
||||
|
||||
const [isShowPluginInfo, {
|
||||
setTrue: showPluginInfo,
|
||||
setFalse: hidePluginInfo,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const [isShowDeleteConfirm, {
|
||||
setTrue: showDeleteConfirm,
|
||||
setFalse: hideDeleteConfirm,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const [deleting, {
|
||||
setTrue: showDeleting,
|
||||
setFalse: hideDeleting,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const handleDelete = useCallback(async () => {
|
||||
showDeleting()
|
||||
const res = await uninstallPlugin(id)
|
||||
hideDeleting()
|
||||
if (res.success) {
|
||||
hideDeleteConfirm()
|
||||
onUpdate?.(true)
|
||||
if (PluginCategoryEnum.model.includes(category))
|
||||
refreshModelProviders()
|
||||
if (PluginCategoryEnum.tool.includes(category))
|
||||
invalidateAllToolProviders()
|
||||
trackEvent('plugin_uninstalled', { plugin_id, plugin_name: name })
|
||||
}
|
||||
}, [showDeleting, id, hideDeleting, hideDeleteConfirm, onUpdate, category, refreshModelProviders, invalidateAllToolProviders, plugin_id, name])
|
||||
|
||||
return (
|
||||
<div className={cn('shrink-0 border-b border-divider-subtle bg-components-panel-bg p-4 pb-3', isReadmeView && 'border-b-0 bg-transparent p-0')}>
|
||||
<div className="flex">
|
||||
<div className={cn('overflow-hidden rounded-xl border border-components-panel-border-subtle', isReadmeView && 'bg-components-panel-bg')}>
|
||||
<Icon src={iconSrc} />
|
||||
</div>
|
||||
<div className="ml-3 w-0 grow">
|
||||
<div className="flex h-5 items-center">
|
||||
<Title title={label[locale]} />
|
||||
{verified && !isReadmeView && <Verified className="ml-0.5 h-4 w-4" text={t('marketplace.verifiedTip', { ns: 'plugin' })} />}
|
||||
{!!version && (
|
||||
<PluginVersionPicker
|
||||
disabled={!isFromMarketplace || isReadmeView}
|
||||
isShow={isShow}
|
||||
onShowChange={setIsShow}
|
||||
pluginID={plugin_id}
|
||||
currentVersion={version}
|
||||
onSelect={(state) => {
|
||||
setTargetVersion(state)
|
||||
handleUpdate(state.isDowngrade)
|
||||
}}
|
||||
trigger={(
|
||||
<Badge
|
||||
className={cn(
|
||||
'mx-1',
|
||||
isShow && 'bg-state-base-hover',
|
||||
(isShow || isFromMarketplace) && 'hover:bg-state-base-hover',
|
||||
)}
|
||||
uppercase={false}
|
||||
text={(
|
||||
<>
|
||||
<div>{isFromGitHub ? meta!.version : version}</div>
|
||||
{isFromMarketplace && !isReadmeView && <RiArrowLeftRightLine className="ml-1 h-3 w-3 text-text-tertiary" />}
|
||||
</>
|
||||
)}
|
||||
hasRedCornerMark={hasNewVersion}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
{/* Auto update info */}
|
||||
{isAutoUpgradeEnabled && !isReadmeView && (
|
||||
<Tooltip popupContent={t('autoUpdate.nextUpdateTime', { ns: 'plugin', time: timeOfDayToDayjs(convertUTCDaySecondsToLocalSeconds(autoUpgradeInfo?.upgrade_time_of_day || 0, timezone!)).format('hh:mm A') })}>
|
||||
{/* add a a div to fix tooltip hover not show problem */}
|
||||
<div>
|
||||
<Badge className="mr-1 cursor-pointer px-1">
|
||||
<AutoUpdateLine className="size-3" />
|
||||
</Badge>
|
||||
</div>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
{(hasNewVersion || isFromGitHub) && (
|
||||
<Button
|
||||
variant="secondary-accent"
|
||||
size="small"
|
||||
className="!h-5"
|
||||
onClick={() => {
|
||||
if (isFromMarketplace) {
|
||||
setTargetVersion({
|
||||
version: latest_version,
|
||||
unique_identifier: latest_unique_identifier,
|
||||
})
|
||||
}
|
||||
handleUpdate()
|
||||
}}
|
||||
>
|
||||
{t('detailPanel.operation.update', { ns: 'plugin' })}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
<div className="mb-1 flex h-4 items-center justify-between">
|
||||
<div className="mt-0.5 flex items-center">
|
||||
<OrgInfo
|
||||
packageNameClassName="w-auto"
|
||||
orgName={author}
|
||||
packageName={name?.includes('/') ? (name.split('/').pop() || '') : name}
|
||||
/>
|
||||
{!!source && (
|
||||
<>
|
||||
<div className="system-xs-regular ml-1 mr-0.5 text-text-quaternary">·</div>
|
||||
{source === PluginSource.marketplace && (
|
||||
<Tooltip popupContent={t('detailPanel.categoryTip.marketplace', { ns: 'plugin' })}>
|
||||
<div><BoxSparkleFill className="h-3.5 w-3.5 text-text-tertiary hover:text-text-accent" /></div>
|
||||
</Tooltip>
|
||||
)}
|
||||
{source === PluginSource.github && (
|
||||
<Tooltip popupContent={t('detailPanel.categoryTip.github', { ns: 'plugin' })}>
|
||||
<div><Github className="h-3.5 w-3.5 text-text-secondary hover:text-text-primary" /></div>
|
||||
</Tooltip>
|
||||
)}
|
||||
{source === PluginSource.local && (
|
||||
<Tooltip popupContent={t('detailPanel.categoryTip.local', { ns: 'plugin' })}>
|
||||
<div><RiHardDrive3Line className="h-3.5 w-3.5 text-text-tertiary" /></div>
|
||||
</Tooltip>
|
||||
)}
|
||||
{source === PluginSource.debugging && (
|
||||
<Tooltip popupContent={t('detailPanel.categoryTip.debugging', { ns: 'plugin' })}>
|
||||
<div><RiBugLine className="h-3.5 w-3.5 text-text-tertiary hover:text-text-warning" /></div>
|
||||
</Tooltip>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{!isReadmeView && (
|
||||
<div className="flex gap-1">
|
||||
<OperationDropdown
|
||||
source={source}
|
||||
onInfo={showPluginInfo}
|
||||
onCheckVersion={handleUpdate}
|
||||
onRemove={showDeleteConfirm}
|
||||
detailUrl={detailUrl}
|
||||
/>
|
||||
<ActionButton onClick={onHide}>
|
||||
<RiCloseLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{isFromMarketplace && (
|
||||
<DeprecationNotice
|
||||
status={status}
|
||||
deprecatedReason={deprecated_reason}
|
||||
alternativePluginId={alternative_plugin_id}
|
||||
alternativePluginURL={getMarketplaceUrl(`/plugins/${alternative_plugin_id}`, { language: currentLocale, theme })}
|
||||
className="mt-3"
|
||||
/>
|
||||
)}
|
||||
{!isReadmeView && <Description className="mb-2 mt-3 h-auto" text={description[locale]} descriptionLineRows={2}></Description>}
|
||||
{
|
||||
category === PluginCategoryEnum.tool && !isReadmeView && (
|
||||
<PluginAuth
|
||||
pluginPayload={{
|
||||
provider: provider?.name || '',
|
||||
category: AuthCategory.tool,
|
||||
providerType: provider?.type || '',
|
||||
detail,
|
||||
}}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{isShowPluginInfo && (
|
||||
<PluginInfo
|
||||
repository={isFromGitHub ? meta?.repo : ''}
|
||||
release={version}
|
||||
packageName={meta?.package || ''}
|
||||
onHide={hidePluginInfo}
|
||||
/>
|
||||
)}
|
||||
{isShowDeleteConfirm && (
|
||||
<Confirm
|
||||
isShow
|
||||
title={t(`${i18nPrefix}.delete`, { ns: 'plugin' })}
|
||||
content={(
|
||||
<div>
|
||||
{t(`${i18nPrefix}.deleteContentLeft`, { ns: 'plugin' })}
|
||||
<span className="system-md-semibold">{label[locale]}</span>
|
||||
{t(`${i18nPrefix}.deleteContentRight`, { ns: 'plugin' })}
|
||||
<br />
|
||||
</div>
|
||||
)}
|
||||
onCancel={hideDeleteConfirm}
|
||||
onConfirm={handleDelete}
|
||||
isLoading={deleting}
|
||||
isDisabled={deleting}
|
||||
/>
|
||||
)}
|
||||
{
|
||||
isShowUpdateModal && (
|
||||
<UpdateFromMarketplace
|
||||
pluginId={plugin_id}
|
||||
payload={{
|
||||
category: detail.declaration.category,
|
||||
originalPackageInfo: {
|
||||
id: detail.plugin_unique_identifier,
|
||||
payload: detail.declaration,
|
||||
},
|
||||
targetPackageInfo: {
|
||||
id: targetVersion.unique_identifier,
|
||||
version: targetVersion.version,
|
||||
},
|
||||
}}
|
||||
onCancel={hideUpdateModal}
|
||||
onSave={handleUpdatedFromMarketplace}
|
||||
isShowDowngradeWarningModal={isDowngrade && isAutoUpgradeEnabled}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default DetailHeader
|
||||
// Re-export from refactored module for backward compatibility
|
||||
export { default } from './detail-header/index'
|
||||
|
||||
@@ -0,0 +1,539 @@
|
||||
import type { PluginDetail } from '../../../types'
|
||||
import type { ModalStates, VersionTarget } from '../hooks'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { PluginSource } from '../../../types'
|
||||
import HeaderModals from './header-modals'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useGetLanguage: () => 'en_US',
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/confirm', () => ({
|
||||
default: ({ isShow, title, onCancel, onConfirm, isLoading }: {
|
||||
isShow: boolean
|
||||
title: string
|
||||
onCancel: () => void
|
||||
onConfirm: () => void
|
||||
isLoading: boolean
|
||||
}) => isShow
|
||||
? (
|
||||
<div data-testid="delete-confirm">
|
||||
<div data-testid="delete-title">{title}</div>
|
||||
<button data-testid="confirm-cancel" onClick={onCancel}>Cancel</button>
|
||||
<button data-testid="confirm-ok" onClick={onConfirm} disabled={isLoading}>Confirm</button>
|
||||
</div>
|
||||
)
|
||||
: null,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/plugin-page/plugin-info', () => ({
|
||||
default: ({ repository, release, packageName, onHide }: {
|
||||
repository: string
|
||||
release: string
|
||||
packageName: string
|
||||
onHide: () => void
|
||||
}) => (
|
||||
<div data-testid="plugin-info">
|
||||
<div data-testid="plugin-info-repo">{repository}</div>
|
||||
<div data-testid="plugin-info-release">{release}</div>
|
||||
<div data-testid="plugin-info-package">{packageName}</div>
|
||||
<button data-testid="plugin-info-close" onClick={onHide}>Close</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/update-plugin/from-market-place', () => ({
|
||||
default: ({ pluginId, onSave, onCancel, isShowDowngradeWarningModal }: {
|
||||
pluginId: string
|
||||
onSave: () => void
|
||||
onCancel: () => void
|
||||
isShowDowngradeWarningModal: boolean
|
||||
}) => (
|
||||
<div data-testid="update-modal">
|
||||
<div data-testid="update-plugin-id">{pluginId}</div>
|
||||
<div data-testid="update-downgrade-warning">{String(isShowDowngradeWarningModal)}</div>
|
||||
<button data-testid="update-modal-save" onClick={onSave}>Save</button>
|
||||
<button data-testid="update-modal-cancel" onClick={onCancel}>Cancel</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({
|
||||
id: 'test-id',
|
||||
created_at: '2024-01-01',
|
||||
updated_at: '2024-01-02',
|
||||
name: 'Test Plugin',
|
||||
plugin_id: 'test-plugin',
|
||||
plugin_unique_identifier: 'test-uid',
|
||||
declaration: {
|
||||
author: 'test-author',
|
||||
name: 'test-plugin-name',
|
||||
category: 'tool',
|
||||
label: { en_US: 'Test Plugin Label' },
|
||||
description: { en_US: 'Test description' },
|
||||
icon: 'icon.png',
|
||||
verified: true,
|
||||
} as unknown as PluginDetail['declaration'],
|
||||
installation_id: 'install-1',
|
||||
tenant_id: 'tenant-1',
|
||||
endpoints_setups: 0,
|
||||
endpoints_active: 0,
|
||||
version: '1.0.0',
|
||||
latest_version: '2.0.0',
|
||||
latest_unique_identifier: 'new-uid',
|
||||
source: PluginSource.marketplace,
|
||||
meta: undefined,
|
||||
status: 'active',
|
||||
deprecated_reason: '',
|
||||
alternative_plugin_id: '',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createModalStatesMock = (overrides: Partial<ModalStates> = {}): ModalStates => ({
|
||||
isShowUpdateModal: false,
|
||||
showUpdateModal: vi.fn<() => void>(),
|
||||
hideUpdateModal: vi.fn<() => void>(),
|
||||
isShowPluginInfo: false,
|
||||
showPluginInfo: vi.fn<() => void>(),
|
||||
hidePluginInfo: vi.fn<() => void>(),
|
||||
isShowDeleteConfirm: false,
|
||||
showDeleteConfirm: vi.fn<() => void>(),
|
||||
hideDeleteConfirm: vi.fn<() => void>(),
|
||||
deleting: false,
|
||||
showDeleting: vi.fn<() => void>(),
|
||||
hideDeleting: vi.fn<() => void>(),
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createTargetVersion = (overrides: Partial<VersionTarget> = {}): VersionTarget => ({
|
||||
version: '2.0.0',
|
||||
unique_identifier: 'new-uid',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('HeaderModals', () => {
|
||||
let mockOnUpdatedFromMarketplace: () => void
|
||||
let mockOnDelete: () => void
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockOnUpdatedFromMarketplace = vi.fn<() => void>()
|
||||
mockOnDelete = vi.fn<() => void>()
|
||||
})
|
||||
|
||||
describe('Plugin Info Modal', () => {
|
||||
it('should not render plugin info modal when isShowPluginInfo is false', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: false })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('plugin-info')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render plugin info modal when isShowPluginInfo is true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('plugin-info')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass GitHub repo to plugin info for GitHub source', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'test-pkg' },
|
||||
})
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={detail}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('owner/repo')
|
||||
})
|
||||
|
||||
it('should pass empty string for repo for non-GitHub source', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail({ source: PluginSource.marketplace })}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('')
|
||||
})
|
||||
|
||||
it('should call hidePluginInfo when close button is clicked', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('plugin-info-close'))
|
||||
|
||||
expect(modalStates.hidePluginInfo).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Delete Confirm Modal', () => {
|
||||
it('should not render delete confirm when isShowDeleteConfirm is false', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: false })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('delete-confirm')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render delete confirm when isShowDeleteConfirm is true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show correct delete title', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('delete-title')).toHaveTextContent('action.delete')
|
||||
})
|
||||
|
||||
it('should call hideDeleteConfirm when cancel is clicked', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('confirm-cancel'))
|
||||
|
||||
expect(modalStates.hideDeleteConfirm).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onDelete when confirm is clicked', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('confirm-ok'))
|
||||
|
||||
expect(mockOnDelete).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should disable confirm button when deleting', () => {
|
||||
const modalStates = createModalStatesMock({ isShowDeleteConfirm: true, deleting: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('confirm-ok')).toBeDisabled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Update Modal', () => {
|
||||
it('should not render update modal when isShowUpdateModal is false', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: false })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('update-modal')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render update modal when isShowUpdateModal is true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass plugin id to update modal', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail({ plugin_id: 'my-plugin-id' })}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-plugin-id')).toHaveTextContent('my-plugin-id')
|
||||
})
|
||||
|
||||
it('should call onUpdatedFromMarketplace when save is clicked', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('update-modal-save'))
|
||||
|
||||
expect(mockOnUpdatedFromMarketplace).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call hideUpdateModal when cancel is clicked', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('update-modal-cancel'))
|
||||
|
||||
expect(modalStates.hideUpdateModal).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show downgrade warning when isDowngrade and isAutoUpgradeEnabled are true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={true}
|
||||
isAutoUpgradeEnabled={true}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('true')
|
||||
})
|
||||
|
||||
it('should not show downgrade warning when only isDowngrade is true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={true}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('false')
|
||||
})
|
||||
|
||||
it('should not show downgrade warning when only isAutoUpgradeEnabled is true', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={true}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-downgrade-warning')).toHaveTextContent('false')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Multiple Modals', () => {
|
||||
it('should render multiple modals when multiple are open', () => {
|
||||
const modalStates = createModalStatesMock({
|
||||
isShowPluginInfo: true,
|
||||
isShowDeleteConfirm: true,
|
||||
isShowUpdateModal: true,
|
||||
})
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('plugin-info')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle undefined target version values', () => {
|
||||
const modalStates = createModalStatesMock({ isShowUpdateModal: true })
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={createPluginDetail()}
|
||||
modalStates={modalStates}
|
||||
targetVersion={{ version: undefined, unique_identifier: undefined }}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty meta for GitHub source', () => {
|
||||
const modalStates = createModalStatesMock({ isShowPluginInfo: true })
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: undefined,
|
||||
})
|
||||
render(
|
||||
<HeaderModals
|
||||
detail={detail}
|
||||
modalStates={modalStates}
|
||||
targetVersion={createTargetVersion()}
|
||||
isDowngrade={false}
|
||||
isAutoUpgradeEnabled={false}
|
||||
onUpdatedFromMarketplace={mockOnUpdatedFromMarketplace}
|
||||
onDelete={mockOnDelete}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('plugin-info-repo')).toHaveTextContent('')
|
||||
expect(screen.getByTestId('plugin-info-package')).toHaveTextContent('')
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,107 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { PluginDetail } from '../../../types'
|
||||
import type { ModalStates, VersionTarget } from '../hooks'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import PluginInfo from '@/app/components/plugins/plugin-page/plugin-info'
|
||||
import UpdateFromMarketplace from '@/app/components/plugins/update-plugin/from-market-place'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import { PluginSource } from '../../../types'
|
||||
|
||||
const i18nPrefix = 'action'
|
||||
|
||||
type HeaderModalsProps = {
|
||||
detail: PluginDetail
|
||||
modalStates: ModalStates
|
||||
targetVersion: VersionTarget
|
||||
isDowngrade: boolean
|
||||
isAutoUpgradeEnabled: boolean
|
||||
onUpdatedFromMarketplace: () => void
|
||||
onDelete: () => void
|
||||
}
|
||||
|
||||
const HeaderModals: FC<HeaderModalsProps> = ({
|
||||
detail,
|
||||
modalStates,
|
||||
targetVersion,
|
||||
isDowngrade,
|
||||
isAutoUpgradeEnabled,
|
||||
onUpdatedFromMarketplace,
|
||||
onDelete,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const locale = useGetLanguage()
|
||||
|
||||
const { source, version, meta } = detail
|
||||
const { label } = detail.declaration || detail
|
||||
const isFromGitHub = source === PluginSource.github
|
||||
|
||||
const {
|
||||
isShowUpdateModal,
|
||||
hideUpdateModal,
|
||||
isShowPluginInfo,
|
||||
hidePluginInfo,
|
||||
isShowDeleteConfirm,
|
||||
hideDeleteConfirm,
|
||||
deleting,
|
||||
} = modalStates
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* Plugin Info Modal */}
|
||||
{isShowPluginInfo && (
|
||||
<PluginInfo
|
||||
repository={isFromGitHub ? meta?.repo : ''}
|
||||
release={version}
|
||||
packageName={meta?.package || ''}
|
||||
onHide={hidePluginInfo}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Delete Confirm Modal */}
|
||||
{isShowDeleteConfirm && (
|
||||
<Confirm
|
||||
isShow
|
||||
title={t(`${i18nPrefix}.delete`, { ns: 'plugin' })}
|
||||
content={(
|
||||
<div>
|
||||
{t(`${i18nPrefix}.deleteContentLeft`, { ns: 'plugin' })}
|
||||
<span className="system-md-semibold">{label[locale]}</span>
|
||||
{t(`${i18nPrefix}.deleteContentRight`, { ns: 'plugin' })}
|
||||
<br />
|
||||
</div>
|
||||
)}
|
||||
onCancel={hideDeleteConfirm}
|
||||
onConfirm={onDelete}
|
||||
isLoading={deleting}
|
||||
isDisabled={deleting}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Update from Marketplace Modal */}
|
||||
{isShowUpdateModal && (
|
||||
<UpdateFromMarketplace
|
||||
pluginId={detail.plugin_id}
|
||||
payload={{
|
||||
category: detail.declaration?.category ?? '',
|
||||
originalPackageInfo: {
|
||||
id: detail.plugin_unique_identifier,
|
||||
payload: detail.declaration ?? undefined,
|
||||
},
|
||||
targetPackageInfo: {
|
||||
id: targetVersion.unique_identifier || '',
|
||||
version: targetVersion.version || '',
|
||||
},
|
||||
}}
|
||||
onCancel={hideUpdateModal}
|
||||
onSave={onUpdatedFromMarketplace}
|
||||
isShowDowngradeWarningModal={isDowngrade && isAutoUpgradeEnabled}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default HeaderModals
|
||||
@@ -0,0 +1,2 @@
|
||||
export { default as HeaderModals } from './header-modals'
|
||||
export { default as PluginSourceBadge } from './plugin-source-badge'
|
||||
@@ -0,0 +1,200 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { PluginSource } from '../../../types'
|
||||
import PluginSourceBadge from './plugin-source-badge'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/tooltip', () => ({
|
||||
default: ({ children, popupContent }: { children: React.ReactNode, popupContent: string }) => (
|
||||
<div data-testid="tooltip" data-content={popupContent}>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
describe('PluginSourceBadge', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Source Icon Rendering', () => {
|
||||
it('should render marketplace source badge', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.marketplace} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
expect(tooltip).toBeInTheDocument()
|
||||
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.marketplace')
|
||||
})
|
||||
|
||||
it('should render github source badge', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.github} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
expect(tooltip).toBeInTheDocument()
|
||||
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.github')
|
||||
})
|
||||
|
||||
it('should render local source badge', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.local} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
expect(tooltip).toBeInTheDocument()
|
||||
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.local')
|
||||
})
|
||||
|
||||
it('should render debugging source badge', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.debugging} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
expect(tooltip).toBeInTheDocument()
|
||||
expect(tooltip).toHaveAttribute('data-content', 'detailPanel.categoryTip.debugging')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Separator Rendering', () => {
|
||||
it('should render separator dot before marketplace badge', () => {
|
||||
const { container } = render(<PluginSourceBadge source={PluginSource.marketplace} />)
|
||||
|
||||
const separator = container.querySelector('.text-text-quaternary')
|
||||
expect(separator).toBeInTheDocument()
|
||||
expect(separator?.textContent).toBe('·')
|
||||
})
|
||||
|
||||
it('should render separator dot before github badge', () => {
|
||||
const { container } = render(<PluginSourceBadge source={PluginSource.github} />)
|
||||
|
||||
const separator = container.querySelector('.text-text-quaternary')
|
||||
expect(separator).toBeInTheDocument()
|
||||
expect(separator?.textContent).toBe('·')
|
||||
})
|
||||
|
||||
it('should render separator dot before local badge', () => {
|
||||
const { container } = render(<PluginSourceBadge source={PluginSource.local} />)
|
||||
|
||||
const separator = container.querySelector('.text-text-quaternary')
|
||||
expect(separator).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render separator dot before debugging badge', () => {
|
||||
const { container } = render(<PluginSourceBadge source={PluginSource.debugging} />)
|
||||
|
||||
const separator = container.querySelector('.text-text-quaternary')
|
||||
expect(separator).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Tooltip Content', () => {
|
||||
it('should show marketplace tooltip', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.marketplace} />)
|
||||
|
||||
expect(screen.getByTestId('tooltip')).toHaveAttribute(
|
||||
'data-content',
|
||||
'detailPanel.categoryTip.marketplace',
|
||||
)
|
||||
})
|
||||
|
||||
it('should show github tooltip', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.github} />)
|
||||
|
||||
expect(screen.getByTestId('tooltip')).toHaveAttribute(
|
||||
'data-content',
|
||||
'detailPanel.categoryTip.github',
|
||||
)
|
||||
})
|
||||
|
||||
it('should show local tooltip', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.local} />)
|
||||
|
||||
expect(screen.getByTestId('tooltip')).toHaveAttribute(
|
||||
'data-content',
|
||||
'detailPanel.categoryTip.local',
|
||||
)
|
||||
})
|
||||
|
||||
it('should show debugging tooltip', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.debugging} />)
|
||||
|
||||
expect(screen.getByTestId('tooltip')).toHaveAttribute(
|
||||
'data-content',
|
||||
'detailPanel.categoryTip.debugging',
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Icon Element Structure', () => {
|
||||
it('should render icon inside tooltip for marketplace', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.marketplace} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
const iconWrapper = tooltip.querySelector('div')
|
||||
expect(iconWrapper).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render icon inside tooltip for github', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.github} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
const iconWrapper = tooltip.querySelector('div')
|
||||
expect(iconWrapper).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render icon inside tooltip for local', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.local} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
const iconWrapper = tooltip.querySelector('div')
|
||||
expect(iconWrapper).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render icon inside tooltip for debugging', () => {
|
||||
render(<PluginSourceBadge source={PluginSource.debugging} />)
|
||||
|
||||
const tooltip = screen.getByTestId('tooltip')
|
||||
const iconWrapper = tooltip.querySelector('div')
|
||||
expect(iconWrapper).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Lookup Table Coverage', () => {
|
||||
it('should handle all PluginSource enum values', () => {
|
||||
const allSources = Object.values(PluginSource)
|
||||
|
||||
allSources.forEach((source) => {
|
||||
const { container } = render(<PluginSourceBadge source={source} />)
|
||||
// Should render either tooltip or nothing
|
||||
expect(container).toBeTruthy()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Invalid Source Handling', () => {
|
||||
it('should return null for unknown source type', () => {
|
||||
// Use type assertion to test invalid source value
|
||||
const invalidSource = 'unknown_source' as PluginSource
|
||||
const { container } = render(<PluginSourceBadge source={invalidSource} />)
|
||||
|
||||
// Should render nothing (empty container)
|
||||
expect(container.firstChild).toBeNull()
|
||||
})
|
||||
|
||||
it('should not render separator for invalid source', () => {
|
||||
const invalidSource = 'invalid' as PluginSource
|
||||
const { container } = render(<PluginSourceBadge source={invalidSource} />)
|
||||
|
||||
const separator = container.querySelector('.text-text-quaternary')
|
||||
expect(separator).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render tooltip for invalid source', () => {
|
||||
const invalidSource = '' as PluginSource
|
||||
render(<PluginSourceBadge source={invalidSource} />)
|
||||
|
||||
expect(screen.queryByTestId('tooltip')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,59 @@
|
||||
'use client'
|
||||
|
||||
import type { FC, ReactNode } from 'react'
|
||||
import {
|
||||
RiBugLine,
|
||||
RiHardDrive3Line,
|
||||
} from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Github } from '@/app/components/base/icons/src/public/common'
|
||||
import { BoxSparkleFill } from '@/app/components/base/icons/src/vender/plugin'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { PluginSource } from '../../../types'
|
||||
|
||||
type SourceConfig = {
|
||||
icon: ReactNode
|
||||
tipKey: string
|
||||
}
|
||||
|
||||
type PluginSourceBadgeProps = {
|
||||
source: PluginSource
|
||||
}
|
||||
|
||||
const SOURCE_CONFIG_MAP: Record<PluginSource, SourceConfig | null> = {
|
||||
[PluginSource.marketplace]: {
|
||||
icon: <BoxSparkleFill className="h-3.5 w-3.5 text-text-tertiary hover:text-text-accent" />,
|
||||
tipKey: 'detailPanel.categoryTip.marketplace',
|
||||
},
|
||||
[PluginSource.github]: {
|
||||
icon: <Github className="h-3.5 w-3.5 text-text-secondary hover:text-text-primary" />,
|
||||
tipKey: 'detailPanel.categoryTip.github',
|
||||
},
|
||||
[PluginSource.local]: {
|
||||
icon: <RiHardDrive3Line className="h-3.5 w-3.5 text-text-tertiary" />,
|
||||
tipKey: 'detailPanel.categoryTip.local',
|
||||
},
|
||||
[PluginSource.debugging]: {
|
||||
icon: <RiBugLine className="h-3.5 w-3.5 text-text-tertiary hover:text-text-warning" />,
|
||||
tipKey: 'detailPanel.categoryTip.debugging',
|
||||
},
|
||||
}
|
||||
|
||||
const PluginSourceBadge: FC<PluginSourceBadgeProps> = ({ source }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const config = SOURCE_CONFIG_MAP[source]
|
||||
if (!config)
|
||||
return null
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="system-xs-regular ml-1 mr-0.5 text-text-quaternary">·</div>
|
||||
<Tooltip popupContent={t(config.tipKey as never, { ns: 'plugin' })}>
|
||||
<div>{config.icon}</div>
|
||||
</Tooltip>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default PluginSourceBadge
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user