mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 09:39:25 +08:00
Merge branch 'main' into jzh
This commit is contained in:
@@ -8,7 +8,7 @@ from hashlib import sha256
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@@ -144,22 +144,26 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
def load_user(user_id: str) -> None | Account:
|
||||
account = db.session.query(Account).filter_by(id=user_id).first()
|
||||
account = db.session.get(Account, user_id)
|
||||
if not account:
|
||||
return None
|
||||
|
||||
if account.status == AccountStatus.BANNED:
|
||||
raise Unauthorized("Account is banned.")
|
||||
|
||||
current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
|
||||
current_tenant = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True)
|
||||
.limit(1)
|
||||
)
|
||||
if current_tenant:
|
||||
account.set_tenant_id(current_tenant.tenant_id)
|
||||
else:
|
||||
available_ta = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter_by(account_id=account.id)
|
||||
available_ta = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id)
|
||||
.order_by(TenantAccountJoin.id.asc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not available_ta:
|
||||
return None
|
||||
@@ -195,7 +199,7 @@ class AccountService:
|
||||
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
|
||||
"""authenticate account with email and password"""
|
||||
|
||||
account = db.session.query(Account).filter_by(email=email).first()
|
||||
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
if not account:
|
||||
raise AccountPasswordError("Invalid email or password.")
|
||||
|
||||
@@ -371,8 +375,10 @@ class AccountService:
|
||||
"""Link account integrate"""
|
||||
try:
|
||||
# Query whether there is an existing binding record for the same provider
|
||||
account_integrate: AccountIntegrate | None = (
|
||||
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
|
||||
account_integrate: AccountIntegrate | None = db.session.scalar(
|
||||
select(AccountIntegrate)
|
||||
.where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if account_integrate:
|
||||
@@ -416,7 +422,9 @@ class AccountService:
|
||||
def update_account_email(account: Account, email: str) -> Account:
|
||||
"""Update account email"""
|
||||
account.email = email
|
||||
account_integrate = db.session.query(AccountIntegrate).filter_by(account_id=account.id).first()
|
||||
account_integrate = db.session.scalar(
|
||||
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1)
|
||||
)
|
||||
if account_integrate:
|
||||
db.session.delete(account_integrate)
|
||||
db.session.add(account)
|
||||
@@ -818,7 +826,7 @@ class AccountService:
|
||||
)
|
||||
)
|
||||
|
||||
account = db.session.query(Account).where(Account.email == email).first()
|
||||
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
if not account:
|
||||
return None
|
||||
|
||||
@@ -1018,7 +1026,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
def check_email_unique(email: str) -> bool:
|
||||
return db.session.query(Account).filter_by(email=email).first() is None
|
||||
return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None
|
||||
|
||||
|
||||
class TenantService:
|
||||
@@ -1384,10 +1392,10 @@ class RegisterService:
|
||||
db.session.add(dify_setup)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DifySetup).delete()
|
||||
db.session.query(TenantAccountJoin).delete()
|
||||
db.session.query(Account).delete()
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.execute(delete(DifySetup))
|
||||
db.session.execute(delete(TenantAccountJoin))
|
||||
db.session.execute(delete(Account))
|
||||
db.session.execute(delete(Tenant))
|
||||
db.session.commit()
|
||||
|
||||
logger.exception("Setup account failed, email: %s, name: %s", email, name)
|
||||
@@ -1488,7 +1496,11 @@ class RegisterService:
|
||||
TenantService.switch_tenant(account, tenant.id)
|
||||
else:
|
||||
TenantService.check_member_permission(tenant, inviter, account, "add")
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
ta = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not ta:
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
@@ -1545,21 +1557,18 @@ class RegisterService:
|
||||
if not invitation_data:
|
||||
return None
|
||||
|
||||
tenant = (
|
||||
db.session.query(Tenant)
|
||||
.where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
|
||||
.first()
|
||||
tenant = db.session.scalar(
|
||||
select(Tenant).where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
tenant_account = (
|
||||
db.session.query(Account, TenantAccountJoin.role)
|
||||
tenant_account = db.session.execute(
|
||||
select(Account, TenantAccountJoin.role)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
|
||||
.first()
|
||||
)
|
||||
).first()
|
||||
|
||||
if not tenant_account:
|
||||
return None
|
||||
|
||||
@@ -4,6 +4,8 @@ import uuid
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from typing import TypedDict
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
@@ -23,6 +25,27 @@ from tasks.annotation.enable_annotation_reply_task import enable_annotation_repl
|
||||
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
|
||||
|
||||
|
||||
class AnnotationJobStatusDict(TypedDict):
|
||||
job_id: str
|
||||
job_status: str
|
||||
|
||||
|
||||
class EmbeddingModelDict(TypedDict):
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class AnnotationSettingDict(TypedDict):
|
||||
id: str
|
||||
enabled: bool
|
||||
score_threshold: float
|
||||
embedding_model: EmbeddingModelDict | dict
|
||||
|
||||
|
||||
class AnnotationSettingDisabledDict(TypedDict):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class AppAnnotationService:
|
||||
@classmethod
|
||||
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
@@ -85,7 +108,7 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def enable_app_annotation(cls, args: dict, app_id: str):
|
||||
def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict:
|
||||
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(enable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
@@ -109,7 +132,7 @@ class AppAnnotationService:
|
||||
return {"job_id": job_id, "job_status": "waiting"}
|
||||
|
||||
@classmethod
|
||||
def disable_app_annotation(cls, app_id: str):
|
||||
def disable_app_annotation(cls, app_id: str) -> AnnotationJobStatusDict:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(disable_app_annotation_key)
|
||||
@@ -567,7 +590,7 @@ class AppAnnotationService:
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_app_annotation_setting_by_app_id(cls, app_id: str):
|
||||
def get_app_annotation_setting_by_app_id(cls, app_id: str) -> AnnotationSettingDict | AnnotationSettingDisabledDict:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# get app info
|
||||
app = (
|
||||
@@ -602,7 +625,9 @@ class AppAnnotationService:
|
||||
return {"enabled": False}
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
|
||||
def update_app_annotation_setting(
|
||||
cls, app_id: str, annotation_setting_id: str, args: dict
|
||||
) -> AnnotationSettingDict:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# get app info
|
||||
app = (
|
||||
|
||||
@@ -5,7 +5,7 @@ from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from graphon.nodes.http_request.exc import InvalidHttpMethodError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.helper import ssrf_proxy
|
||||
@@ -103,8 +103,10 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_api(external_knowledge_api_id: str, tenant_id: str) -> ExternalKnowledgeApis:
|
||||
external_knowledge_api: ExternalKnowledgeApis | None = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
@@ -112,8 +114,10 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
|
||||
external_knowledge_api: ExternalKnowledgeApis | None = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
@@ -132,8 +136,10 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
@@ -144,9 +150,12 @@ class ExternalDatasetService:
|
||||
@staticmethod
|
||||
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
|
||||
count = (
|
||||
db.session.query(ExternalKnowledgeBindings)
|
||||
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
|
||||
.count()
|
||||
db.session.scalar(
|
||||
select(func.count(ExternalKnowledgeBindings.id)).where(
|
||||
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if count > 0:
|
||||
return True, count
|
||||
@@ -154,8 +163,10 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
|
||||
external_knowledge_binding: ExternalKnowledgeBindings | None = (
|
||||
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_binding: ExternalKnowledgeBindings | None = db.session.scalar(
|
||||
select(ExternalKnowledgeBindings)
|
||||
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("external knowledge binding not found")
|
||||
@@ -163,8 +174,10 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None or external_knowledge_api.settings is None:
|
||||
raise ValueError("api template not found")
|
||||
@@ -238,12 +251,17 @@ class ExternalDatasetService:
|
||||
@staticmethod
|
||||
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
|
||||
# check if dataset name already exists
|
||||
if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
|
||||
if db.session.scalar(
|
||||
select(Dataset).where(Dataset.name == args.get("name"), Dataset.tenant_id == tenant_id).limit(1)
|
||||
):
|
||||
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis)
|
||||
.filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
|
||||
.first()
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(
|
||||
ExternalKnowledgeApis.id == args.get("external_knowledge_api_id"),
|
||||
ExternalKnowledgeApis.tenant_id == tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if external_knowledge_api is None:
|
||||
@@ -286,16 +304,18 @@ class ExternalDatasetService:
|
||||
external_retrieval_parameters: dict,
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
):
|
||||
external_knowledge_binding = (
|
||||
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_binding = db.session.scalar(
|
||||
select(ExternalKnowledgeBindings)
|
||||
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("external knowledge binding not found")
|
||||
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis)
|
||||
.filter_by(id=external_knowledge_binding.external_knowledge_api_id)
|
||||
.first()
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None or external_knowledge_api.settings is None:
|
||||
raise ValueError("external api template not found")
|
||||
|
||||
@@ -156,27 +156,27 @@ class RagPipelineService:
|
||||
:param template_id: template id
|
||||
:param template_info: template info
|
||||
"""
|
||||
customized_template: PipelineCustomizedTemplate | None = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.id == template_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not customized_template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
# check template name is exist
|
||||
template_name = template_info.name
|
||||
if template_name:
|
||||
template = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
template = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.name == template_name,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
PipelineCustomizedTemplate.id != template_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if template:
|
||||
raise ValueError("Template name is already exists")
|
||||
@@ -192,13 +192,13 @@ class RagPipelineService:
|
||||
"""
|
||||
Delete customized pipeline template.
|
||||
"""
|
||||
customized_template: PipelineCustomizedTemplate | None = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.id == template_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not customized_template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
@@ -210,14 +210,14 @@ class RagPipelineService:
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by rag pipeline
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# return draft workflow
|
||||
@@ -232,28 +232,28 @@ class RagPipelineService:
|
||||
return None
|
||||
|
||||
# fetch published workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.id == pipeline.workflow_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
|
||||
"""Fetch a published workflow snapshot by ID for restore operations."""
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.id == workflow_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if workflow and workflow.version == Workflow.VERSION_DRAFT:
|
||||
raise IsDraftWorkflowError("source workflow must be published")
|
||||
@@ -974,7 +974,7 @@ class RagPipelineService:
|
||||
if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE:
|
||||
document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID)
|
||||
if document_id:
|
||||
document = db.session.query(Document).where(Document.id == document_id.value).first()
|
||||
document = db.session.get(Document, document_id.value)
|
||||
if document:
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = error
|
||||
@@ -1178,12 +1178,12 @@ class RagPipelineService:
|
||||
"""
|
||||
Publish customized pipeline template
|
||||
"""
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
|
||||
pipeline = db.session.get(Pipeline, pipeline_id)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
if not pipeline.workflow_id:
|
||||
raise ValueError("Pipeline workflow not found")
|
||||
workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
|
||||
workflow = db.session.get(Workflow, pipeline.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
with Session(db.engine) as session:
|
||||
@@ -1194,21 +1194,21 @@ class RagPipelineService:
|
||||
# check template name is exist
|
||||
template_name = args.get("name")
|
||||
if template_name:
|
||||
template = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
template = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.name == template_name,
|
||||
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if template:
|
||||
raise ValueError("Template name is already exists")
|
||||
|
||||
max_position = (
|
||||
db.session.query(func.max(PipelineCustomizedTemplate.position))
|
||||
.where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
|
||||
.scalar()
|
||||
max_position = db.session.scalar(
|
||||
select(func.max(PipelineCustomizedTemplate.position)).where(
|
||||
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id
|
||||
)
|
||||
)
|
||||
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
@@ -1239,13 +1239,14 @@ class RagPipelineService:
|
||||
|
||||
def is_workflow_exist(self, pipeline: Pipeline) -> bool:
|
||||
return (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
db.session.scalar(
|
||||
select(func.count(Workflow.id)).where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
) > 0
|
||||
|
||||
def get_node_last_run(
|
||||
@@ -1353,11 +1354,11 @@ class RagPipelineService:
|
||||
|
||||
def get_recommended_plugins(self, type: str) -> dict:
|
||||
# Query active recommended plugins
|
||||
query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
|
||||
stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
|
||||
if type and type != "all":
|
||||
query = query.where(PipelineRecommendedPlugin.type == type)
|
||||
stmt = stmt.where(PipelineRecommendedPlugin.type == type)
|
||||
|
||||
pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all()
|
||||
pipeline_recommended_plugins = db.session.scalars(stmt.order_by(PipelineRecommendedPlugin.position.asc())).all()
|
||||
|
||||
if not pipeline_recommended_plugins:
|
||||
return {
|
||||
@@ -1396,14 +1397,12 @@ class RagPipelineService:
|
||||
"""
|
||||
Retry error document
|
||||
"""
|
||||
document_pipeline_execution_log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document.id)
|
||||
.first()
|
||||
document_pipeline_execution_log = db.session.scalar(
|
||||
select(DocumentPipelineExecutionLog).where(DocumentPipelineExecutionLog.document_id == document.id).limit(1)
|
||||
)
|
||||
if not document_pipeline_execution_log:
|
||||
raise ValueError("Document pipeline execution log not found")
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_execution_log.pipeline_id).first()
|
||||
pipeline = db.session.get(Pipeline, document_pipeline_execution_log.pipeline_id)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
# convert to app config
|
||||
@@ -1432,23 +1431,23 @@ class RagPipelineService:
|
||||
"""
|
||||
Get datasource plugins
|
||||
"""
|
||||
dataset: Dataset | None = (
|
||||
db.session.query(Dataset)
|
||||
dataset: Dataset | None = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = (
|
||||
db.session.query(Pipeline)
|
||||
pipeline: Pipeline | None = db.session.scalar(
|
||||
select(Pipeline)
|
||||
.where(
|
||||
Pipeline.id == dataset.pipeline_id,
|
||||
Pipeline.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
@@ -1530,23 +1529,23 @@ class RagPipelineService:
|
||||
"""
|
||||
Get pipeline
|
||||
"""
|
||||
dataset: Dataset | None = (
|
||||
db.session.query(Dataset)
|
||||
dataset: Dataset | None = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = (
|
||||
db.session.query(Pipeline)
|
||||
pipeline: Pipeline | None = db.session.scalar(
|
||||
select(Pipeline)
|
||||
.where(
|
||||
Pipeline.id == dataset.pipeline_id,
|
||||
Pipeline.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
@@ -292,7 +292,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
"""
|
||||
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
|
||||
mock_db_session.scalar.return_value = api
|
||||
|
||||
result = ExternalDatasetService.get_external_knowledge_api("api-id", "tenant-id")
|
||||
assert result is api
|
||||
@@ -302,7 +302,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
When the record is absent, a ``ValueError`` is raised.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.get_external_knowledge_api("missing-id", "tenant-id")
|
||||
@@ -320,7 +320,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
existing_api = Mock(spec=ExternalKnowledgeApis)
|
||||
existing_api.settings_dict = {"api_key": "stored-key"}
|
||||
existing_api.settings = '{"api_key":"stored-key"}'
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_api
|
||||
mock_db_session.scalar.return_value = existing_api
|
||||
|
||||
args = {
|
||||
"name": "New Name",
|
||||
@@ -340,7 +340,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
Updating a non‑existent API template should raise ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.update_external_knowledge_api(
|
||||
@@ -356,7 +356,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
"""
|
||||
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
|
||||
mock_db_session.scalar.return_value = api
|
||||
|
||||
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "api-1")
|
||||
|
||||
@@ -368,7 +368,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
Deletion of a missing template should raise ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "missing")
|
||||
@@ -394,7 +394,7 @@ class TestExternalDatasetServiceUsageAndBindings:
|
||||
When there are bindings, ``external_knowledge_api_use_check`` returns True and count.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.count.return_value = 3
|
||||
mock_db_session.scalar.return_value = 3
|
||||
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
|
||||
|
||||
@@ -406,7 +406,7 @@ class TestExternalDatasetServiceUsageAndBindings:
|
||||
Zero bindings should return ``(False, 0)``.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.count.return_value = 0
|
||||
mock_db_session.scalar.return_value = 0
|
||||
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
|
||||
|
||||
@@ -419,7 +419,7 @@ class TestExternalDatasetServiceUsageAndBindings:
|
||||
"""
|
||||
|
||||
binding = Mock(spec=ExternalKnowledgeBindings)
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = binding
|
||||
mock_db_session.scalar.return_value = binding
|
||||
|
||||
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
|
||||
assert result is binding
|
||||
@@ -429,7 +429,7 @@ class TestExternalDatasetServiceUsageAndBindings:
|
||||
Missing binding should result in a ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="external knowledge binding not found"):
|
||||
ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
|
||||
@@ -460,7 +460,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
|
||||
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
|
||||
)
|
||||
# Raw string; the service itself calls json.loads on it
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
|
||||
mock_db_session.scalar.return_value = external_api
|
||||
|
||||
process_parameter = {"foo": "value", "bar": "optional"}
|
||||
|
||||
@@ -474,7 +474,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
|
||||
When the referenced API template is missing, a ``ValueError`` is raised.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.document_create_args_validate("tenant-1", "missing", {})
|
||||
@@ -488,7 +488,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
|
||||
external_api.settings = (
|
||||
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
|
||||
)
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
|
||||
mock_db_session.scalar.return_value = external_api
|
||||
|
||||
process_parameter = {"bar": "present"} # missing "foo"
|
||||
|
||||
@@ -702,7 +702,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
|
||||
}
|
||||
|
||||
# No existing dataset with same name.
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
mock_db_session.scalar.side_effect = [
|
||||
None, # duplicate‑name check
|
||||
Mock(spec=ExternalKnowledgeApis), # external knowledge api
|
||||
]
|
||||
@@ -724,7 +724,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
|
||||
"""
|
||||
|
||||
existing_dataset = Mock(spec=Dataset)
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_dataset
|
||||
mock_db_session.scalar.return_value = existing_dataset
|
||||
|
||||
args = {
|
||||
"name": "Existing",
|
||||
@@ -744,7 +744,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
|
||||
"""
|
||||
|
||||
# First call: duplicate name check – not found.
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
mock_db_session.scalar.side_effect = [
|
||||
None,
|
||||
None, # external knowledge api lookup
|
||||
]
|
||||
@@ -763,8 +763,10 @@ class TestExternalDatasetServiceCreateExternalDataset:
|
||||
``external_knowledge_id`` and ``external_knowledge_api_id`` are mandatory.
|
||||
"""
|
||||
|
||||
# duplicate name check
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
# duplicate name check — two calls to create_external_dataset, each does 2 scalar calls
|
||||
mock_db_session.scalar.side_effect = [
|
||||
None,
|
||||
Mock(spec=ExternalKnowledgeApis),
|
||||
None,
|
||||
Mock(spec=ExternalKnowledgeApis),
|
||||
]
|
||||
@@ -826,7 +828,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
|
||||
|
||||
# First query: binding; second query: api.
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
mock_db_session.scalar.side_effect = [
|
||||
binding,
|
||||
api,
|
||||
]
|
||||
@@ -861,7 +863,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
Missing binding should raise ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="external knowledge binding not found"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
@@ -878,7 +880,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
"""
|
||||
|
||||
binding = ExternalDatasetTestDataFactory.create_external_binding()
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
mock_db_session.scalar.side_effect = [
|
||||
binding,
|
||||
None,
|
||||
]
|
||||
@@ -901,7 +903,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
mock_db_session.scalar.side_effect = [
|
||||
binding,
|
||||
api,
|
||||
]
|
||||
|
||||
@@ -117,9 +117,7 @@ def test_get_all_published_workflow_applies_limit_and_has_more(rag_pipeline_serv
|
||||
|
||||
|
||||
def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service) -> None:
|
||||
first_query = mocker.Mock()
|
||||
first_query.where.return_value.first.return_value = None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=first_query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
rag_pipeline_service.get_pipeline("tenant-1", "dataset-1")
|
||||
@@ -131,12 +129,8 @@ def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service
|
||||
def test_update_customized_pipeline_template_success(mocker) -> None:
|
||||
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
|
||||
|
||||
# First query finds the template, second query (duplicate check) returns None
|
||||
query_mock_1 = mocker.Mock()
|
||||
query_mock_1.where.return_value.first.return_value = template
|
||||
query_mock_2 = mocker.Mock()
|
||||
query_mock_2.where.return_value.first.return_value = None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", side_effect=[query_mock_1, query_mock_2])
|
||||
# First scalar finds the template, second scalar (duplicate check) returns None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, None])
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
|
||||
@@ -152,9 +146,7 @@ def test_update_customized_pipeline_template_success(mocker) -> None:
|
||||
|
||||
|
||||
def test_update_customized_pipeline_template_not_found(mocker) -> None:
|
||||
query_mock = mocker.Mock()
|
||||
query_mock.where.return_value.first.return_value = None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
|
||||
info = PipelineTemplateInfoEntity(name="x", description="d", icon_info=IconInfo(icon="i"))
|
||||
@@ -166,9 +158,7 @@ def test_update_customized_pipeline_template_duplicate_name(mocker) -> None:
|
||||
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
|
||||
duplicate = SimpleNamespace(name="dup")
|
||||
|
||||
query_mock = mocker.Mock()
|
||||
query_mock.where.return_value.first.side_effect = [template, duplicate]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, duplicate])
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
|
||||
info = PipelineTemplateInfoEntity(name="dup", description="d", icon_info=IconInfo(icon="i"))
|
||||
@@ -181,9 +171,7 @@ def test_update_customized_pipeline_template_duplicate_name(mocker) -> None:
|
||||
|
||||
def test_delete_customized_pipeline_template_success(mocker) -> None:
|
||||
template = SimpleNamespace(id="tpl-1")
|
||||
query_mock = mocker.Mock()
|
||||
query_mock.where.return_value.first.return_value = template
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
|
||||
delete_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.delete")
|
||||
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
|
||||
@@ -196,9 +184,7 @@ def test_delete_customized_pipeline_template_success(mocker) -> None:
|
||||
|
||||
|
||||
def test_delete_customized_pipeline_template_not_found(mocker) -> None:
|
||||
query_mock = mocker.Mock()
|
||||
query_mock.where.return_value.first.return_value = None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
|
||||
with pytest.raises(ValueError, match="Customized pipeline template not found"):
|
||||
@@ -397,18 +383,14 @@ def test_get_rag_pipeline_workflow_run_delegates(mocker, rag_pipeline_service) -
|
||||
|
||||
|
||||
def test_is_workflow_exist_returns_true_when_draft_exists(mocker, rag_pipeline_service) -> None:
|
||||
query_mock = mocker.Mock()
|
||||
query_mock.where.return_value.count.return_value = 1
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=1)
|
||||
|
||||
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
|
||||
assert rag_pipeline_service.is_workflow_exist(pipeline) is True
|
||||
|
||||
|
||||
def test_is_workflow_exist_returns_false_when_no_draft(mocker, rag_pipeline_service) -> None:
|
||||
query_mock = mocker.Mock()
|
||||
query_mock.where.return_value.count.return_value = 0
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=0)
|
||||
|
||||
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
|
||||
assert rag_pipeline_service.is_workflow_exist(pipeline) is False
|
||||
@@ -738,8 +720,7 @@ def test_get_second_step_parameters_success(mocker, rag_pipeline_service) -> Non
|
||||
|
||||
|
||||
def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_service) -> None:
|
||||
from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate
|
||||
from models.workflow import Workflow
|
||||
from models.dataset import Pipeline
|
||||
|
||||
# 1. Setup mocks
|
||||
pipeline = mocker.Mock(spec=Pipeline)
|
||||
@@ -754,36 +735,15 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi
|
||||
# Mock db itself to avoid app context errors
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
|
||||
# Improved mocking for session.query
|
||||
def mock_query_side_effect(model):
|
||||
m = mocker.Mock()
|
||||
if model == Pipeline:
|
||||
m.where.return_value.first.return_value = pipeline
|
||||
elif model == Workflow:
|
||||
m.where.return_value.first.return_value = workflow
|
||||
elif model == PipelineCustomizedTemplate:
|
||||
m.where.return_value.first.return_value = None
|
||||
elif model == Dataset:
|
||||
m.where.return_value.first.return_value = mocker.Mock()
|
||||
else:
|
||||
# For func.max cases
|
||||
m.where.return_value.scalar.return_value = 5
|
||||
m.where.return_value.first.return_value = mocker.Mock()
|
||||
return m
|
||||
|
||||
mock_db.session.query.side_effect = mock_query_side_effect
|
||||
# Mock get() for Pipeline and Workflow PK lookups
|
||||
mock_db.session.get.side_effect = [pipeline, workflow]
|
||||
# Mock scalar() for template name check (None) and max position (5)
|
||||
mock_db.session.scalar.side_effect = [None, 5]
|
||||
|
||||
# Mock retrieve_dataset
|
||||
dataset = mocker.Mock()
|
||||
pipeline.retrieve_dataset.return_value = dataset
|
||||
|
||||
# Mock max position
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.func.max", return_value=1)
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline.db.session.query.return_value.where.return_value.scalar",
|
||||
return_value=5,
|
||||
)
|
||||
|
||||
# Mock RagPipelineDslService
|
||||
mock_dsl_service = mocker.Mock()
|
||||
mock_dsl_service.export_rag_pipeline_dsl.return_value = {"dsl": "content"}
|
||||
@@ -839,9 +799,7 @@ def test_get_datasource_plugins_success(mocker, rag_pipeline_service) -> None:
|
||||
workflow.rag_pipeline_variables = []
|
||||
|
||||
# Mock queries
|
||||
mock_query = mocker.Mock()
|
||||
mock_query.where.return_value.first.side_effect = [dataset, pipeline]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=mock_query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
|
||||
|
||||
@@ -881,11 +839,9 @@ def test_retry_error_document_success(mocker, rag_pipeline_service) -> None:
|
||||
|
||||
workflow = mocker.Mock()
|
||||
|
||||
# Mock queries
|
||||
mock_query = mocker.Mock()
|
||||
# Log lookup, then Pipeline lookup
|
||||
mock_query.where.return_value.first.side_effect = [log, pipeline]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=mock_query)
|
||||
# Mock queries: Log lookup via scalar, Pipeline lookup via get
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=log)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
|
||||
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
|
||||
|
||||
@@ -913,7 +869,7 @@ def test_set_datasource_variables_success(mocker, rag_pipeline_service) -> None:
|
||||
# Mock db aggressively
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.engine = mocker.Mock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mocker.Mock()
|
||||
mock_db.session.scalar.return_value = mocker.Mock()
|
||||
|
||||
pipeline = mocker.Mock(spec=Pipeline)
|
||||
pipeline.id = "p-1"
|
||||
@@ -976,7 +932,7 @@ def test_get_draft_workflow_success(mocker, rag_pipeline_service) -> None:
|
||||
workflow = mocker.Mock(spec=Workflow)
|
||||
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = workflow
|
||||
mock_db.session.scalar.return_value = workflow
|
||||
|
||||
# 2. Run test
|
||||
result = rag_pipeline_service.get_draft_workflow(pipeline)
|
||||
@@ -998,7 +954,7 @@ def test_get_published_workflow_success(mocker, rag_pipeline_service) -> None:
|
||||
workflow = mocker.Mock(spec=Workflow)
|
||||
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = workflow
|
||||
mock_db.session.scalar.return_value = workflow
|
||||
|
||||
# 2. Run test
|
||||
result = rag_pipeline_service.get_published_workflow(pipeline)
|
||||
@@ -1319,11 +1275,8 @@ def test_get_rag_pipeline_workflow_run_node_executions_returns_sorted_executions
|
||||
|
||||
|
||||
def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, rag_pipeline_service) -> None:
|
||||
query = mocker.Mock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value.all.return_value = []
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.session.query.return_value = query
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
result = rag_pipeline_service.get_recommended_plugins("all")
|
||||
|
||||
@@ -1336,11 +1289,8 @@ def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, ra
|
||||
def test_get_recommended_plugins_returns_installed_and_uninstalled(mocker, rag_pipeline_service) -> None:
|
||||
plugin_a = SimpleNamespace(plugin_id="plugin-a")
|
||||
plugin_b = SimpleNamespace(plugin_id="plugin-b")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value.all.return_value = [plugin_a, plugin_b]
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.session.query.return_value = query
|
||||
mock_db.session.scalars.return_value.all.return_value = [plugin_a, plugin_b]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools",
|
||||
@@ -1568,9 +1518,7 @@ def test_get_second_step_parameters_filters_first_step_variables(mocker, rag_pip
|
||||
|
||||
|
||||
def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pipeline_service) -> None:
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.return_value = None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Document pipeline execution log not found"):
|
||||
rag_pipeline_service.retry_error_document(
|
||||
@@ -1581,9 +1529,7 @@ def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pi
|
||||
def test_get_datasource_plugins_raises_when_workflow_not_found(mocker, rag_pipeline_service) -> None:
|
||||
dataset = SimpleNamespace(pipeline_id="p1")
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [dataset, pipeline]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Pipeline or workflow not found"):
|
||||
@@ -1656,8 +1602,7 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(mocker
|
||||
|
||||
document = SimpleNamespace(indexing_status="waiting", error=None)
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.return_value = document
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=document)
|
||||
add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add")
|
||||
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
|
||||
@@ -1712,9 +1657,7 @@ def test_run_datasource_node_preview_raises_for_unsupported_provider(mocker, rag
|
||||
|
||||
|
||||
def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker, rag_pipeline_service) -> None:
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.return_value = None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Pipeline not found"):
|
||||
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
|
||||
@@ -1722,9 +1665,7 @@ def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker
|
||||
|
||||
def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(mocker, rag_pipeline_service) -> None:
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id=None)
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.return_value = pipeline
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
|
||||
|
||||
with pytest.raises(ValueError, match="Pipeline workflow not found"):
|
||||
rag_pipeline_service.publish_customized_pipeline_template("p1", {"name": "template-name"})
|
||||
@@ -1732,8 +1673,7 @@ def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(moc
|
||||
|
||||
def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.return_value = None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
rag_pipeline_service.get_pipeline("t1", "d1")
|
||||
@@ -1742,8 +1682,7 @@ def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service)
|
||||
def test_get_pipeline_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
|
||||
dataset = SimpleNamespace(pipeline_id="p1")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [dataset, None]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, None])
|
||||
|
||||
with pytest.raises(ValueError, match="Pipeline not found"):
|
||||
rag_pipeline_service.get_pipeline("t1", "d1")
|
||||
@@ -1783,8 +1722,7 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None:
|
||||
def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None:
|
||||
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.return_value = template
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
|
||||
commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
|
||||
@@ -2011,8 +1949,7 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_
|
||||
def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [pipeline, None]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None])
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
|
||||
@@ -2021,11 +1958,9 @@ def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocke
|
||||
def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
|
||||
workflow = SimpleNamespace(id="wf-1")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [pipeline, workflow]
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.engine = mocker.Mock()
|
||||
mock_db.session.query.return_value = query
|
||||
mock_db.session.get.side_effect = [pipeline, workflow]
|
||||
session_ctx = mocker.MagicMock()
|
||||
session_ctx.__enter__.return_value = SimpleNamespace()
|
||||
session_ctx.__exit__.return_value = False
|
||||
@@ -2038,11 +1973,8 @@ def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker
|
||||
|
||||
def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipeline_service) -> None:
|
||||
plugin = SimpleNamespace(plugin_id="plugin-a")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value.all.return_value = [plugin]
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.session.query.return_value = query
|
||||
mock_db.session.scalars.return_value.all.return_value = [plugin]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools", return_value=[])
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.marketplace.batch_fetch_plugin_by_ids", return_value=[])
|
||||
@@ -2056,8 +1988,8 @@ def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipelin
|
||||
def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
|
||||
exec_log = SimpleNamespace(pipeline_id="p1")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [exec_log, None]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Pipeline not found"):
|
||||
rag_pipeline_service.retry_error_document(
|
||||
@@ -2069,8 +2001,8 @@ def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_
|
||||
exec_log = SimpleNamespace(pipeline_id="p1")
|
||||
pipeline = SimpleNamespace(id="p1")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [exec_log, pipeline]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
@@ -2086,8 +2018,7 @@ def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, r
|
||||
graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[]
|
||||
)
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [dataset, pipeline]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
|
||||
|
||||
assert rag_pipeline_service.get_datasource_plugins("t1", "d1", True) == []
|
||||
@@ -2250,8 +2181,7 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published(
|
||||
rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}],
|
||||
)
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [dataset, pipeline]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=workflow)
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials", return_value=[]
|
||||
@@ -2291,8 +2221,7 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
|
||||
],
|
||||
)
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [dataset, pipeline]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials",
|
||||
@@ -2310,8 +2239,7 @@ def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service)
|
||||
dataset = SimpleNamespace(pipeline_id="p1")
|
||||
pipeline = SimpleNamespace(id="p1")
|
||||
query = mocker.Mock()
|
||||
query.where.return_value.first.side_effect = [dataset, pipeline]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
|
||||
result = rag_pipeline_service.get_pipeline("t1", "d1")
|
||||
|
||||
|
||||
@@ -173,9 +173,7 @@ class TestAccountService:
|
||||
# Setup test data
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||
|
||||
# Setup smart database query mock
|
||||
query_results = {("Account", "email", "test@example.com"): mock_account}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_account
|
||||
|
||||
mock_password_dependencies["compare_password"].return_value = True
|
||||
|
||||
@@ -188,9 +186,7 @@ class TestAccountService:
|
||||
|
||||
def test_authenticate_account_not_found(self, mock_db_dependencies):
|
||||
"""Test authentication when account does not exist."""
|
||||
# Setup smart database query mock - no matching results
|
||||
query_results = {("Account", "email", "notfound@example.com"): None}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.scalar.return_value = None
|
||||
|
||||
# Execute test and verify exception
|
||||
self._assert_exception_raised(
|
||||
@@ -202,9 +198,7 @@ class TestAccountService:
|
||||
# Setup test data
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
|
||||
|
||||
# Setup smart database query mock
|
||||
query_results = {("Account", "email", "banned@example.com"): mock_account}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_account
|
||||
|
||||
# Execute test and verify exception
|
||||
self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password")
|
||||
@@ -214,9 +208,7 @@ class TestAccountService:
|
||||
# Setup test data
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||
|
||||
# Setup smart database query mock
|
||||
query_results = {("Account", "email", "test@example.com"): mock_account}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_account
|
||||
|
||||
mock_password_dependencies["compare_password"].return_value = False
|
||||
|
||||
@@ -230,9 +222,7 @@ class TestAccountService:
|
||||
# Setup test data
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="pending")
|
||||
|
||||
# Setup smart database query mock
|
||||
query_results = {("Account", "email", "pending@example.com"): mock_account}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_account
|
||||
|
||||
mock_password_dependencies["compare_password"].return_value = True
|
||||
|
||||
@@ -422,12 +412,8 @@ class TestAccountService:
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||
mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock()
|
||||
|
||||
# Setup smart database query mock
|
||||
query_results = {
|
||||
("Account", "id", "user-123"): mock_account,
|
||||
("TenantAccountJoin", "account_id", "user-123"): mock_tenant_join,
|
||||
}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.get.return_value = mock_account
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant_join
|
||||
|
||||
# Mock datetime
|
||||
with patch("services.account_service.datetime") as mock_datetime:
|
||||
@@ -444,9 +430,7 @@ class TestAccountService:
|
||||
|
||||
def test_load_user_not_found(self, mock_db_dependencies):
|
||||
"""Test user loading when user does not exist."""
|
||||
# Setup smart database query mock - no matching results
|
||||
query_results = {("Account", "id", "non-existent-user"): None}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.get.return_value = None
|
||||
|
||||
# Execute test
|
||||
result = AccountService.load_user("non-existent-user")
|
||||
@@ -459,9 +443,7 @@ class TestAccountService:
|
||||
# Setup test data
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
|
||||
|
||||
# Setup smart database query mock
|
||||
query_results = {("Account", "id", "user-123"): mock_account}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.get.return_value = mock_account
|
||||
|
||||
# Execute test and verify exception
|
||||
self._assert_exception_raised(
|
||||
@@ -476,13 +458,9 @@ class TestAccountService:
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||
mock_available_tenant = TestAccountAssociatedDataFactory.create_tenant_join_mock(current=False)
|
||||
|
||||
# Setup smart database query mock for complex scenario
|
||||
query_results = {
|
||||
("Account", "id", "user-123"): mock_account,
|
||||
("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
|
||||
("TenantAccountJoin", "order_by", "first_available"): mock_available_tenant, # First available tenant
|
||||
}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.get.return_value = mock_account
|
||||
# First scalar: current tenant (None), second scalar: available tenant
|
||||
mock_db_dependencies["db"].session.scalar.side_effect = [None, mock_available_tenant]
|
||||
|
||||
# Mock datetime
|
||||
with patch("services.account_service.datetime") as mock_datetime:
|
||||
@@ -503,13 +481,9 @@ class TestAccountService:
|
||||
# Setup test data
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||
|
||||
# Setup smart database query mock for no tenants scenario
|
||||
query_results = {
|
||||
("Account", "id", "user-123"): mock_account,
|
||||
("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
|
||||
("TenantAccountJoin", "order_by", "first_available"): None, # No available tenants
|
||||
}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies["db"].session.get.return_value = mock_account
|
||||
# First scalar: current tenant (None), second scalar: available tenant (None)
|
||||
mock_db_dependencies["db"].session.scalar.side_effect = [None, None]
|
||||
|
||||
# Mock datetime
|
||||
with patch("services.account_service.datetime") as mock_datetime:
|
||||
@@ -1060,7 +1034,7 @@ class TestRegisterService:
|
||||
)
|
||||
|
||||
# Verify rollback operations were called
|
||||
mock_db_dependencies["db"].session.query.assert_called()
|
||||
mock_db_dependencies["db"].session.execute.assert_called()
|
||||
|
||||
# ==================== Registration Tests ====================
|
||||
|
||||
@@ -1625,10 +1599,8 @@ class TestRegisterService:
|
||||
mock_session_class.return_value.__exit__.return_value = None
|
||||
mock_lookup.return_value = mock_existing_account
|
||||
|
||||
# Mock the db.session.query for TenantAccountJoin
|
||||
mock_db_query = MagicMock()
|
||||
mock_db_query.filter_by.return_value.first.return_value = None # No existing member
|
||||
mock_db_dependencies["db"].session.query.return_value = mock_db_query
|
||||
# Mock scalar for TenantAccountJoin lookup - no existing member
|
||||
mock_db_dependencies["db"].session.scalar.return_value = None
|
||||
|
||||
# Mock TenantService methods
|
||||
with (
|
||||
@@ -1803,14 +1775,9 @@ class TestRegisterService:
|
||||
}
|
||||
mock_get_invitation_by_token.return_value = invitation_data
|
||||
|
||||
# Mock database queries - complex query mocking
|
||||
mock_query1 = MagicMock()
|
||||
mock_query1.where.return_value.first.return_value = mock_tenant
|
||||
|
||||
mock_query2 = MagicMock()
|
||||
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
|
||||
|
||||
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
|
||||
# Mock scalar for tenant lookup, execute for account+role lookup
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
|
||||
mock_db_dependencies["db"].session.execute.return_value.first.return_value = (mock_account, "normal")
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
@@ -1842,10 +1809,8 @@ class TestRegisterService:
|
||||
}
|
||||
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
|
||||
|
||||
# Mock database queries - no tenant found
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = None
|
||||
mock_db_dependencies["db"].session.query.return_value = mock_query
|
||||
# Mock scalar for tenant lookup - not found
|
||||
mock_db_dependencies["db"].session.scalar.return_value = None
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
@@ -1868,14 +1833,9 @@ class TestRegisterService:
|
||||
}
|
||||
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
|
||||
|
||||
# Mock database queries
|
||||
mock_query1 = MagicMock()
|
||||
mock_query1.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
mock_query2 = MagicMock()
|
||||
mock_query2.join.return_value.where.return_value.first.return_value = None # No account found
|
||||
|
||||
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
|
||||
# Mock scalar for tenant, execute for account+role
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
|
||||
mock_db_dependencies["db"].session.execute.return_value.first.return_value = None # No account found
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
@@ -1901,14 +1861,9 @@ class TestRegisterService:
|
||||
}
|
||||
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
|
||||
|
||||
# Mock database queries
|
||||
mock_query1 = MagicMock()
|
||||
mock_query1.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
mock_query2 = MagicMock()
|
||||
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
|
||||
|
||||
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
|
||||
# Mock scalar for tenant, execute for account+role
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
|
||||
mock_db_dependencies["db"].session.execute.return_value.first.return_value = (mock_account, "normal")
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
|
||||
@@ -799,10 +799,7 @@ class TestExternalDatasetServiceGetAPI:
|
||||
api_id = "api-123"
|
||||
expected_api = factory.create_external_knowledge_api_mock(api_id=api_id)
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = expected_api
|
||||
mock_db.session.scalar.return_value = expected_api
|
||||
|
||||
# Act
|
||||
tenant_id = "tenant-123"
|
||||
@@ -810,16 +807,12 @@ class TestExternalDatasetServiceGetAPI:
|
||||
|
||||
# Assert
|
||||
assert result.id == api_id
|
||||
mock_query.filter_by.assert_called_once_with(id=api_id, tenant_id=tenant_id)
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_get_external_knowledge_api_not_found(self, mock_db, factory):
|
||||
"""Test error when API is not found."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
@@ -848,10 +841,7 @@ class TestExternalDatasetServiceUpdateAPI:
|
||||
"settings": {"endpoint": "https://new.example.com", "api_key": "new-key"},
|
||||
}
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_api
|
||||
mock_db.session.scalar.return_value = existing_api
|
||||
|
||||
# Act
|
||||
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, user_id, api_id, args)
|
||||
@@ -881,10 +871,7 @@ class TestExternalDatasetServiceUpdateAPI:
|
||||
"settings": {"endpoint": "https://api.example.com", "api_key": HIDDEN_VALUE},
|
||||
}
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_api
|
||||
mock_db.session.scalar.return_value = existing_api
|
||||
|
||||
# Act
|
||||
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, "user-123", api_id, args)
|
||||
@@ -897,10 +884,7 @@ class TestExternalDatasetServiceUpdateAPI:
|
||||
def test_update_external_knowledge_api_not_found(self, mock_db, factory):
|
||||
"""Test error when API is not found."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
args = {"name": "Updated API"}
|
||||
|
||||
@@ -912,10 +896,7 @@ class TestExternalDatasetServiceUpdateAPI:
|
||||
def test_update_external_knowledge_api_tenant_mismatch(self, mock_db, factory):
|
||||
"""Test error when tenant ID doesn't match."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
args = {"name": "Updated API"}
|
||||
|
||||
@@ -934,10 +915,7 @@ class TestExternalDatasetServiceUpdateAPI:
|
||||
|
||||
args = {"name": "New Name Only"}
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_api
|
||||
mock_db.session.scalar.return_value = existing_api
|
||||
|
||||
# Act
|
||||
result = ExternalDatasetService.update_external_knowledge_api("tenant-123", "user-123", "api-123", args)
|
||||
@@ -958,10 +936,7 @@ class TestExternalDatasetServiceDeleteAPI:
|
||||
|
||||
existing_api = factory.create_external_knowledge_api_mock(api_id=api_id, tenant_id=tenant_id)
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_api
|
||||
mock_db.session.scalar.return_value = existing_api
|
||||
|
||||
# Act
|
||||
ExternalDatasetService.delete_external_knowledge_api(tenant_id, api_id)
|
||||
@@ -974,10 +949,7 @@ class TestExternalDatasetServiceDeleteAPI:
|
||||
def test_delete_external_knowledge_api_not_found(self, mock_db, factory):
|
||||
"""Test error when API is not found."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
@@ -987,10 +959,7 @@ class TestExternalDatasetServiceDeleteAPI:
|
||||
def test_delete_external_knowledge_api_tenant_mismatch(self, mock_db, factory):
|
||||
"""Test error when tenant ID doesn't match."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
@@ -1006,10 +975,7 @@ class TestExternalDatasetServiceAPIUseCheck:
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.count.return_value = 1
|
||||
mock_db.session.scalar.return_value = 1
|
||||
|
||||
# Act
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
|
||||
@@ -1024,10 +990,7 @@ class TestExternalDatasetServiceAPIUseCheck:
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.count.return_value = 10
|
||||
mock_db.session.scalar.return_value = 10
|
||||
|
||||
# Act
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
|
||||
@@ -1042,10 +1005,7 @@ class TestExternalDatasetServiceAPIUseCheck:
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.count.return_value = 0
|
||||
mock_db.session.scalar.return_value = 0
|
||||
|
||||
# Act
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
|
||||
@@ -1067,10 +1027,7 @@ class TestExternalDatasetServiceGetBinding:
|
||||
|
||||
expected_binding = factory.create_external_knowledge_binding_mock(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = expected_binding
|
||||
mock_db.session.scalar.return_value = expected_binding
|
||||
|
||||
# Act
|
||||
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id(tenant_id, dataset_id)
|
||||
@@ -1083,10 +1040,7 @@ class TestExternalDatasetServiceGetBinding:
|
||||
def test_get_external_knowledge_binding_not_found(self, mock_db, factory):
|
||||
"""Test error when binding is not found."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="external knowledge binding not found"):
|
||||
@@ -1113,10 +1067,7 @@ class TestExternalDatasetServiceDocumentValidate:
|
||||
|
||||
api = factory.create_external_knowledge_api_mock(api_id=api_id, settings=[settings])
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = api
|
||||
mock_db.session.scalar.return_value = api
|
||||
|
||||
process_parameter = {"param1": "value1", "param2": "value2"}
|
||||
|
||||
@@ -1134,10 +1085,7 @@ class TestExternalDatasetServiceDocumentValidate:
|
||||
|
||||
api = factory.create_external_knowledge_api_mock(api_id=api_id, settings=[settings])
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = api
|
||||
mock_db.session.scalar.return_value = api
|
||||
|
||||
process_parameter = {}
|
||||
|
||||
@@ -1149,10 +1097,7 @@ class TestExternalDatasetServiceDocumentValidate:
|
||||
def test_document_create_args_validate_api_not_found(self, mock_db, factory):
|
||||
"""Test validation fails when API is not found."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
@@ -1165,10 +1110,7 @@ class TestExternalDatasetServiceDocumentValidate:
|
||||
settings = {}
|
||||
api = factory.create_external_knowledge_api_mock(settings=[settings])
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = api
|
||||
mock_db.session.scalar.return_value = api
|
||||
|
||||
# Act & Assert - should not raise
|
||||
ExternalDatasetService.document_create_args_validate("tenant-123", "api-123", {})
|
||||
@@ -1186,10 +1128,7 @@ class TestExternalDatasetServiceDocumentValidate:
|
||||
|
||||
api = factory.create_external_knowledge_api_mock(settings=[settings])
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = api
|
||||
mock_db.session.scalar.return_value = api
|
||||
|
||||
process_parameter = {"required_param": "value"}
|
||||
|
||||
@@ -1498,24 +1437,7 @@ class TestExternalDatasetServiceCreateDataset:
|
||||
|
||||
api = factory.create_external_knowledge_api_mock(api_id="api-123")
|
||||
|
||||
# Mock database queries
|
||||
mock_dataset_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == Dataset:
|
||||
return mock_dataset_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_dataset_query.filter_by.return_value = mock_dataset_query
|
||||
mock_dataset_query.first.return_value = None
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [None, api]
|
||||
|
||||
# Act
|
||||
result = ExternalDatasetService.create_external_dataset(tenant_id, user_id, args)
|
||||
@@ -1534,10 +1456,7 @@ class TestExternalDatasetServiceCreateDataset:
|
||||
# Arrange
|
||||
existing_dataset = factory.create_dataset_mock(name="Duplicate Dataset")
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_dataset
|
||||
mock_db.session.scalar.return_value = existing_dataset
|
||||
|
||||
args = {"name": "Duplicate Dataset"}
|
||||
|
||||
@@ -1549,23 +1468,7 @@ class TestExternalDatasetServiceCreateDataset:
|
||||
def test_create_external_dataset_api_not_found_error(self, mock_db, factory):
|
||||
"""Test error when external knowledge API is not found."""
|
||||
# Arrange
|
||||
mock_dataset_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == Dataset:
|
||||
return mock_dataset_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_dataset_query.filter_by.return_value = mock_dataset_query
|
||||
mock_dataset_query.first.return_value = None
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = None
|
||||
mock_db.session.scalar.side_effect = [None, None]
|
||||
|
||||
args = {"name": "Test Dataset", "external_knowledge_api_id": "nonexistent-api"}
|
||||
|
||||
@@ -1579,23 +1482,7 @@ class TestExternalDatasetServiceCreateDataset:
|
||||
# Arrange
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_dataset_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == Dataset:
|
||||
return mock_dataset_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_dataset_query.filter_by.return_value = mock_dataset_query
|
||||
mock_dataset_query.first.return_value = None
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [None, api]
|
||||
|
||||
args = {"name": "Test Dataset", "external_knowledge_api_id": "api-123"}
|
||||
|
||||
@@ -1609,23 +1496,7 @@ class TestExternalDatasetServiceCreateDataset:
|
||||
# Arrange
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_dataset_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == Dataset:
|
||||
return mock_dataset_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_dataset_query.filter_by.return_value = mock_dataset_query
|
||||
mock_dataset_query.first.return_value = None
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [None, api]
|
||||
|
||||
args = {"name": "Test Dataset", "external_knowledge_id": "knowledge-123"}
|
||||
|
||||
@@ -1651,23 +1522,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
)
|
||||
api = factory.create_external_knowledge_api_mock(api_id="api-123")
|
||||
|
||||
mock_binding_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == ExternalKnowledgeBindings:
|
||||
return mock_binding_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||
mock_binding_query.first.return_value = binding
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [binding, api]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
@@ -1695,10 +1550,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
def test_fetch_external_knowledge_retrieval_binding_not_found_error(self, mock_db, factory):
|
||||
"""Test error when external knowledge binding is not found."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="external knowledge binding not found"):
|
||||
@@ -1712,23 +1564,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_binding_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == ExternalKnowledgeBindings:
|
||||
return mock_binding_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||
mock_binding_query.first.return_value = binding
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [binding, api]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
@@ -1751,23 +1587,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_binding_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == ExternalKnowledgeBindings:
|
||||
return mock_binding_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||
mock_binding_query.first.return_value = binding
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [binding, api]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
@@ -1799,23 +1619,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_binding_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == ExternalKnowledgeBindings:
|
||||
return mock_binding_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||
mock_binding_query.first.return_value = binding
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [binding, api]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
@@ -1856,23 +1660,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
)
|
||||
api = factory.create_external_knowledge_api_mock(api_id="api-123")
|
||||
|
||||
mock_binding_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == ExternalKnowledgeBindings:
|
||||
return mock_binding_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||
mock_binding_query.first.return_value = binding
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [binding, api]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
@@ -1891,23 +1679,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_binding_query = MagicMock()
|
||||
mock_api_query = MagicMock()
|
||||
|
||||
def query_side_effect(model):
|
||||
if model == ExternalKnowledgeBindings:
|
||||
return mock_binding_query
|
||||
elif model == ExternalKnowledgeApis:
|
||||
return mock_api_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||
mock_binding_query.first.return_value = binding
|
||||
|
||||
mock_api_query.filter_by.return_value = mock_api_query
|
||||
mock_api_query.first.return_value = api
|
||||
mock_db.session.scalar.side_effect = [binding, api]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 503
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
"prepare": "vp config"
|
||||
},
|
||||
"devDependencies": {
|
||||
"taze": "catalog:",
|
||||
"vite-plus": "catalog:"
|
||||
},
|
||||
"engines": {
|
||||
|
||||
3079
pnpm-lock.yaml
generated
3079
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,9 @@
|
||||
catalogMode: prefer
|
||||
trustPolicy: no-downgrade
|
||||
minimumReleaseAge: 2880
|
||||
trustPolicyExclude:
|
||||
- chokidar@4.0.3
|
||||
- reselect@5.1.1
|
||||
- semver@6.3.1
|
||||
blockExoticSubdeps: true
|
||||
strictDepBuilds: true
|
||||
allowBuilds:
|
||||
@@ -23,7 +27,7 @@ overrides:
|
||||
array.prototype.flatmap: npm:@nolyfill/array.prototype.flatmap@^1.0.44
|
||||
array.prototype.tosorted: npm:@nolyfill/array.prototype.tosorted@^1.0.44
|
||||
assert: npm:@nolyfill/assert@^1.0.26
|
||||
brace-expansion@<2.0.2: 2.0.2
|
||||
brace-expansion@>=2.0.0 <2.0.3: 2.0.3
|
||||
canvas: ^3.2.2
|
||||
devalue@<5.3.2: 5.3.2
|
||||
dompurify@>=3.1.3 <=3.3.1: 3.3.2
|
||||
@@ -37,6 +41,8 @@ overrides:
|
||||
is-generator-function: npm:@nolyfill/is-generator-function@^1.0.44
|
||||
is-typed-array: npm:@nolyfill/is-typed-array@^1.0.44
|
||||
isarray: npm:@nolyfill/isarray@^1.0.44
|
||||
lodash@>=4.0.0 <= 4.17.23: 4.18.0
|
||||
lodash-es@>=4.0.0 <= 4.17.23: 4.18.0
|
||||
object.assign: npm:@nolyfill/object.assign@^1.0.44
|
||||
object.entries: npm:@nolyfill/object.entries@^1.0.44
|
||||
object.fromentries: npm:@nolyfill/object.fromentries@^1.0.44
|
||||
@@ -64,15 +70,15 @@ overrides:
|
||||
tar@<=7.5.10: 7.5.11
|
||||
typed-array-buffer: npm:@nolyfill/typed-array-buffer@^1.0.44
|
||||
undici@>=7.0.0 <7.24.0: 7.24.0
|
||||
vite: npm:@voidzero-dev/vite-plus-core@0.1.14
|
||||
vitest: npm:@voidzero-dev/vite-plus-test@0.1.14
|
||||
vite: npm:@voidzero-dev/vite-plus-core@0.1.15
|
||||
vitest: npm:@voidzero-dev/vite-plus-test@0.1.15
|
||||
which-typed-array: npm:@nolyfill/which-typed-array@^1.0.44
|
||||
yaml@>=2.0.0 <2.8.3: 2.8.3
|
||||
yauzl@<3.2.1: 3.2.1
|
||||
catalog:
|
||||
"@amplitude/analytics-browser": 2.38.0
|
||||
"@amplitude/plugin-session-replay-browser": 1.27.5
|
||||
"@antfu/eslint-config": 7.7.3
|
||||
"@amplitude/analytics-browser": 2.38.1
|
||||
"@amplitude/plugin-session-replay-browser": 1.27.6
|
||||
"@antfu/eslint-config": 8.0.0
|
||||
"@base-ui/react": 1.3.0
|
||||
"@chromatic-com/storybook": 5.1.1
|
||||
"@cucumber/cucumber": 12.7.0
|
||||
@@ -84,7 +90,7 @@ catalog:
|
||||
"@formatjs/intl-localematcher": 0.8.2
|
||||
"@headlessui/react": 2.2.9
|
||||
"@heroicons/react": 2.2.0
|
||||
"@hono/node-server": 1.19.11
|
||||
"@hono/node-server": 1.19.12
|
||||
"@iconify-json/heroicons": 1.2.3
|
||||
"@iconify-json/ri": 1.2.10
|
||||
"@lexical/code": 0.42.0
|
||||
@@ -98,34 +104,35 @@ catalog:
|
||||
"@mdx-js/react": 3.1.1
|
||||
"@mdx-js/rollup": 3.1.1
|
||||
"@monaco-editor/react": 4.7.0
|
||||
"@next/eslint-plugin-next": 16.2.1
|
||||
"@next/mdx": 16.2.1
|
||||
"@next/eslint-plugin-next": 16.2.2
|
||||
"@next/mdx": 16.2.2
|
||||
"@orpc/client": 1.13.13
|
||||
"@orpc/contract": 1.13.13
|
||||
"@orpc/openapi-client": 1.13.13
|
||||
"@orpc/tanstack-query": 1.13.13
|
||||
"@playwright/test": 1.58.2
|
||||
"@playwright/test": 1.59.1
|
||||
"@remixicon/react": 4.9.0
|
||||
"@rgrove/parse-xml": 4.2.0
|
||||
"@sentry/react": 10.46.0
|
||||
"@storybook/addon-docs": 10.3.3
|
||||
"@storybook/addon-links": 10.3.3
|
||||
"@storybook/addon-onboarding": 10.3.3
|
||||
"@storybook/addon-themes": 10.3.3
|
||||
"@storybook/nextjs-vite": 10.3.3
|
||||
"@storybook/react": 10.3.3
|
||||
"@sentry/react": 10.47.0
|
||||
"@storybook/addon-docs": 10.3.4
|
||||
"@storybook/addon-links": 10.3.4
|
||||
"@storybook/addon-onboarding": 10.3.4
|
||||
"@storybook/addon-themes": 10.3.4
|
||||
"@storybook/nextjs-vite": 10.3.4
|
||||
"@storybook/react": 10.3.4
|
||||
"@streamdown/math": 1.0.2
|
||||
"@svgdotjs/svg.js": 3.2.5
|
||||
"@t3-oss/env-nextjs": 0.13.11
|
||||
"@tailwindcss/postcss": 4.2.2
|
||||
"@tailwindcss/typography": 0.5.19
|
||||
"@tailwindcss/vite": 4.2.2
|
||||
"@tanstack/eslint-plugin-query": 5.95.2
|
||||
"@tanstack/react-devtools": 0.10.0
|
||||
"@tanstack/react-form": 1.28.5
|
||||
"@tanstack/react-form-devtools": 0.2.19
|
||||
"@tanstack/react-query": 5.95.2
|
||||
"@tanstack/react-query-devtools": 5.95.2
|
||||
"@tanstack/eslint-plugin-query": 5.96.1
|
||||
"@tanstack/react-devtools": 0.10.1
|
||||
"@tanstack/react-form": 1.28.6
|
||||
"@tanstack/react-form-devtools": 0.2.20
|
||||
"@tanstack/react-query": 5.96.1
|
||||
"@tanstack/react-query-devtools": 5.96.1
|
||||
"@tanstack/react-virtual": 3.13.23
|
||||
"@testing-library/dom": 10.4.1
|
||||
"@testing-library/jest-dom": 6.9.1
|
||||
"@testing-library/react": 16.3.2
|
||||
@@ -141,15 +148,13 @@ catalog:
|
||||
"@types/qs": 6.15.0
|
||||
"@types/react": 19.2.14
|
||||
"@types/react-dom": 19.2.3
|
||||
"@types/react-syntax-highlighter": 15.5.13
|
||||
"@types/react-window": 1.8.8
|
||||
"@types/sortablejs": 1.15.9
|
||||
"@typescript-eslint/eslint-plugin": 8.57.2
|
||||
"@typescript-eslint/parser": 8.57.2
|
||||
"@typescript/native-preview": 7.0.0-dev.20260329.1
|
||||
"@typescript-eslint/eslint-plugin": 8.58.0
|
||||
"@typescript-eslint/parser": 8.58.0
|
||||
"@typescript/native-preview": 7.0.0-dev.20260401.1
|
||||
"@vitejs/plugin-react": 6.0.1
|
||||
"@vitejs/plugin-rsc": 0.5.21
|
||||
"@vitest/coverage-v8": 4.1.1
|
||||
"@vitest/coverage-v8": 4.1.2
|
||||
abcjs: 6.6.2
|
||||
agentation: 3.0.2
|
||||
ahooks: 3.9.7
|
||||
@@ -157,7 +162,7 @@ catalog:
|
||||
class-variance-authority: 0.7.1
|
||||
clsx: 2.1.1
|
||||
cmdk: 1.1.1
|
||||
code-inspector-plugin: 1.4.5
|
||||
code-inspector-plugin: 1.5.1
|
||||
copy-to-clipboard: 3.3.3
|
||||
cron-parser: 5.5.0
|
||||
dayjs: 1.11.20
|
||||
@@ -174,19 +179,19 @@ catalog:
|
||||
eslint-markdown: 0.6.0
|
||||
eslint-plugin-better-tailwindcss: 4.3.2
|
||||
eslint-plugin-hyoban: 0.14.1
|
||||
eslint-plugin-markdown-preferences: 0.40.3
|
||||
eslint-plugin-markdown-preferences: 0.41.0
|
||||
eslint-plugin-no-barrel-files: 1.2.2
|
||||
eslint-plugin-react-hooks: 7.0.1
|
||||
eslint-plugin-react-refresh: 0.5.2
|
||||
eslint-plugin-sonarjs: 4.0.2
|
||||
eslint-plugin-storybook: 10.3.3
|
||||
eslint-plugin-storybook: 10.3.4
|
||||
fast-deep-equal: 3.1.3
|
||||
foxact: 0.3.0
|
||||
happy-dom: 20.8.9
|
||||
hono: 4.12.9
|
||||
hast-util-to-jsx-runtime: 2.3.6
|
||||
hono: 4.12.10
|
||||
html-entities: 2.6.0
|
||||
html-to-image: 1.11.13
|
||||
i18next: 25.10.10
|
||||
i18next: 26.0.3
|
||||
i18next-resources-to-backend: 1.2.1
|
||||
iconify-import-svg: 0.1.2
|
||||
immer: 11.1.4
|
||||
@@ -196,15 +201,15 @@ catalog:
|
||||
js-yaml: 4.1.1
|
||||
jsonschema: 1.5.0
|
||||
katex: 0.16.44
|
||||
knip: 6.1.0
|
||||
knip: 6.2.0
|
||||
ky: 1.14.3
|
||||
lamejs: 1.2.1
|
||||
lexical: 0.42.0
|
||||
mermaid: 11.13.0
|
||||
mermaid: 11.14.0
|
||||
mime: 4.1.0
|
||||
mitt: 3.0.1
|
||||
negotiator: 1.0.0
|
||||
next: 16.2.1
|
||||
next: 16.2.2
|
||||
next-themes: 0.4.6
|
||||
nuqs: 2.8.9
|
||||
pinyin-pro: 3.28.0
|
||||
@@ -217,42 +222,39 @@ catalog:
|
||||
react-dom: 19.2.4
|
||||
react-easy-crop: 5.5.7
|
||||
react-hotkeys-hook: 5.2.4
|
||||
react-i18next: 16.6.6
|
||||
react-i18next: 17.0.2
|
||||
react-multi-email: 1.0.25
|
||||
react-papaparse: 4.4.0
|
||||
react-pdf-highlighter: 8.0.0-rc.0
|
||||
react-server-dom-webpack: 19.2.4
|
||||
react-sortablejs: 6.1.4
|
||||
react-syntax-highlighter: 15.6.6
|
||||
react-textarea-autosize: 8.5.9
|
||||
react-window: 1.8.11
|
||||
reactflow: 11.11.4
|
||||
remark-breaks: 4.0.0
|
||||
remark-directive: 4.0.0
|
||||
sass: 1.98.0
|
||||
scheduler: 0.27.0
|
||||
sharp: 0.34.5
|
||||
shiki: 4.0.2
|
||||
sortablejs: 1.15.7
|
||||
std-semver: 1.0.8
|
||||
storybook: 10.3.3
|
||||
storybook: 10.3.4
|
||||
streamdown: 2.5.0
|
||||
string-ts: 2.3.1
|
||||
tailwind-merge: 3.5.0
|
||||
tailwindcss: 4.2.2
|
||||
taze: 19.10.0
|
||||
tldts: 7.0.27
|
||||
tsup: ^8.5.1
|
||||
tsdown: 0.21.7
|
||||
tsx: 4.21.0
|
||||
typescript: 5.9.3
|
||||
typescript: 6.0.2
|
||||
uglify-js: 3.19.3
|
||||
unist-util-visit: 5.1.0
|
||||
use-context-selector: 2.0.0
|
||||
uuid: 13.0.0
|
||||
vinext: 0.0.38
|
||||
vite: npm:@voidzero-dev/vite-plus-core@0.1.14
|
||||
vinext: 0.0.39
|
||||
vite: npm:@voidzero-dev/vite-plus-core@0.1.15
|
||||
vite-plugin-inspect: 12.0.0-beta.1
|
||||
vite-plus: 0.1.14
|
||||
vitest: npm:@voidzero-dev/vite-plus-test@0.1.14
|
||||
vite-plus: 0.1.15
|
||||
vitest: npm:@voidzero-dev/vite-plus-test@0.1.15
|
||||
vitest-canvas-mock: 1.1.4
|
||||
zod: 4.3.6
|
||||
zundo: 2.3.0
|
||||
|
||||
@@ -45,12 +45,12 @@
|
||||
"homepage": "https://dify.ai",
|
||||
"license": "MIT",
|
||||
"scripts": {
|
||||
"build": "tsup",
|
||||
"build": "vp pack",
|
||||
"lint": "eslint",
|
||||
"lint:fix": "eslint --fix",
|
||||
"type-check": "tsc -p tsconfig.json --noEmit",
|
||||
"test": "vitest run",
|
||||
"test:coverage": "vitest run --coverage",
|
||||
"test": "vp test",
|
||||
"test:coverage": "vp test --coverage",
|
||||
"publish:check": "./scripts/publish.sh --dry-run",
|
||||
"publish:npm": "./scripts/publish.sh"
|
||||
},
|
||||
@@ -61,8 +61,8 @@
|
||||
"@typescript-eslint/parser": "catalog:",
|
||||
"@vitest/coverage-v8": "catalog:",
|
||||
"eslint": "catalog:",
|
||||
"tsup": "catalog:",
|
||||
"typescript": "catalog:",
|
||||
"vite-plus": "catalog:",
|
||||
"vitest": "catalog:"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,8 @@
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"skipLibCheck": true
|
||||
"skipLibCheck": true,
|
||||
"types": ["node"]
|
||||
},
|
||||
"include": ["src/**/*.ts", "tests/**/*.ts"]
|
||||
}
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
import { defineConfig } from "tsup";
|
||||
|
||||
export default defineConfig({
|
||||
entry: ["src/index.ts"],
|
||||
format: ["esm"],
|
||||
dts: true,
|
||||
clean: true,
|
||||
sourcemap: true,
|
||||
splitting: false,
|
||||
treeshake: true,
|
||||
outDir: "dist",
|
||||
});
|
||||
@@ -1,6 +1,17 @@
|
||||
import { defineConfig } from "vitest/config";
|
||||
import { defineConfig } from "vite-plus";
|
||||
|
||||
export default defineConfig({
|
||||
pack: {
|
||||
entry: ["src/index.ts"],
|
||||
format: ["esm"],
|
||||
dts: true,
|
||||
clean: true,
|
||||
sourcemap: true,
|
||||
// splitting: false,
|
||||
treeshake: true,
|
||||
outDir: "dist",
|
||||
target: false,
|
||||
},
|
||||
test: {
|
||||
environment: "node",
|
||||
include: ["**/*.test.ts"],
|
||||
@@ -1,15 +0,0 @@
|
||||
import { defineConfig } from 'taze'
|
||||
|
||||
export default defineConfig({
|
||||
exclude: [
|
||||
// We are going to replace these
|
||||
'react-syntax-highlighter',
|
||||
'react-window',
|
||||
'@types/react-window',
|
||||
|
||||
// We can not upgrade these yet
|
||||
'typescript',
|
||||
],
|
||||
|
||||
maturityPeriod: 2,
|
||||
})
|
||||
36
web/__mocks__/@tanstack/react-virtual.ts
Normal file
36
web/__mocks__/@tanstack/react-virtual.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
import { vi } from 'vitest'
|
||||
|
||||
const mockVirtualizer = ({
|
||||
count,
|
||||
estimateSize,
|
||||
}: {
|
||||
count: number
|
||||
estimateSize?: (index: number) => number
|
||||
}) => {
|
||||
const getSize = (index: number) => estimateSize?.(index) ?? 0
|
||||
|
||||
return {
|
||||
getTotalSize: () => Array.from({ length: count }).reduce<number>((total, _, index) => total + getSize(index), 0),
|
||||
getVirtualItems: () => {
|
||||
let start = 0
|
||||
|
||||
return Array.from({ length: count }).map((_, index) => {
|
||||
const size = getSize(index)
|
||||
const virtualItem = {
|
||||
end: start + size,
|
||||
index,
|
||||
key: index,
|
||||
size,
|
||||
start,
|
||||
}
|
||||
|
||||
start += size
|
||||
return virtualItem
|
||||
})
|
||||
},
|
||||
measureElement: vi.fn(),
|
||||
scrollToIndex: vi.fn(),
|
||||
}
|
||||
}
|
||||
|
||||
export { mockVirtualizer as useVirtualizer }
|
||||
@@ -1,10 +1,10 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { useContext } from 'react'
|
||||
import { use } from 'react'
|
||||
import { FeaturesContext, FeaturesProvider } from '../context'
|
||||
|
||||
const TestConsumer = () => {
|
||||
const store = useContext(FeaturesContext)
|
||||
const store = use(FeaturesContext)
|
||||
if (!store)
|
||||
return <div>no store</div>
|
||||
|
||||
@@ -34,10 +34,10 @@ describe('FeaturesProvider', () => {
|
||||
})
|
||||
|
||||
it('should maintain the same store reference across re-renders', () => {
|
||||
const storeRefs: Array<ReturnType<typeof useContext>> = []
|
||||
const storeRefs: Array<React.ContextType<typeof FeaturesContext>> = []
|
||||
|
||||
const StoreRefCollector = () => {
|
||||
const store = useContext(FeaturesContext)
|
||||
const store = use(FeaturesContext)
|
||||
storeRefs.push(store)
|
||||
return null
|
||||
}
|
||||
|
||||
@@ -5,6 +5,10 @@ import { Theme } from '@/types/app'
|
||||
|
||||
import CodeBlock from '../code-block'
|
||||
|
||||
const { mockHighlightCode } = vi.hoisted(() => ({
|
||||
mockHighlightCode: vi.fn(),
|
||||
}))
|
||||
|
||||
type UseThemeReturn = {
|
||||
theme: Theme
|
||||
}
|
||||
@@ -70,6 +74,10 @@ vi.mock('@/hooks/use-theme', () => ({
|
||||
default: () => mockUseTheme(),
|
||||
}))
|
||||
|
||||
vi.mock('../shiki-highlight', () => ({
|
||||
highlightCode: mockHighlightCode,
|
||||
}))
|
||||
|
||||
vi.mock('echarts', () => ({
|
||||
getInstanceByDom: mockEcharts.getInstanceByDom,
|
||||
}))
|
||||
@@ -130,6 +138,11 @@ describe('CodeBlock', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseTheme.mockReturnValue({ theme: Theme.light })
|
||||
mockHighlightCode.mockImplementation(async ({ code, language }) => (
|
||||
<pre className="shiki">
|
||||
<code className={`language-${language}`}>{code}</code>
|
||||
</pre>
|
||||
))
|
||||
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
clientWidthSpy = vi.spyOn(HTMLElement.prototype, 'clientWidth', 'get').mockReturnValue(900)
|
||||
@@ -198,11 +211,13 @@ describe('CodeBlock', () => {
|
||||
expect(container.querySelector('code')?.textContent).toBe('plain text')
|
||||
})
|
||||
|
||||
it('should render syntax-highlighted output when language is standard', () => {
|
||||
it('should render syntax-highlighted output when language is standard', async () => {
|
||||
render(<CodeBlock className="language-javascript">const x = 1;</CodeBlock>)
|
||||
|
||||
expect(screen.getByText('JavaScript')).toBeInTheDocument()
|
||||
expect(document.querySelector('code.language-javascript')?.textContent).toContain('const x = 1;')
|
||||
await waitFor(() => {
|
||||
expect(document.querySelector('code.language-javascript')?.textContent).toContain('const x = 1;')
|
||||
})
|
||||
})
|
||||
|
||||
it('should format unknown language labels with capitalized fallback when language is not in map', () => {
|
||||
@@ -242,13 +257,26 @@ describe('CodeBlock', () => {
|
||||
expect(screen.queryByText(/Error rendering SVG/i)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render syntax-highlighted output when language is standard and app theme is dark', () => {
|
||||
it('should render syntax-highlighted output when language is standard and app theme is dark', async () => {
|
||||
mockUseTheme.mockReturnValue({ theme: Theme.dark })
|
||||
|
||||
render(<CodeBlock className="language-javascript">const y = 2;</CodeBlock>)
|
||||
|
||||
expect(screen.getByText('JavaScript')).toBeInTheDocument()
|
||||
expect(document.querySelector('code.language-javascript')?.textContent).toContain('const y = 2;')
|
||||
await waitFor(() => {
|
||||
expect(document.querySelector('code.language-javascript')?.textContent).toContain('const y = 2;')
|
||||
})
|
||||
})
|
||||
|
||||
it('should fall back to plain code block when shiki highlighting fails', async () => {
|
||||
mockHighlightCode.mockRejectedValueOnce(new Error('highlight failed'))
|
||||
|
||||
render(<CodeBlock className="language-javascript">const z = 3;</CodeBlock>)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('const z = 3;')).toBeInTheDocument()
|
||||
})
|
||||
expect(document.querySelector('code.language-javascript')).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import type { JSX } from 'react'
|
||||
import type { BundledLanguage, BundledTheme } from 'shiki/bundle/web'
|
||||
import ReactEcharts from 'echarts-for-react'
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import SyntaxHighlighter from 'react-syntax-highlighter'
|
||||
import {
|
||||
atelierHeathDark,
|
||||
atelierHeathLight,
|
||||
} from 'react-syntax-highlighter/dist/esm/styles/hljs'
|
||||
import { memo, useCallback, useEffect, useLayoutEffect, useMemo, useRef, useState } from 'react'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
import CopyIcon from '@/app/components/base/copy-icon'
|
||||
import MarkdownMusic from '@/app/components/base/markdown-blocks/music'
|
||||
@@ -14,10 +11,10 @@ import useTheme from '@/hooks/use-theme'
|
||||
import dynamic from '@/next/dynamic'
|
||||
import { Theme } from '@/types/app'
|
||||
import SVGRenderer from '../svg-gallery' // Assumes svg-gallery.tsx is in /base directory
|
||||
import { highlightCode } from './shiki-highlight'
|
||||
|
||||
const Flowchart = dynamic(() => import('@/app/components/base/mermaid'), { ssr: false })
|
||||
|
||||
// Available language https://github.com/react-syntax-highlighter/react-syntax-highlighter/blob/master/AVAILABLE_LANGUAGES_HLJS.MD
|
||||
const capitalizationLanguageNameMap: Record<string, string> = {
|
||||
sql: 'SQL',
|
||||
javascript: 'JavaScript',
|
||||
@@ -64,6 +61,61 @@ const getCorrectCapitalizationLanguageName = (language: string) => {
|
||||
// visit https://reactjs.org/docs/error-decoder.html?invariant=185 for the full message
|
||||
// or use the non-minified dev environment for full errors and additional helpful warnings.
|
||||
|
||||
const ShikiCodeBlock = memo(({ code, language, theme, initial }: { code: string, language: string, theme: BundledTheme, initial?: JSX.Element }) => {
|
||||
const [nodes, setNodes] = useState(initial)
|
||||
|
||||
useLayoutEffect(() => {
|
||||
let cancelled = false
|
||||
|
||||
void highlightCode({
|
||||
code,
|
||||
language: language as BundledLanguage,
|
||||
theme,
|
||||
}).then((result) => {
|
||||
if (!cancelled)
|
||||
setNodes(result)
|
||||
}).catch((error) => {
|
||||
console.error('Shiki highlighting failed:', error)
|
||||
if (!cancelled)
|
||||
setNodes(undefined)
|
||||
})
|
||||
|
||||
return () => {
|
||||
cancelled = true
|
||||
}
|
||||
}, [code, language, theme])
|
||||
|
||||
if (!nodes) {
|
||||
return (
|
||||
<pre style={{
|
||||
paddingLeft: 12,
|
||||
borderBottomLeftRadius: '10px',
|
||||
borderBottomRightRadius: '10px',
|
||||
backgroundColor: 'var(--color-components-input-bg-normal)',
|
||||
margin: 0,
|
||||
overflow: 'auto',
|
||||
}}
|
||||
>
|
||||
<code>{code}</code>
|
||||
</pre>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
style={{
|
||||
borderBottomLeftRadius: '10px',
|
||||
borderBottomRightRadius: '10px',
|
||||
overflow: 'auto',
|
||||
}}
|
||||
className="shiki-line-numbers [&_pre]:m-0! [&_pre]:rounded-t-none! [&_pre]:rounded-b-[10px]! [&_pre]:bg-components-input-bg-normal! [&_pre]:py-2!"
|
||||
>
|
||||
{nodes}
|
||||
</div>
|
||||
)
|
||||
})
|
||||
ShikiCodeBlock.displayName = 'ShikiCodeBlock'
|
||||
|
||||
// Define ECharts event parameter types
|
||||
type EChartsEventParams = {
|
||||
type: string
|
||||
@@ -416,20 +468,11 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any
|
||||
)
|
||||
default:
|
||||
return (
|
||||
<SyntaxHighlighter
|
||||
{...props}
|
||||
style={theme === Theme.light ? atelierHeathLight : atelierHeathDark}
|
||||
customStyle={{
|
||||
paddingLeft: 12,
|
||||
borderBottomLeftRadius: '10px',
|
||||
borderBottomRightRadius: '10px',
|
||||
backgroundColor: 'var(--color-components-input-bg-normal)',
|
||||
}}
|
||||
language={match?.[1]}
|
||||
showLineNumbers
|
||||
>
|
||||
{content}
|
||||
</SyntaxHighlighter>
|
||||
<ShikiCodeBlock
|
||||
code={content}
|
||||
language={match?.[1] || 'text'}
|
||||
theme={isDarkMode ? 'github-dark' : 'github-light'}
|
||||
/>
|
||||
)
|
||||
}
|
||||
}, [children, language, isSVG, finalChartOption, props, theme, match, chartState, isDarkMode, echartsStyle, echartsOpts, handleChartReady, echartsEvents])
|
||||
@@ -440,7 +483,7 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any
|
||||
return (
|
||||
<div className="relative">
|
||||
<div className="flex h-8 items-center justify-between rounded-t-[10px] border-b border-divider-subtle bg-components-input-bg-normal p-1 pl-3">
|
||||
<div className="text-text-secondary system-xs-semibold-uppercase">{languageShowName}</div>
|
||||
<div className="system-xs-semibold-uppercase text-text-secondary">{languageShowName}</div>
|
||||
<div className="flex items-center gap-1">
|
||||
{language === 'svg' && <SVGBtn isSVG={isSVG} setIsSVG={setIsSVG} />}
|
||||
<ActionButton>
|
||||
|
||||
29
web/app/components/base/markdown-blocks/shiki-highlight.tsx
Normal file
29
web/app/components/base/markdown-blocks/shiki-highlight.tsx
Normal file
@@ -0,0 +1,29 @@
|
||||
import type { JSX } from 'react'
|
||||
import type { BundledLanguage, BundledTheme } from 'shiki/bundle/web'
|
||||
import { toJsxRuntime } from 'hast-util-to-jsx-runtime'
|
||||
import { Fragment } from 'react'
|
||||
import { jsx, jsxs } from 'react/jsx-runtime'
|
||||
import { codeToHast } from 'shiki/bundle/web'
|
||||
|
||||
type HighlightCodeOptions = {
|
||||
code: string
|
||||
language: BundledLanguage
|
||||
theme: BundledTheme
|
||||
}
|
||||
|
||||
export const highlightCode = async ({
|
||||
code,
|
||||
language,
|
||||
theme,
|
||||
}: HighlightCodeOptions): Promise<JSX.Element> => {
|
||||
const hast = await codeToHast(code, {
|
||||
lang: language,
|
||||
theme,
|
||||
})
|
||||
|
||||
return toJsxRuntime(hast, {
|
||||
Fragment,
|
||||
jsx,
|
||||
jsxs,
|
||||
}) as JSX.Element
|
||||
}
|
||||
@@ -9,6 +9,8 @@ import { useModalContextSelector } from '@/context/modal-context'
|
||||
import { useInvalidPreImportNotionPages, usePreImportNotionPages } from '@/service/knowledge/use-import'
|
||||
import NotionPageSelector from '../base'
|
||||
|
||||
vi.mock('@tanstack/react-virtual')
|
||||
|
||||
vi.mock('@/service/knowledge/use-import', () => ({
|
||||
usePreImportNotionPages: vi.fn(),
|
||||
useInvalidPreImportNotionPages: vi.fn(),
|
||||
@@ -183,7 +185,7 @@ describe('NotionPageSelector Base', () => {
|
||||
const user = userEvent.setup()
|
||||
render(<NotionPageSelector credentialList={mockCredentialList} onSelect={vi.fn()} />)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'Configure Notion' }))
|
||||
await user.click(screen.getByRole('button', { name: 'common.dataSource.notion.selector.configure' }))
|
||||
expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({ payload: ACCOUNT_SETTING_TAB.DATA_SOURCE })
|
||||
})
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import type { DataSourceCredential } from '../../header/account-setting/data-sou
|
||||
import type { NotionCredential } from './credential-selector'
|
||||
import type { DataSourceNotionPageMap, DataSourceNotionWorkspace, NotionPage } from '@/models/common'
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
|
||||
import { useModalContextSelector } from '@/context/modal-context'
|
||||
import { useInvalidPreImportNotionPages, usePreImportNotionPages } from '@/service/knowledge/use-import'
|
||||
@@ -33,6 +34,7 @@ const NotionPageSelector = ({
|
||||
credentialList,
|
||||
onSelectCredential,
|
||||
}: NotionPageSelectorProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [searchValue, setSearchValue] = useState('')
|
||||
const setShowAccountSettingModal = useModalContextSelector(s => s.setShowAccountSettingModal)
|
||||
|
||||
@@ -48,27 +50,34 @@ const NotionPageSelector = ({
|
||||
}
|
||||
})
|
||||
}, [credentialList])
|
||||
const [currentCredential, setCurrentCredential] = useState(notionCredentials[0])
|
||||
const [selectedCredentialId, setSelectedCredentialId] = useState(() => notionCredentials[0]?.credentialId ?? '')
|
||||
const currentCredential = useMemo(() => {
|
||||
return notionCredentials.find(item => item.credentialId === selectedCredentialId) ?? notionCredentials[0] ?? null
|
||||
}, [notionCredentials, selectedCredentialId])
|
||||
const currentCredentialId = currentCredential?.credentialId ?? ''
|
||||
|
||||
useEffect(() => {
|
||||
const credential = notionCredentials.find(item => item.credentialId === currentCredential?.credentialId)
|
||||
if (!credential) {
|
||||
const firstCredential = notionCredentials[0]
|
||||
invalidPreImportNotionPages({ datasetId, credentialId: firstCredential.credentialId })
|
||||
setCurrentCredential(notionCredentials[0])
|
||||
onSelect([]) // Clear selected pages when changing credential
|
||||
onSelectCredential?.(firstCredential.credentialId)
|
||||
onSelectCredential?.(currentCredentialId)
|
||||
}, [currentCredentialId, onSelectCredential])
|
||||
|
||||
useEffect(() => {
|
||||
if (!notionCredentials.length) {
|
||||
onSelect([])
|
||||
return
|
||||
}
|
||||
else {
|
||||
onSelectCredential?.(credential?.credentialId || '')
|
||||
}
|
||||
}, [notionCredentials])
|
||||
|
||||
if (!selectedCredentialId || selectedCredentialId === currentCredentialId)
|
||||
return
|
||||
|
||||
invalidPreImportNotionPages({ datasetId, credentialId: currentCredentialId })
|
||||
onSelect([])
|
||||
}, [currentCredentialId, datasetId, invalidPreImportNotionPages, notionCredentials.length, onSelect, selectedCredentialId])
|
||||
|
||||
const {
|
||||
data: notionsPages,
|
||||
isFetching: isFetchingNotionPages,
|
||||
isError: isFetchingNotionPagesError,
|
||||
} = usePreImportNotionPages({ datasetId, credentialId: currentCredential.credentialId || '' })
|
||||
} = usePreImportNotionPages({ datasetId, credentialId: currentCredentialId })
|
||||
|
||||
const pagesMapAndSelectedPagesId: [DataSourceNotionPageMap, Set<string>, Set<string>] = useMemo(() => {
|
||||
const selectedPagesId = new Set<string>()
|
||||
@@ -94,28 +103,24 @@ const NotionPageSelector = ({
|
||||
const defaultSelectedPagesId = useMemo(() => {
|
||||
return [...Array.from(pagesMapAndSelectedPagesId[1]), ...(value || [])]
|
||||
}, [pagesMapAndSelectedPagesId, value])
|
||||
const [selectedPagesId, setSelectedPagesId] = useState<Set<string>>(() => new Set(defaultSelectedPagesId))
|
||||
|
||||
useEffect(() => {
|
||||
setSelectedPagesId(new Set(defaultSelectedPagesId))
|
||||
}, [defaultSelectedPagesId])
|
||||
const selectedPagesId = useMemo(() => new Set(defaultSelectedPagesId), [defaultSelectedPagesId])
|
||||
|
||||
const handleSearchValueChange = useCallback((value: string) => {
|
||||
setSearchValue(value)
|
||||
}, [])
|
||||
|
||||
const handleSelectCredential = useCallback((credentialId: string) => {
|
||||
const credential = notionCredentials.find(item => item.credentialId === credentialId)!
|
||||
invalidPreImportNotionPages({ datasetId, credentialId: credential.credentialId })
|
||||
setCurrentCredential(credential)
|
||||
if (credentialId === currentCredentialId)
|
||||
return
|
||||
|
||||
invalidPreImportNotionPages({ datasetId, credentialId })
|
||||
setSelectedCredentialId(credentialId)
|
||||
onSelect([]) // Clear selected pages when changing credential
|
||||
onSelectCredential?.(credential.credentialId)
|
||||
}, [datasetId, invalidPreImportNotionPages, notionCredentials, onSelect, onSelectCredential])
|
||||
}, [currentCredentialId, datasetId, invalidPreImportNotionPages, onSelect])
|
||||
|
||||
const handleSelectPages = useCallback((newSelectedPagesId: Set<string>) => {
|
||||
const selectedPages = Array.from(newSelectedPagesId).map(pageId => pagesMapAndSelectedPagesId[0][pageId])
|
||||
|
||||
setSelectedPagesId(new Set(Array.from(newSelectedPagesId)))
|
||||
onSelect(selectedPages)
|
||||
}, [pagesMapAndSelectedPagesId, onSelect])
|
||||
|
||||
@@ -140,16 +145,16 @@ const NotionPageSelector = ({
|
||||
<div className="flex flex-col gap-y-2" data-testid="notion-page-selector-base">
|
||||
<Header
|
||||
onClickConfiguration={handleConfigureNotion}
|
||||
title="Choose notion pages"
|
||||
buttonText="Configure Notion"
|
||||
docTitle="Notion docs"
|
||||
title={t('dataSource.notion.selector.headerTitle', { ns: 'common' })}
|
||||
buttonText={t('dataSource.notion.selector.configure', { ns: 'common' })}
|
||||
docTitle={t('dataSource.notion.selector.docs', { ns: 'common' })}
|
||||
docLink="https://www.notion.so/docs"
|
||||
/>
|
||||
<div className="rounded-xl border border-components-panel-border bg-background-default-subtle">
|
||||
<div className="flex h-12 items-center gap-x-2 rounded-t-xl border-b border-b-divider-regular bg-components-panel-bg p-2">
|
||||
<div className="flex grow items-center gap-x-1">
|
||||
<WorkspaceSelector
|
||||
value={currentCredential.credentialId}
|
||||
value={currentCredentialId}
|
||||
items={notionCredentials}
|
||||
onSelect={handleSelectCredential}
|
||||
/>
|
||||
@@ -168,6 +173,7 @@ const NotionPageSelector = ({
|
||||
)
|
||||
: (
|
||||
<PageSelector
|
||||
key={currentCredentialId || 'default'}
|
||||
value={selectedPagesId}
|
||||
disabledValue={pagesMapAndSelectedPagesId[2]}
|
||||
searchValue={searchValue}
|
||||
|
||||
@@ -3,6 +3,8 @@ import { render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import PageSelector from '../index'
|
||||
|
||||
vi.mock('@tanstack/react-virtual')
|
||||
|
||||
const buildPage = (overrides: Partial<DataSourceNotionPage>): DataSourceNotionPage => ({
|
||||
page_id: 'page-id',
|
||||
page_name: 'Page name',
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
import type { ListChildComponentProps } from 'react-window'
|
||||
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
|
||||
import { memo, useEffect, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { areEqual, FixedSizeList as List } from 'react-window'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Checkbox from '../../checkbox'
|
||||
import NotionIcon from '../../notion-icon'
|
||||
import { usePageSelectorModel } from './use-page-selector-model'
|
||||
import VirtualPageList from './virtual-page-list'
|
||||
|
||||
type PageSelectorProps = {
|
||||
value: Set<string>
|
||||
@@ -17,173 +13,7 @@ type PageSelectorProps = {
|
||||
canPreview?: boolean
|
||||
previewPageId?: string
|
||||
onPreview?: (selectedPageId: string) => void
|
||||
isMultipleChoice?: boolean
|
||||
}
|
||||
type NotionPageTreeItem = {
|
||||
children: Set<string>
|
||||
descendants: Set<string>
|
||||
depth: number
|
||||
ancestors: string[]
|
||||
} & DataSourceNotionPage
|
||||
type NotionPageTreeMap = Record<string, NotionPageTreeItem>
|
||||
type NotionPageItem = {
|
||||
expand: boolean
|
||||
depth: number
|
||||
} & DataSourceNotionPage
|
||||
|
||||
const recursivePushInParentDescendants = (
|
||||
pagesMap: DataSourceNotionPageMap,
|
||||
listTreeMap: NotionPageTreeMap,
|
||||
current: NotionPageTreeItem,
|
||||
leafItem: NotionPageTreeItem,
|
||||
) => {
|
||||
const parentId = current.parent_id
|
||||
const pageId = current.page_id
|
||||
|
||||
if (!parentId || !pageId)
|
||||
return
|
||||
|
||||
if (parentId !== 'root' && pagesMap[parentId]) {
|
||||
if (!listTreeMap[parentId]) {
|
||||
const children = new Set([pageId])
|
||||
const descendants = new Set([pageId, leafItem.page_id])
|
||||
listTreeMap[parentId] = {
|
||||
...pagesMap[parentId],
|
||||
children,
|
||||
descendants,
|
||||
depth: 0,
|
||||
ancestors: [],
|
||||
}
|
||||
}
|
||||
else {
|
||||
listTreeMap[parentId].children.add(pageId)
|
||||
listTreeMap[parentId].descendants.add(pageId)
|
||||
listTreeMap[parentId].descendants.add(leafItem.page_id)
|
||||
}
|
||||
leafItem.depth++
|
||||
leafItem.ancestors.unshift(listTreeMap[parentId].page_name)
|
||||
|
||||
if (listTreeMap[parentId].parent_id !== 'root')
|
||||
recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap[parentId], leafItem)
|
||||
}
|
||||
}
|
||||
|
||||
const ItemComponent = ({ index, style, data }: ListChildComponentProps<{
|
||||
dataList: NotionPageItem[]
|
||||
handleToggle: (index: number) => void
|
||||
checkedIds: Set<string>
|
||||
disabledCheckedIds: Set<string>
|
||||
handleCheck: (index: number) => void
|
||||
canPreview?: boolean
|
||||
handlePreview: (index: number) => void
|
||||
listMapWithChildrenAndDescendants: NotionPageTreeMap
|
||||
searchValue: string
|
||||
previewPageId: string
|
||||
pagesMap: DataSourceNotionPageMap
|
||||
}>) => {
|
||||
const { t } = useTranslation()
|
||||
const {
|
||||
dataList,
|
||||
handleToggle,
|
||||
checkedIds,
|
||||
disabledCheckedIds,
|
||||
handleCheck,
|
||||
canPreview,
|
||||
handlePreview,
|
||||
listMapWithChildrenAndDescendants,
|
||||
searchValue,
|
||||
previewPageId,
|
||||
pagesMap,
|
||||
} = data
|
||||
const current = dataList[index]
|
||||
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[current.page_id]
|
||||
const hasChild = currentWithChildrenAndDescendants.descendants.size > 0
|
||||
const ancestors = currentWithChildrenAndDescendants.ancestors
|
||||
const breadCrumbs = ancestors.length ? [...ancestors, current.page_name] : [current.page_name]
|
||||
const disabled = disabledCheckedIds.has(current.page_id)
|
||||
|
||||
const renderArrow = () => {
|
||||
if (hasChild) {
|
||||
return (
|
||||
<div
|
||||
className="mr-1 flex h-5 w-5 shrink-0 items-center justify-center rounded-md hover:bg-components-button-ghost-bg-hover"
|
||||
style={{ marginLeft: current.depth * 8 }}
|
||||
onClick={() => handleToggle(index)}
|
||||
data-testid={`notion-page-toggle-${current.page_id}`}
|
||||
>
|
||||
{
|
||||
current.expand
|
||||
? <div className="i-ri-arrow-down-s-line h-4 w-4 text-text-tertiary" />
|
||||
: <div className="i-ri-arrow-right-s-line h-4 w-4 text-text-tertiary" />
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
if (current.parent_id === 'root' || !pagesMap[current.parent_id]) {
|
||||
return (
|
||||
<div></div>
|
||||
)
|
||||
}
|
||||
return (
|
||||
<div className="mr-1 h-5 w-5 shrink-0" style={{ marginLeft: current.depth * 8 }} />
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn('group flex cursor-pointer items-center rounded-md pl-2 pr-[2px] hover:bg-state-base-hover', previewPageId === current.page_id && 'bg-state-base-hover')}
|
||||
style={{ ...style, top: style.top as number + 8, left: 8, right: 8, width: 'calc(100% - 16px)' }}
|
||||
data-testid={`notion-page-row-${current.page_id}`}
|
||||
>
|
||||
<Checkbox
|
||||
className="mr-2 shrink-0"
|
||||
checked={checkedIds.has(current.page_id)}
|
||||
disabled={disabled}
|
||||
onCheck={() => {
|
||||
handleCheck(index)
|
||||
}}
|
||||
id={`notion-page-checkbox-${current.page_id}`}
|
||||
/>
|
||||
{!searchValue && renderArrow()}
|
||||
<NotionIcon
|
||||
className="mr-1 shrink-0"
|
||||
type="page"
|
||||
src={current.page_icon}
|
||||
/>
|
||||
<div
|
||||
className="grow truncate text-[13px] font-medium leading-4 text-text-secondary"
|
||||
title={current.page_name}
|
||||
data-testid={`notion-page-name-${current.page_id}`}
|
||||
>
|
||||
{current.page_name}
|
||||
</div>
|
||||
{
|
||||
canPreview && (
|
||||
<div
|
||||
className="ml-1 hidden h-6 shrink-0 cursor-pointer items-center rounded-md border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-2 text-xs
|
||||
font-medium leading-4 text-components-button-secondary-text shadow-xs shadow-shadow-shadow-3 backdrop-blur-[10px]
|
||||
hover:border-components-button-secondary-border-hover hover:bg-components-button-secondary-bg-hover group-hover:flex"
|
||||
onClick={() => handlePreview(index)}
|
||||
data-testid={`notion-page-preview-${current.page_id}`}
|
||||
>
|
||||
{t('dataSource.notion.selector.preview', { ns: 'common' })}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
searchValue && (
|
||||
<div
|
||||
className="ml-1 max-w-[120px] shrink-0 truncate text-xs text-text-quaternary"
|
||||
title={breadCrumbs.join(' / ')}
|
||||
>
|
||||
{breadCrumbs.join(' / ')}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
const Item = memo(ItemComponent, areEqual)
|
||||
|
||||
const PageSelector = ({
|
||||
value,
|
||||
@@ -197,108 +27,25 @@ const PageSelector = ({
|
||||
onPreview,
|
||||
}: PageSelectorProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [dataList, setDataList] = useState<NotionPageItem[]>([])
|
||||
const [localPreviewPageId, setLocalPreviewPageId] = useState('')
|
||||
|
||||
useEffect(() => {
|
||||
setDataList(list.filter(item => item.parent_id === 'root' || !pagesMap[item.parent_id]).map((item) => {
|
||||
return {
|
||||
...item,
|
||||
expand: false,
|
||||
depth: 0,
|
||||
}
|
||||
}))
|
||||
}, [list])
|
||||
|
||||
const searchDataList = list.filter((item) => {
|
||||
return item.page_name.includes(searchValue)
|
||||
}).map((item) => {
|
||||
return {
|
||||
...item,
|
||||
expand: false,
|
||||
depth: 0,
|
||||
}
|
||||
const {
|
||||
currentPreviewPageId,
|
||||
effectiveSearchValue,
|
||||
rows,
|
||||
handlePreview,
|
||||
handleSelect,
|
||||
handleToggle,
|
||||
} = usePageSelectorModel({
|
||||
checkedIds: value,
|
||||
list,
|
||||
onPreview,
|
||||
onSelect,
|
||||
pagesMap,
|
||||
previewPageId,
|
||||
searchValue,
|
||||
selectionMode: 'multiple',
|
||||
})
|
||||
const currentDataList = searchValue ? searchDataList : dataList
|
||||
const currentPreviewPageId = previewPageId === undefined ? localPreviewPageId : previewPageId
|
||||
|
||||
const listMapWithChildrenAndDescendants = useMemo(() => {
|
||||
return list.reduce((prev: NotionPageTreeMap, next: DataSourceNotionPage) => {
|
||||
const pageId = next.page_id
|
||||
if (!prev[pageId])
|
||||
prev[pageId] = { ...next, children: new Set(), descendants: new Set(), depth: 0, ancestors: [] }
|
||||
|
||||
recursivePushInParentDescendants(pagesMap, prev, prev[pageId], prev[pageId])
|
||||
return prev
|
||||
}, {})
|
||||
}, [list, pagesMap])
|
||||
|
||||
const handleToggle = (index: number) => {
|
||||
const current = dataList[index]
|
||||
const pageId = current.page_id
|
||||
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId]
|
||||
const descendantsIds = Array.from(currentWithChildrenAndDescendants.descendants)
|
||||
const childrenIds = Array.from(currentWithChildrenAndDescendants.children)
|
||||
let newDataList = []
|
||||
|
||||
if (current.expand) {
|
||||
current.expand = false
|
||||
|
||||
newDataList = dataList.filter(item => !descendantsIds.includes(item.page_id))
|
||||
}
|
||||
else {
|
||||
current.expand = true
|
||||
|
||||
newDataList = [
|
||||
...dataList.slice(0, index + 1),
|
||||
...childrenIds.map(item => ({
|
||||
...pagesMap[item],
|
||||
expand: false,
|
||||
depth: listMapWithChildrenAndDescendants[item].depth,
|
||||
})),
|
||||
...dataList.slice(index + 1),
|
||||
]
|
||||
}
|
||||
setDataList(newDataList)
|
||||
}
|
||||
|
||||
const copyValue = new Set(value)
|
||||
const handleCheck = (index: number) => {
|
||||
const current = currentDataList[index]
|
||||
const pageId = current.page_id
|
||||
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId]
|
||||
|
||||
if (copyValue.has(pageId)) {
|
||||
if (!searchValue) {
|
||||
for (const item of currentWithChildrenAndDescendants.descendants)
|
||||
copyValue.delete(item)
|
||||
}
|
||||
|
||||
copyValue.delete(pageId)
|
||||
}
|
||||
else {
|
||||
if (!searchValue) {
|
||||
for (const item of currentWithChildrenAndDescendants.descendants)
|
||||
copyValue.add(item)
|
||||
}
|
||||
|
||||
copyValue.add(pageId)
|
||||
}
|
||||
|
||||
onSelect(new Set(copyValue))
|
||||
}
|
||||
|
||||
const handlePreview = (index: number) => {
|
||||
const current = currentDataList[index]
|
||||
const pageId = current.page_id
|
||||
|
||||
setLocalPreviewPageId(pageId)
|
||||
|
||||
if (onPreview)
|
||||
onPreview(pageId)
|
||||
}
|
||||
|
||||
if (!currentDataList.length) {
|
||||
if (!rows.length) {
|
||||
return (
|
||||
<div className="flex h-[296px] items-center justify-center text-[13px] text-text-tertiary">
|
||||
{t('dataSource.notion.selector.noSearchResult', { ns: 'common' })}
|
||||
@@ -307,29 +54,18 @@ const PageSelector = ({
|
||||
}
|
||||
|
||||
return (
|
||||
<List
|
||||
className="py-2"
|
||||
height={296}
|
||||
itemCount={currentDataList.length}
|
||||
itemSize={28}
|
||||
width="100%"
|
||||
itemKey={(index, data) => data.dataList[index].page_id}
|
||||
itemData={{
|
||||
dataList: currentDataList,
|
||||
handleToggle,
|
||||
checkedIds: value,
|
||||
disabledCheckedIds: disabledValue,
|
||||
handleCheck,
|
||||
canPreview,
|
||||
handlePreview,
|
||||
listMapWithChildrenAndDescendants,
|
||||
searchValue,
|
||||
previewPageId: currentPreviewPageId,
|
||||
pagesMap,
|
||||
}}
|
||||
>
|
||||
{Item}
|
||||
</List>
|
||||
<VirtualPageList
|
||||
checkedIds={value}
|
||||
disabledValue={disabledValue}
|
||||
onPreview={handlePreview}
|
||||
onSelect={handleSelect}
|
||||
onToggle={handleToggle}
|
||||
previewPageId={currentPreviewPageId}
|
||||
rows={rows}
|
||||
searchValue={effectiveSearchValue}
|
||||
selectionMode="multiple"
|
||||
showPreview={canPreview}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
import type { CSSProperties } from 'react'
|
||||
import type { NotionPageRow as NotionPageRowData, NotionPageSelectionMode } from './types'
|
||||
import { RiArrowDownSLine, RiArrowRightSLine } from '@remixicon/react'
|
||||
import { memo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Checkbox from '@/app/components/base/checkbox'
|
||||
import NotionIcon from '@/app/components/base/notion-icon'
|
||||
import Radio from '@/app/components/base/radio/ui'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type NotionPageRowProps = {
|
||||
checked: boolean
|
||||
disabled: boolean
|
||||
isPreviewed: boolean
|
||||
onPreview: (pageId: string) => void
|
||||
onSelect: (pageId: string) => void
|
||||
onToggle: (pageId: string) => void
|
||||
row: NotionPageRowData
|
||||
searchValue: string
|
||||
selectionMode: NotionPageSelectionMode
|
||||
showPreview: boolean
|
||||
style: CSSProperties
|
||||
}
|
||||
|
||||
const NotionPageRow = ({
|
||||
checked,
|
||||
disabled,
|
||||
isPreviewed,
|
||||
onPreview,
|
||||
onSelect,
|
||||
onToggle,
|
||||
row,
|
||||
searchValue,
|
||||
selectionMode,
|
||||
showPreview,
|
||||
style,
|
||||
}: NotionPageRowProps) => {
|
||||
const { t } = useTranslation()
|
||||
const pageId = row.page.page_id
|
||||
const breadcrumbs = row.ancestors.length ? [...row.ancestors, row.page.page_name] : [row.page.page_name]
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn('group flex cursor-pointer items-center rounded-md pr-[2px] pl-2 hover:bg-state-base-hover', isPreviewed && 'bg-state-base-hover')}
|
||||
style={style}
|
||||
data-testid={`notion-page-row-${pageId}`}
|
||||
>
|
||||
{selectionMode === 'multiple'
|
||||
? (
|
||||
<Checkbox
|
||||
className="mr-2 shrink-0"
|
||||
checked={checked}
|
||||
disabled={disabled}
|
||||
onCheck={() => onSelect(pageId)}
|
||||
id={`notion-page-checkbox-${pageId}`}
|
||||
/>
|
||||
)
|
||||
: (
|
||||
<Radio
|
||||
className="mr-2 shrink-0"
|
||||
isChecked={checked}
|
||||
disabled={disabled}
|
||||
onCheck={() => onSelect(pageId)}
|
||||
/>
|
||||
)}
|
||||
{!searchValue && row.hasChild && (
|
||||
<div
|
||||
className="mr-1 flex h-5 w-5 shrink-0 items-center justify-center rounded-md hover:bg-components-button-ghost-bg-hover"
|
||||
style={{ marginLeft: row.depth * 8 }}
|
||||
onClick={() => onToggle(pageId)}
|
||||
data-testid={`notion-page-toggle-${pageId}`}
|
||||
>
|
||||
{row.expand
|
||||
? <RiArrowDownSLine className="h-4 w-4 text-text-tertiary" />
|
||||
: <RiArrowRightSLine className="h-4 w-4 text-text-tertiary" />}
|
||||
</div>
|
||||
)}
|
||||
{!searchValue && !row.hasChild && row.parentExists && (
|
||||
<div className="mr-1 h-5 w-5 shrink-0" style={{ marginLeft: row.depth * 8 }} />
|
||||
)}
|
||||
<NotionIcon
|
||||
className="mr-1 shrink-0"
|
||||
type="page"
|
||||
src={row.page.page_icon}
|
||||
/>
|
||||
<div
|
||||
className="grow truncate text-[13px] leading-4 font-medium text-text-secondary"
|
||||
title={row.page.page_name}
|
||||
data-testid={`notion-page-name-${pageId}`}
|
||||
>
|
||||
{row.page.page_name}
|
||||
</div>
|
||||
{showPreview && (
|
||||
<div
|
||||
className="ml-1 hidden h-6 shrink-0 cursor-pointer items-center rounded-md border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-2 text-xs
|
||||
leading-4 font-medium text-components-button-secondary-text shadow-xs shadow-shadow-shadow-3 backdrop-blur-[10px]
|
||||
group-hover:flex hover:border-components-button-secondary-border-hover hover:bg-components-button-secondary-bg-hover"
|
||||
onClick={() => onPreview(pageId)}
|
||||
data-testid={`notion-page-preview-${pageId}`}
|
||||
>
|
||||
{t('dataSource.notion.selector.preview', { ns: 'common' })}
|
||||
</div>
|
||||
)}
|
||||
{searchValue && (
|
||||
<div
|
||||
className="ml-1 max-w-[120px] shrink-0 truncate text-xs text-text-quaternary"
|
||||
title={breadcrumbs.join(' / ')}
|
||||
>
|
||||
{breadcrumbs.join(' / ')}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(NotionPageRow)
|
||||
@@ -0,0 +1,21 @@
|
||||
import type { DataSourceNotionPage } from '@/models/common'
|
||||
|
||||
export type NotionPageSelectionMode = 'multiple' | 'single'
|
||||
|
||||
export type NotionPageTreeItem = {
|
||||
children: Set<string>
|
||||
descendants: Set<string>
|
||||
depth: number
|
||||
ancestors: string[]
|
||||
} & DataSourceNotionPage
|
||||
|
||||
export type NotionPageTreeMap = Record<string, NotionPageTreeItem>
|
||||
|
||||
export type NotionPageRow = {
|
||||
page: DataSourceNotionPage
|
||||
parentExists: boolean
|
||||
depth: number
|
||||
expand: boolean
|
||||
hasChild: boolean
|
||||
ancestors: string[]
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
import type { NotionPageSelectionMode } from './types'
|
||||
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
|
||||
import { startTransition, useCallback, useDeferredValue, useMemo, useState } from 'react'
|
||||
import { buildNotionPageTree, getNextSelectedPageIds, getRootPageIds, getVisiblePageRows } from './utils'
|
||||
|
||||
type UsePageSelectorModelProps = {
|
||||
checkedIds: Set<string>
|
||||
searchValue: string
|
||||
pagesMap: DataSourceNotionPageMap
|
||||
list: DataSourceNotionPage[]
|
||||
onSelect: (selectedPagesId: Set<string>) => void
|
||||
previewPageId?: string
|
||||
onPreview?: (selectedPageId: string) => void
|
||||
selectionMode: NotionPageSelectionMode
|
||||
}
|
||||
|
||||
export const usePageSelectorModel = ({
|
||||
checkedIds,
|
||||
searchValue,
|
||||
pagesMap,
|
||||
list,
|
||||
onSelect,
|
||||
previewPageId,
|
||||
onPreview,
|
||||
selectionMode,
|
||||
}: UsePageSelectorModelProps) => {
|
||||
const deferredSearchValue = useDeferredValue(searchValue)
|
||||
const [expandedIds, setExpandedIds] = useState<Set<string>>(() => new Set())
|
||||
const [localPreviewPageId, setLocalPreviewPageId] = useState('')
|
||||
|
||||
const treeMap = useMemo(() => buildNotionPageTree(list, pagesMap), [list, pagesMap])
|
||||
const rootPageIds = useMemo(() => getRootPageIds(list, pagesMap), [list, pagesMap])
|
||||
|
||||
const rows = useMemo(() => {
|
||||
return getVisiblePageRows({
|
||||
list,
|
||||
pagesMap,
|
||||
searchValue: deferredSearchValue,
|
||||
treeMap,
|
||||
rootPageIds,
|
||||
expandedIds,
|
||||
})
|
||||
}, [deferredSearchValue, expandedIds, list, pagesMap, rootPageIds, treeMap])
|
||||
|
||||
const currentPreviewPageId = previewPageId ?? localPreviewPageId
|
||||
|
||||
const handleToggle = useCallback((pageId: string) => {
|
||||
startTransition(() => {
|
||||
setExpandedIds((currentExpandedIds) => {
|
||||
const nextExpandedIds = new Set(currentExpandedIds)
|
||||
|
||||
if (nextExpandedIds.has(pageId)) {
|
||||
nextExpandedIds.delete(pageId)
|
||||
treeMap[pageId]?.descendants.forEach(descendantId => nextExpandedIds.delete(descendantId))
|
||||
}
|
||||
else {
|
||||
nextExpandedIds.add(pageId)
|
||||
}
|
||||
|
||||
return nextExpandedIds
|
||||
})
|
||||
})
|
||||
}, [treeMap])
|
||||
|
||||
const handleSelect = useCallback((pageId: string) => {
|
||||
onSelect(getNextSelectedPageIds({
|
||||
checkedIds,
|
||||
pageId,
|
||||
searchValue: deferredSearchValue,
|
||||
selectionMode,
|
||||
treeMap,
|
||||
}))
|
||||
}, [checkedIds, deferredSearchValue, onSelect, selectionMode, treeMap])
|
||||
|
||||
const handlePreview = useCallback((pageId: string) => {
|
||||
setLocalPreviewPageId(pageId)
|
||||
onPreview?.(pageId)
|
||||
}, [onPreview])
|
||||
|
||||
return {
|
||||
currentPreviewPageId,
|
||||
effectiveSearchValue: deferredSearchValue,
|
||||
rows,
|
||||
handlePreview,
|
||||
handleSelect,
|
||||
handleToggle,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
import type { NotionPageRow, NotionPageSelectionMode, NotionPageTreeItem, NotionPageTreeMap } from './types'
|
||||
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
|
||||
|
||||
export const recursivePushInParentDescendants = (
|
||||
pagesMap: DataSourceNotionPageMap,
|
||||
listTreeMap: NotionPageTreeMap,
|
||||
current: NotionPageTreeItem,
|
||||
leafItem: NotionPageTreeItem,
|
||||
) => {
|
||||
const parentId = current.parent_id
|
||||
const pageId = current.page_id
|
||||
|
||||
if (!parentId || !pageId)
|
||||
return
|
||||
|
||||
if (parentId !== 'root' && pagesMap[parentId]) {
|
||||
if (!listTreeMap[parentId]) {
|
||||
const children = new Set([pageId])
|
||||
const descendants = new Set([pageId, leafItem.page_id])
|
||||
listTreeMap[parentId] = {
|
||||
...pagesMap[parentId],
|
||||
children,
|
||||
descendants,
|
||||
depth: 0,
|
||||
ancestors: [],
|
||||
}
|
||||
}
|
||||
else {
|
||||
listTreeMap[parentId].children.add(pageId)
|
||||
listTreeMap[parentId].descendants.add(pageId)
|
||||
listTreeMap[parentId].descendants.add(leafItem.page_id)
|
||||
}
|
||||
|
||||
leafItem.depth++
|
||||
leafItem.ancestors.unshift(listTreeMap[parentId].page_name)
|
||||
|
||||
if (listTreeMap[parentId].parent_id !== 'root')
|
||||
recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap[parentId], leafItem)
|
||||
}
|
||||
}
|
||||
|
||||
export const buildNotionPageTree = (
|
||||
list: DataSourceNotionPage[],
|
||||
pagesMap: DataSourceNotionPageMap,
|
||||
): NotionPageTreeMap => {
|
||||
return list.reduce((prev: NotionPageTreeMap, next) => {
|
||||
const pageId = next.page_id
|
||||
if (!prev[pageId])
|
||||
prev[pageId] = { ...next, children: new Set(), descendants: new Set(), depth: 0, ancestors: [] }
|
||||
|
||||
recursivePushInParentDescendants(pagesMap, prev, prev[pageId], prev[pageId])
|
||||
return prev
|
||||
}, {})
|
||||
}
|
||||
|
||||
export const getRootPageIds = (
|
||||
list: DataSourceNotionPage[],
|
||||
pagesMap: DataSourceNotionPageMap,
|
||||
) => {
|
||||
return list
|
||||
.filter(item => item.parent_id === 'root' || !pagesMap[item.parent_id])
|
||||
.map(item => item.page_id)
|
||||
}
|
||||
|
||||
export const getVisiblePageRows = ({
|
||||
list,
|
||||
pagesMap,
|
||||
searchValue,
|
||||
treeMap,
|
||||
rootPageIds,
|
||||
expandedIds,
|
||||
}: {
|
||||
list: DataSourceNotionPage[]
|
||||
pagesMap: DataSourceNotionPageMap
|
||||
searchValue: string
|
||||
treeMap: NotionPageTreeMap
|
||||
rootPageIds: string[]
|
||||
expandedIds: Set<string>
|
||||
}): NotionPageRow[] => {
|
||||
if (searchValue) {
|
||||
return list
|
||||
.filter(item => item.page_name.includes(searchValue))
|
||||
.map(item => ({
|
||||
page: item,
|
||||
parentExists: item.parent_id !== 'root' && Boolean(pagesMap[item.parent_id]),
|
||||
depth: treeMap[item.page_id]?.depth ?? 0,
|
||||
expand: false,
|
||||
hasChild: (treeMap[item.page_id]?.children.size ?? 0) > 0,
|
||||
ancestors: treeMap[item.page_id]?.ancestors ?? [],
|
||||
}))
|
||||
}
|
||||
|
||||
const rows: NotionPageRow[] = []
|
||||
|
||||
const visit = (pageId: string) => {
|
||||
const current = treeMap[pageId]
|
||||
if (!current)
|
||||
return
|
||||
|
||||
const expand = expandedIds.has(pageId)
|
||||
rows.push({
|
||||
page: current,
|
||||
parentExists: current.parent_id !== 'root' && Boolean(pagesMap[current.parent_id]),
|
||||
depth: current.depth,
|
||||
expand,
|
||||
hasChild: current.children.size > 0,
|
||||
ancestors: current.ancestors,
|
||||
})
|
||||
|
||||
if (!expand)
|
||||
return
|
||||
|
||||
current.children.forEach(visit)
|
||||
}
|
||||
|
||||
rootPageIds.forEach(visit)
|
||||
|
||||
return rows
|
||||
}
|
||||
|
||||
export const getNextSelectedPageIds = ({
|
||||
checkedIds,
|
||||
pageId,
|
||||
searchValue,
|
||||
selectionMode,
|
||||
treeMap,
|
||||
}: {
|
||||
checkedIds: Set<string>
|
||||
pageId: string
|
||||
searchValue: string
|
||||
selectionMode: NotionPageSelectionMode
|
||||
treeMap: NotionPageTreeMap
|
||||
}) => {
|
||||
const nextCheckedIds = new Set(checkedIds)
|
||||
const descendants = treeMap[pageId]?.descendants ?? new Set<string>()
|
||||
|
||||
if (selectionMode === 'single') {
|
||||
if (nextCheckedIds.has(pageId)) {
|
||||
nextCheckedIds.delete(pageId)
|
||||
}
|
||||
else {
|
||||
nextCheckedIds.clear()
|
||||
nextCheckedIds.add(pageId)
|
||||
}
|
||||
|
||||
return nextCheckedIds
|
||||
}
|
||||
|
||||
if (nextCheckedIds.has(pageId)) {
|
||||
if (!searchValue)
|
||||
descendants.forEach(item => nextCheckedIds.delete(item))
|
||||
|
||||
nextCheckedIds.delete(pageId)
|
||||
return nextCheckedIds
|
||||
}
|
||||
|
||||
if (!searchValue)
|
||||
descendants.forEach(item => nextCheckedIds.add(item))
|
||||
|
||||
nextCheckedIds.add(pageId)
|
||||
|
||||
return nextCheckedIds
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
'use client'
|
||||
|
||||
import type { NotionPageRow, NotionPageSelectionMode } from './types'
|
||||
import { useVirtualizer } from '@tanstack/react-virtual'
|
||||
import { useRef } from 'react'
|
||||
import PageRow from './page-row'
|
||||
|
||||
type VirtualPageListProps = {
|
||||
checkedIds: Set<string>
|
||||
disabledValue: Set<string>
|
||||
onPreview: (pageId: string) => void
|
||||
onSelect: (pageId: string) => void
|
||||
onToggle: (pageId: string) => void
|
||||
previewPageId: string
|
||||
rows: NotionPageRow[]
|
||||
searchValue: string
|
||||
selectionMode: NotionPageSelectionMode
|
||||
showPreview: boolean
|
||||
}
|
||||
|
||||
const rowHeight = 28
|
||||
|
||||
const VirtualPageList = ({
|
||||
checkedIds,
|
||||
disabledValue,
|
||||
onPreview,
|
||||
onSelect,
|
||||
onToggle,
|
||||
previewPageId,
|
||||
rows,
|
||||
searchValue,
|
||||
selectionMode,
|
||||
showPreview,
|
||||
}: VirtualPageListProps) => {
|
||||
const scrollRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
const rowVirtualizer = useVirtualizer({
|
||||
count: rows.length,
|
||||
estimateSize: () => rowHeight,
|
||||
getScrollElement: () => scrollRef.current,
|
||||
overscan: 6,
|
||||
paddingEnd: 8,
|
||||
paddingStart: 8,
|
||||
})
|
||||
|
||||
const virtualRows = rowVirtualizer.getVirtualItems()
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={scrollRef}
|
||||
className="h-[296px] overflow-auto"
|
||||
data-testid="virtual-list"
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
height: `${rowVirtualizer.getTotalSize()}px`,
|
||||
position: 'relative',
|
||||
}}
|
||||
>
|
||||
{virtualRows.map((virtualRow) => {
|
||||
const row = rows[virtualRow.index]
|
||||
const pageId = row.page.page_id
|
||||
|
||||
return (
|
||||
<PageRow
|
||||
key={pageId}
|
||||
checked={checkedIds.has(pageId)}
|
||||
disabled={disabledValue.has(pageId)}
|
||||
isPreviewed={previewPageId === pageId}
|
||||
onPreview={onPreview}
|
||||
onSelect={onSelect}
|
||||
onToggle={onToggle}
|
||||
row={row}
|
||||
searchValue={searchValue}
|
||||
selectionMode={selectionMode}
|
||||
showPreview={showPreview}
|
||||
style={{
|
||||
height: `${virtualRow.size}px`,
|
||||
left: 8,
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
transform: `translateY(${virtualRow.start}px)`,
|
||||
width: 'calc(100% - 16px)',
|
||||
}}
|
||||
/>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default VirtualPageList
|
||||
@@ -12,6 +12,12 @@ export const Select = BaseSelect.Root
|
||||
export const SelectGroup = BaseSelect.Group
|
||||
export const SelectGroupLabel = BaseSelect.GroupLabel
|
||||
export const SelectValue = BaseSelect.Value
|
||||
/** @public */
|
||||
export const SelectGroup = BaseSelect.Group
|
||||
/** @public */
|
||||
export const SelectGroupLabel = BaseSelect.GroupLabel
|
||||
/** @public */
|
||||
export const SelectSeparator = BaseSelect.Separator
|
||||
|
||||
const selectTriggerVariants = cva(
|
||||
'',
|
||||
|
||||
@@ -106,7 +106,7 @@ const OnlineDocuments = ({
|
||||
if (!currentCredentialId)
|
||||
return
|
||||
getOnlineDocuments()
|
||||
}, [currentCredentialId])
|
||||
}, [currentCredentialId, getOnlineDocuments])
|
||||
|
||||
const handleSearchValueChange = useCallback((value: string) => {
|
||||
const { setSearchValue } = dataSourceStore.getState()
|
||||
@@ -156,6 +156,7 @@ const OnlineDocuments = ({
|
||||
{documentsData?.length
|
||||
? (
|
||||
<PageSelector
|
||||
key={`${currentCredentialId}:${supportBatchUpload ? 'multiple' : 'single'}`}
|
||||
checkedIds={selectedPagesId}
|
||||
disabledValue={new Set()}
|
||||
searchValue={searchValue}
|
||||
@@ -165,7 +166,6 @@ const OnlineDocuments = ({
|
||||
canPreview={!isInPipeline}
|
||||
onPreview={handlePreviewPage}
|
||||
isMultipleChoice={supportBatchUpload}
|
||||
currentCredentialId={currentCredentialId}
|
||||
/>
|
||||
)
|
||||
: (
|
||||
|
||||
@@ -1,26 +1,11 @@
|
||||
import type { NotionPageTreeItem, NotionPageTreeMap } from '../index'
|
||||
import type { NotionPageTreeItem, NotionPageTreeMap } from '@/app/components/base/notion-page-selector/page-selector/types'
|
||||
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { recursivePushInParentDescendants } from '@/app/components/base/notion-page-selector/page-selector/utils'
|
||||
import PageSelector from '../index'
|
||||
import { recursivePushInParentDescendants } from '../utils'
|
||||
|
||||
// Mock react-window FixedSizeList - renders items directly for testing
|
||||
vi.mock('react-window', () => ({
|
||||
FixedSizeList: ({ children: ItemComponent, itemCount, itemData, itemKey }: { children: React.ComponentType<{ index: number, style: React.CSSProperties, data: unknown }>, itemCount: number, itemData: unknown, itemKey?: (index: number, data: unknown) => string | number }) => (
|
||||
<div data-testid="virtual-list">
|
||||
{Array.from({ length: itemCount }).map((_, index) => (
|
||||
<ItemComponent
|
||||
key={itemKey?.(index, itemData) || index}
|
||||
index={index}
|
||||
style={{ top: index * 28, left: 0, right: 0, width: '100%', position: 'absolute' as const }}
|
||||
data={itemData}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
areEqual: (prevProps: Record<string, unknown>, nextProps: Record<string, unknown>) => prevProps === nextProps,
|
||||
}))
|
||||
vi.mock('@tanstack/react-virtual')
|
||||
|
||||
// Note: NotionIcon from @/app/components/base/ is NOT mocked - using real component per testing guidelines
|
||||
|
||||
@@ -70,7 +55,6 @@ const createDefaultProps = (overrides?: Partial<PageSelectorProps>): PageSelecto
|
||||
canPreview: true,
|
||||
onPreview: vi.fn(),
|
||||
isMultipleChoice: true,
|
||||
currentCredentialId: 'cred-1',
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
@@ -114,7 +98,7 @@ describe('PageSelector', () => {
|
||||
expect(screen.queryByTestId('virtual-list')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render items using FixedSizeList', () => {
|
||||
it('should render items using VirtualList', () => {
|
||||
const pages = [
|
||||
createMockPage({ page_id: 'page-1', page_name: 'Page 1' }),
|
||||
createMockPage({ page_id: 'page-2', page_name: 'Page 2' }),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { NotionPageTreeItem, NotionPageTreeMap } from '../index'
|
||||
import type { NotionPageTreeItem, NotionPageTreeMap } from '@/app/components/base/notion-page-selector/page-selector/types'
|
||||
import type { DataSourceNotionPageMap } from '@/models/common'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { recursivePushInParentDescendants } from '../utils'
|
||||
import { recursivePushInParentDescendants } from '@/app/components/base/notion-page-selector/page-selector/utils'
|
||||
|
||||
const makePageEntry = (overrides: Partial<NotionPageTreeItem>): NotionPageTreeItem => ({
|
||||
page_icon: null,
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { FixedSizeList as List } from 'react-window'
|
||||
import Item from './item'
|
||||
import { recursivePushInParentDescendants } from './utils'
|
||||
import { usePageSelectorModel } from '@/app/components/base/notion-page-selector/page-selector/use-page-selector-model'
|
||||
import VirtualPageList from '@/app/components/base/notion-page-selector/page-selector/virtual-page-list'
|
||||
|
||||
type PageSelectorProps = {
|
||||
checkedIds: Set<string>
|
||||
@@ -15,23 +13,9 @@ type PageSelectorProps = {
|
||||
canPreview?: boolean
|
||||
onPreview?: (selectedPageId: string) => void
|
||||
isMultipleChoice?: boolean
|
||||
currentCredentialId: string
|
||||
currentCredentialId?: string
|
||||
}
|
||||
|
||||
export type NotionPageTreeItem = {
|
||||
children: Set<string>
|
||||
descendants: Set<string>
|
||||
depth: number
|
||||
ancestors: string[]
|
||||
} & DataSourceNotionPage
|
||||
|
||||
export type NotionPageTreeMap = Record<string, NotionPageTreeItem>
|
||||
|
||||
type NotionPageItem = {
|
||||
expand: boolean
|
||||
depth: number
|
||||
} & DataSourceNotionPage
|
||||
|
||||
const PageSelector = ({
|
||||
checkedIds,
|
||||
disabledValue,
|
||||
@@ -42,116 +26,28 @@ const PageSelector = ({
|
||||
canPreview = true,
|
||||
onPreview,
|
||||
isMultipleChoice = true,
|
||||
currentCredentialId,
|
||||
currentCredentialId: _currentCredentialId,
|
||||
}: PageSelectorProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [dataList, setDataList] = useState<NotionPageItem[]>([])
|
||||
const [currentPreviewPageId, setCurrentPreviewPageId] = useState('')
|
||||
|
||||
useEffect(() => {
|
||||
setDataList(list.filter(item => item.parent_id === 'root' || !pagesMap[item.parent_id]).map((item) => {
|
||||
return {
|
||||
...item,
|
||||
expand: false,
|
||||
depth: 0,
|
||||
}
|
||||
}))
|
||||
}, [currentCredentialId])
|
||||
|
||||
const searchDataList = list.filter((item) => {
|
||||
return item.page_name.includes(searchValue)
|
||||
}).map((item) => {
|
||||
return {
|
||||
...item,
|
||||
expand: false,
|
||||
depth: 0,
|
||||
}
|
||||
const selectionMode = isMultipleChoice ? 'multiple' : 'single'
|
||||
const {
|
||||
currentPreviewPageId,
|
||||
effectiveSearchValue,
|
||||
rows,
|
||||
handlePreview,
|
||||
handleSelect,
|
||||
handleToggle,
|
||||
} = usePageSelectorModel({
|
||||
checkedIds,
|
||||
list,
|
||||
onPreview,
|
||||
onSelect,
|
||||
pagesMap,
|
||||
searchValue,
|
||||
selectionMode,
|
||||
})
|
||||
const currentDataList = searchValue ? searchDataList : dataList
|
||||
|
||||
const listMapWithChildrenAndDescendants = useMemo(() => {
|
||||
return list.reduce((prev: NotionPageTreeMap, next: DataSourceNotionPage) => {
|
||||
const pageId = next.page_id
|
||||
if (!prev[pageId])
|
||||
prev[pageId] = { ...next, children: new Set(), descendants: new Set(), depth: 0, ancestors: [] }
|
||||
|
||||
recursivePushInParentDescendants(pagesMap, prev, prev[pageId], prev[pageId])
|
||||
return prev
|
||||
}, {})
|
||||
}, [list, pagesMap])
|
||||
|
||||
const handleToggle = useCallback((index: number) => {
|
||||
const current = dataList[index]
|
||||
const pageId = current.page_id
|
||||
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId]
|
||||
const descendantsIds = Array.from(currentWithChildrenAndDescendants.descendants)
|
||||
const childrenIds = Array.from(currentWithChildrenAndDescendants.children)
|
||||
let newDataList = []
|
||||
|
||||
if (current.expand) {
|
||||
current.expand = false
|
||||
|
||||
newDataList = dataList.filter(item => !descendantsIds.includes(item.page_id))
|
||||
}
|
||||
else {
|
||||
current.expand = true
|
||||
|
||||
newDataList = [
|
||||
...dataList.slice(0, index + 1),
|
||||
...childrenIds.map(item => ({
|
||||
...pagesMap[item],
|
||||
expand: false,
|
||||
depth: listMapWithChildrenAndDescendants[item].depth,
|
||||
})),
|
||||
...dataList.slice(index + 1),
|
||||
]
|
||||
}
|
||||
setDataList(newDataList)
|
||||
}, [dataList, listMapWithChildrenAndDescendants, pagesMap])
|
||||
|
||||
const handleCheck = useCallback((index: number) => {
|
||||
const copyValue = new Set(checkedIds)
|
||||
const current = currentDataList[index]
|
||||
const pageId = current.page_id
|
||||
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId]
|
||||
|
||||
if (copyValue.has(pageId)) {
|
||||
if (!searchValue && isMultipleChoice) {
|
||||
for (const item of currentWithChildrenAndDescendants.descendants)
|
||||
copyValue.delete(item)
|
||||
}
|
||||
|
||||
copyValue.delete(pageId)
|
||||
}
|
||||
else {
|
||||
if (!searchValue && isMultipleChoice) {
|
||||
for (const item of currentWithChildrenAndDescendants.descendants)
|
||||
copyValue.add(item)
|
||||
}
|
||||
// Single choice mode, clear previous selection
|
||||
if (!isMultipleChoice && copyValue.size > 0) {
|
||||
copyValue.clear()
|
||||
copyValue.add(pageId)
|
||||
}
|
||||
else {
|
||||
copyValue.add(pageId)
|
||||
}
|
||||
}
|
||||
|
||||
onSelect(new Set(copyValue))
|
||||
}, [currentDataList, isMultipleChoice, listMapWithChildrenAndDescendants, onSelect, searchValue, checkedIds])
|
||||
|
||||
const handlePreview = useCallback((index: number) => {
|
||||
const current = currentDataList[index]
|
||||
const pageId = current.page_id
|
||||
|
||||
setCurrentPreviewPageId(pageId)
|
||||
|
||||
if (onPreview)
|
||||
onPreview(pageId)
|
||||
}, [currentDataList, onPreview])
|
||||
|
||||
if (!currentDataList.length) {
|
||||
if (!rows.length) {
|
||||
return (
|
||||
<div className="flex h-[296px] items-center justify-center text-[13px] text-text-tertiary">
|
||||
{t('dataSource.notion.selector.noSearchResult', { ns: 'common' })}
|
||||
@@ -160,30 +56,18 @@ const PageSelector = ({
|
||||
}
|
||||
|
||||
return (
|
||||
<List
|
||||
className="py-2"
|
||||
height={296}
|
||||
itemCount={currentDataList.length}
|
||||
itemSize={28}
|
||||
width="100%"
|
||||
itemKey={(index, data) => data.dataList[index].page_id}
|
||||
itemData={{
|
||||
dataList: currentDataList,
|
||||
handleToggle,
|
||||
checkedIds,
|
||||
disabledCheckedIds: disabledValue,
|
||||
handleCheck,
|
||||
canPreview,
|
||||
handlePreview,
|
||||
listMapWithChildrenAndDescendants,
|
||||
searchValue,
|
||||
previewPageId: currentPreviewPageId,
|
||||
pagesMap,
|
||||
isMultipleChoice,
|
||||
}}
|
||||
>
|
||||
{Item}
|
||||
</List>
|
||||
<VirtualPageList
|
||||
checkedIds={checkedIds}
|
||||
disabledValue={disabledValue}
|
||||
onPreview={handlePreview}
|
||||
onSelect={handleSelect}
|
||||
onToggle={handleToggle}
|
||||
previewPageId={currentPreviewPageId}
|
||||
rows={rows}
|
||||
searchValue={effectiveSearchValue}
|
||||
selectionMode={selectionMode}
|
||||
showPreview={canPreview}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
import type { ListChildComponentProps } from 'react-window'
|
||||
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
|
||||
import { RiArrowDownSLine, RiArrowRightSLine } from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { areEqual } from 'react-window'
|
||||
import Checkbox from '@/app/components/base/checkbox'
|
||||
import NotionIcon from '@/app/components/base/notion-icon'
|
||||
import Radio from '@/app/components/base/radio/ui'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type NotionPageTreeItem = {
|
||||
children: Set<string>
|
||||
descendants: Set<string>
|
||||
depth: number
|
||||
ancestors: string[]
|
||||
} & DataSourceNotionPage
|
||||
|
||||
type NotionPageTreeMap = Record<string, NotionPageTreeItem>
|
||||
|
||||
type NotionPageItem = {
|
||||
expand: boolean
|
||||
depth: number
|
||||
} & DataSourceNotionPage
|
||||
|
||||
const Item = ({ index, style, data }: ListChildComponentProps<{
|
||||
dataList: NotionPageItem[]
|
||||
handleToggle: (index: number) => void
|
||||
checkedIds: Set<string>
|
||||
disabledCheckedIds: Set<string>
|
||||
handleCheck: (index: number) => void
|
||||
canPreview?: boolean
|
||||
handlePreview: (index: number) => void
|
||||
listMapWithChildrenAndDescendants: NotionPageTreeMap
|
||||
searchValue: string
|
||||
previewPageId: string
|
||||
pagesMap: DataSourceNotionPageMap
|
||||
isMultipleChoice?: boolean
|
||||
}>) => {
|
||||
const { t } = useTranslation()
|
||||
const {
|
||||
dataList,
|
||||
handleToggle,
|
||||
checkedIds,
|
||||
disabledCheckedIds,
|
||||
handleCheck,
|
||||
canPreview,
|
||||
handlePreview,
|
||||
listMapWithChildrenAndDescendants,
|
||||
searchValue,
|
||||
previewPageId,
|
||||
pagesMap,
|
||||
isMultipleChoice,
|
||||
} = data
|
||||
const current = dataList[index]
|
||||
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[current.page_id]
|
||||
const hasChild = currentWithChildrenAndDescendants.descendants.size > 0
|
||||
const ancestors = currentWithChildrenAndDescendants.ancestors
|
||||
const breadCrumbs = ancestors.length ? [...ancestors, current.page_name] : [current.page_name]
|
||||
const disabled = disabledCheckedIds.has(current.page_id)
|
||||
|
||||
const renderArrow = () => {
|
||||
if (hasChild) {
|
||||
return (
|
||||
<div
|
||||
className="mr-1 flex h-5 w-5 shrink-0 items-center justify-center rounded-md hover:bg-components-button-ghost-bg-hover"
|
||||
style={{ marginLeft: current.depth * 8 }}
|
||||
onClick={() => handleToggle(index)}
|
||||
>
|
||||
{
|
||||
current.expand
|
||||
? <RiArrowDownSLine className="h-4 w-4 text-text-tertiary" />
|
||||
: <RiArrowRightSLine className="h-4 w-4 text-text-tertiary" />
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
if (current.parent_id === 'root' || !pagesMap[current.parent_id]) {
|
||||
return (
|
||||
<div></div>
|
||||
)
|
||||
}
|
||||
return (
|
||||
<div className="mr-1 h-5 w-5 shrink-0" style={{ marginLeft: current.depth * 8 }} />
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn('group flex cursor-pointer items-center rounded-md pl-2 pr-[2px] hover:bg-state-base-hover', previewPageId === current.page_id && 'bg-state-base-hover')}
|
||||
style={{ ...style, top: style.top as number + 8, left: 8, right: 8, width: 'calc(100% - 16px)' }}
|
||||
>
|
||||
{isMultipleChoice
|
||||
? (
|
||||
<Checkbox
|
||||
className="mr-2 shrink-0"
|
||||
checked={checkedIds.has(current.page_id)}
|
||||
disabled={disabled}
|
||||
onCheck={() => {
|
||||
handleCheck(index)
|
||||
}}
|
||||
/>
|
||||
)
|
||||
: (
|
||||
<Radio
|
||||
className="mr-2 shrink-0"
|
||||
isChecked={checkedIds.has(current.page_id)}
|
||||
disabled={disabled}
|
||||
onCheck={() => {
|
||||
handleCheck(index)
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{!searchValue && renderArrow()}
|
||||
<NotionIcon
|
||||
className="mr-1 shrink-0"
|
||||
type="page"
|
||||
src={current.page_icon}
|
||||
/>
|
||||
<div
|
||||
className="grow truncate text-[13px] font-medium leading-4 text-text-secondary"
|
||||
title={current.page_name}
|
||||
>
|
||||
{current.page_name}
|
||||
</div>
|
||||
{
|
||||
canPreview && (
|
||||
<div
|
||||
className="ml-1 hidden h-6 shrink-0 cursor-pointer items-center rounded-md border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-2 text-xs
|
||||
font-medium leading-4 text-components-button-secondary-text shadow-xs shadow-shadow-shadow-3 backdrop-blur-[10px]
|
||||
hover:border-components-button-secondary-border-hover hover:bg-components-button-secondary-bg-hover group-hover:flex"
|
||||
onClick={() => handlePreview(index)}
|
||||
>
|
||||
{t('dataSource.notion.selector.preview', { ns: 'common' })}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
searchValue && (
|
||||
<div
|
||||
className="ml-1 max-w-[120px] shrink-0 truncate text-xs text-text-quaternary"
|
||||
title={breadCrumbs.join(' / ')}
|
||||
>
|
||||
{breadCrumbs.join(' / ')}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(Item, areEqual)
|
||||
@@ -1,39 +0,0 @@
|
||||
import type { NotionPageTreeItem, NotionPageTreeMap } from './index'
|
||||
import type { DataSourceNotionPageMap } from '@/models/common'
|
||||
|
||||
export const recursivePushInParentDescendants = (
|
||||
pagesMap: DataSourceNotionPageMap,
|
||||
listTreeMap: NotionPageTreeMap,
|
||||
current: NotionPageTreeItem,
|
||||
leafItem: NotionPageTreeItem,
|
||||
) => {
|
||||
const parentId = current.parent_id
|
||||
const pageId = current.page_id
|
||||
|
||||
if (!parentId || !pageId)
|
||||
return
|
||||
|
||||
if (parentId !== 'root' && pagesMap[parentId]) {
|
||||
if (!listTreeMap[parentId]) {
|
||||
const children = new Set([pageId])
|
||||
const descendants = new Set([pageId, leafItem.page_id])
|
||||
listTreeMap[parentId] = {
|
||||
...pagesMap[parentId],
|
||||
children,
|
||||
descendants,
|
||||
depth: 0,
|
||||
ancestors: [],
|
||||
}
|
||||
}
|
||||
else {
|
||||
listTreeMap[parentId].children.add(pageId)
|
||||
listTreeMap[parentId].descendants.add(pageId)
|
||||
listTreeMap[parentId].descendants.add(leafItem.page_id)
|
||||
}
|
||||
leafItem.depth++
|
||||
leafItem.ancestors.unshift(listTreeMap[parentId].page_name)
|
||||
|
||||
if (listTreeMap[parentId].parent_id !== 'root')
|
||||
recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap[parentId], leafItem)
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { useContext } from 'react'
|
||||
import { use } from 'react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import DataSourceProvider, { DataSourceContext } from '../provider'
|
||||
|
||||
@@ -11,7 +11,7 @@ vi.mock('../', () => ({
|
||||
|
||||
// Test consumer component that reads from context
|
||||
function ContextConsumer() {
|
||||
const store = useContext(DataSourceContext)
|
||||
const store = use(DataSourceContext)
|
||||
return (
|
||||
<div data-testid="context-value" data-has-store={store !== null}>
|
||||
{store ? 'has-store' : 'no-store'}
|
||||
@@ -65,7 +65,7 @@ describe('DataSourceProvider', () => {
|
||||
const storeValues: Array<typeof mockStore | null> = []
|
||||
|
||||
function StoreCapture() {
|
||||
const store = useContext(DataSourceContext)
|
||||
const store = use(DataSourceContext)
|
||||
storeValues.push(store as typeof mockStore | null)
|
||||
return null
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import type { LocalFileSliceShape } from './slices/local-file'
|
||||
import type { OnlineDocumentSliceShape } from './slices/online-document'
|
||||
import type { OnlineDriveSliceShape } from './slices/online-drive'
|
||||
import type { WebsiteCrawlSliceShape } from './slices/website-crawl'
|
||||
import { useContext } from 'react'
|
||||
import { use } from 'react'
|
||||
import { createStore, useStore } from 'zustand'
|
||||
import { DataSourceContext } from './provider'
|
||||
import { createCommonSlice } from './slices/common'
|
||||
@@ -29,7 +29,7 @@ export const createDataSourceStore = () => {
|
||||
}
|
||||
|
||||
export const useDataSourceStoreWithSelector = <T>(selector: (state: DataSourceShape) => T): T => {
|
||||
const store = useContext(DataSourceContext)
|
||||
const store = use(DataSourceContext)
|
||||
if (!store)
|
||||
throw new Error('Missing DataSourceContext.Provider in the tree')
|
||||
|
||||
@@ -37,7 +37,7 @@ export const useDataSourceStoreWithSelector = <T>(selector: (state: DataSourceSh
|
||||
}
|
||||
|
||||
export const useDataSourceStore = () => {
|
||||
const store = useContext(DataSourceContext)
|
||||
const store = use(DataSourceContext)
|
||||
if (!store)
|
||||
throw new Error('Missing DataSourceContext.Provider in the tree')
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ vi.mock('@/app/components/datasets/common/image-list', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
// Markdown uses next/dynamic and react-syntax-highlighter (ESM)
|
||||
// Markdown uses next/dynamic and shiki (ESM)
|
||||
vi.mock('@/app/components/base/markdown', () => ({
|
||||
Markdown: ({ content, className }: { content: string, className?: string }) => (
|
||||
<div data-testid="markdown" className={`markdown-body ${className || ''}`}>{content}</div>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { createContext, useContext } from 'react'
|
||||
import { createContext, use } from 'react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import styles from './quota-panel.module.css'
|
||||
|
||||
@@ -41,7 +41,7 @@ const SystemQuotaCard = ({
|
||||
}
|
||||
|
||||
const Label = ({ children, className }: { children: ReactNode, className?: string }) => {
|
||||
const variant = useContext(VariantContext)
|
||||
const variant = use(VariantContext)
|
||||
return (
|
||||
<div className={cn(
|
||||
'relative z-1 flex items-center gap-1 truncate px-1.5 pt-1 system-xs-medium',
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import { useContext } from 'react'
|
||||
import { use } from 'react'
|
||||
import { HooksStoreContext, HooksStoreContextProvider } from '../provider'
|
||||
|
||||
const mockRefreshAll = vi.fn()
|
||||
@@ -27,7 +27,7 @@ vi.mock('../store', async () => {
|
||||
})
|
||||
|
||||
const Consumer = () => {
|
||||
const store = useContext(HooksStoreContext)
|
||||
const store = use(HooksStoreContext)
|
||||
return <div>{store ? 'has-hooks-store' : 'missing-hooks-store'}</div>
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
import type { ReactNode } from 'react'
|
||||
import { createContext, useContext } from 'react'
|
||||
import { createContext, use } from 'react'
|
||||
|
||||
type MCPToolAvailabilityContextValue = {
|
||||
versionSupported?: boolean
|
||||
@@ -26,7 +26,7 @@ export const MCPToolAvailabilityProvider = ({
|
||||
)
|
||||
|
||||
export const useMCPToolAvailability = (): MCPToolAvailability => {
|
||||
const context = useContext(MCPToolAvailabilityContext)
|
||||
const context = use(MCPToolAvailabilityContext)
|
||||
if (context === undefined)
|
||||
return { allowed: true }
|
||||
|
||||
|
||||
@@ -909,4 +909,19 @@
|
||||
[data-theme='light'] [data-hide-on-theme='light'] {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* Shiki code block line numbers */
|
||||
.shiki-line-numbers code {
|
||||
counter-reset: line;
|
||||
}
|
||||
.shiki-line-numbers .line::before {
|
||||
counter-increment: line;
|
||||
content: counter(line);
|
||||
display: inline-block;
|
||||
width: 1rem;
|
||||
margin-right: 0.75rem;
|
||||
text-align: right;
|
||||
color: var(--color-text-quaternary);
|
||||
user-select: none;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1070,9 +1070,6 @@
|
||||
filter: invert(50%);
|
||||
}
|
||||
|
||||
.markdown-body .react-syntax-highlighter-line-number {
|
||||
color: var(--color-text-quaternary);
|
||||
}
|
||||
.markdown-body .abcjs-inline-audio .abcjs-btn {
|
||||
display: flex !important;
|
||||
}
|
||||
|
||||
@@ -748,9 +748,6 @@
|
||||
"no-restricted-imports": {
|
||||
"count": 2
|
||||
},
|
||||
"react-refresh/only-export-components": {
|
||||
"count": 1
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 1
|
||||
}
|
||||
@@ -3150,9 +3147,6 @@
|
||||
"react/set-state-in-effect": {
|
||||
"count": 7
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 1
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 9
|
||||
}
|
||||
@@ -3339,11 +3333,6 @@
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/base/notion-page-selector/base.tsx": {
|
||||
"react/set-state-in-effect": {
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/base/notion-page-selector/credential-selector/index.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 2
|
||||
@@ -3359,14 +3348,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/base/notion-page-selector/page-selector/index.tsx": {
|
||||
"react/set-state-in-effect": {
|
||||
"count": 1
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 3
|
||||
}
|
||||
},
|
||||
"app/components/base/notion-page-selector/search-input/index.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 1
|
||||
@@ -3635,17 +3616,11 @@
|
||||
}
|
||||
},
|
||||
"app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx": {
|
||||
"react-refresh/only-export-components": {
|
||||
"count": 1
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/base/prompt-editor/plugins/update-block.tsx": {
|
||||
"react-refresh/only-export-components": {
|
||||
"count": 2
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 2
|
||||
}
|
||||
@@ -4848,16 +4823,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/index.tsx": {
|
||||
"react/set-state-in-effect": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/item.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 3
|
||||
}
|
||||
},
|
||||
"app/components/datasets/documents/create-from-pipeline/data-source/online-documents/title.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 1
|
||||
@@ -7927,9 +7892,6 @@
|
||||
}
|
||||
},
|
||||
"app/components/signin/countdown.tsx": {
|
||||
"react-refresh/only-export-components": {
|
||||
"count": 2
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 1
|
||||
}
|
||||
@@ -8332,7 +8294,7 @@
|
||||
},
|
||||
"app/components/workflow/block-selector/index-bar.tsx": {
|
||||
"react-refresh/only-export-components": {
|
||||
"count": 5
|
||||
"count": 1
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 1
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
// @ts-check
|
||||
|
||||
import antfu, { GLOB_MARKDOWN, GLOB_MARKDOWN_CODE, GLOB_TESTS, GLOB_TS, GLOB_TSX, isInEditorEnv, isInGitHooksOrLintStaged } from '@antfu/eslint-config'
|
||||
import pluginReact from '@eslint-react/eslint-plugin'
|
||||
import pluginQuery from '@tanstack/eslint-plugin-query'
|
||||
import md from 'eslint-markdown'
|
||||
import tailwindcss from 'eslint-plugin-better-tailwindcss'
|
||||
import hyoban from 'eslint-plugin-hyoban'
|
||||
import markdownPreferences from 'eslint-plugin-markdown-preferences'
|
||||
import noBarrelFiles from 'eslint-plugin-no-barrel-files'
|
||||
import { reactRefresh } from 'eslint-plugin-react-refresh'
|
||||
import sonar from 'eslint-plugin-sonarjs'
|
||||
import storybook from 'eslint-plugin-storybook'
|
||||
import {
|
||||
@@ -26,11 +24,14 @@ process.env.TAILWIND_MODE ??= 'ESLINT'
|
||||
|
||||
const disableRuleAutoFix = !(isInEditorEnv() || isInGitHooksOrLintStaged())
|
||||
|
||||
const plugins = pluginReact.configs.all.plugins
|
||||
|
||||
export default antfu(
|
||||
{
|
||||
react: false,
|
||||
react: {
|
||||
overrides: {
|
||||
'react/set-state-in-effect': 'error',
|
||||
'react/no-unnecessary-use-prefix': 'error',
|
||||
},
|
||||
},
|
||||
nextjs: {
|
||||
overrides: {
|
||||
'next/no-img-element': 'off',
|
||||
@@ -58,24 +59,6 @@ export default antfu(
|
||||
e18e: false,
|
||||
pnpm: false,
|
||||
},
|
||||
{
|
||||
plugins: {
|
||||
'react': plugins?.['@eslint-react'],
|
||||
'react-dom': plugins?.['@eslint-react/dom'],
|
||||
'react-naming-convention': plugins?.['@eslint-react/naming-convention'],
|
||||
'react-rsc': plugins?.['@eslint-react/rsc'],
|
||||
'react-web-api': plugins?.['@eslint-react/web-api'],
|
||||
},
|
||||
},
|
||||
{
|
||||
files: [GLOB_TS, GLOB_TSX],
|
||||
rules: {
|
||||
...pluginReact.configs['recommended-typescript'].rules,
|
||||
'react/prefer-namespace-import': 'error',
|
||||
'react/set-state-in-effect': 'error',
|
||||
'react/no-unnecessary-use-prefix': 'error',
|
||||
},
|
||||
},
|
||||
{
|
||||
files: [...GLOB_TESTS, GLOB_MARKDOWN_CODE, 'vitest.setup.ts', 'test/i18n-mock.ts'],
|
||||
rules: {
|
||||
@@ -92,7 +75,6 @@ export default antfu(
|
||||
'no-barrel-files/no-barrel-files': 'error',
|
||||
},
|
||||
},
|
||||
reactRefresh.configs.next(),
|
||||
markdownPreferences.configs.standard,
|
||||
{
|
||||
files: [GLOB_MARKDOWN],
|
||||
@@ -231,10 +213,3 @@ export default antfu(
|
||||
'tailwindcss/no-unnecessary-whitespace',
|
||||
]
|
||||
: [])
|
||||
.renamePlugins({
|
||||
'@eslint-react': 'react',
|
||||
'@eslint-react/dom': 'react-dom',
|
||||
'@eslint-react/naming-convention': 'react-naming-convention',
|
||||
'@eslint-react/rsc': 'react-rsc',
|
||||
'@eslint-react/web-api': 'react-web-api',
|
||||
})
|
||||
|
||||
@@ -6,7 +6,6 @@ export function getInitOptions(): InitOptions {
|
||||
// We do not have en for fallback
|
||||
load: 'currentOnly',
|
||||
fallbackLng: 'en-US',
|
||||
showSupportNotice: false,
|
||||
partialBundledLanguages: true,
|
||||
keySeparator: false,
|
||||
ns: namespaces,
|
||||
|
||||
@@ -136,6 +136,9 @@
|
||||
"dataSource.notion.pagesAuthorized": "Pages authorized",
|
||||
"dataSource.notion.remove": "Remove",
|
||||
"dataSource.notion.selector.addPages": "Add pages",
|
||||
"dataSource.notion.selector.configure": "Configure Notion",
|
||||
"dataSource.notion.selector.docs": "Notion docs",
|
||||
"dataSource.notion.selector.headerTitle": "Choose Notion pages",
|
||||
"dataSource.notion.selector.noSearchResult": "No search results",
|
||||
"dataSource.notion.selector.pageSelected": "Pages Selected",
|
||||
"dataSource.notion.selector.preview": "PREVIEW",
|
||||
|
||||
@@ -81,6 +81,7 @@
|
||||
"@tailwindcss/typography": "catalog:",
|
||||
"@tanstack/react-form": "catalog:",
|
||||
"@tanstack/react-query": "catalog:",
|
||||
"@tanstack/react-virtual": "catalog:",
|
||||
"abcjs": "catalog:",
|
||||
"ahooks": "catalog:",
|
||||
"class-variance-authority": "catalog:",
|
||||
@@ -100,6 +101,7 @@
|
||||
"es-toolkit": "catalog:",
|
||||
"fast-deep-equal": "catalog:",
|
||||
"foxact": "catalog:",
|
||||
"hast-util-to-jsx-runtime": "catalog:",
|
||||
"html-entities": "catalog:",
|
||||
"html-to-image": "catalog:",
|
||||
"i18next": "catalog:",
|
||||
@@ -134,14 +136,13 @@
|
||||
"react-papaparse": "catalog:",
|
||||
"react-pdf-highlighter": "catalog:",
|
||||
"react-sortablejs": "catalog:",
|
||||
"react-syntax-highlighter": "catalog:",
|
||||
"react-textarea-autosize": "catalog:",
|
||||
"react-window": "catalog:",
|
||||
"reactflow": "catalog:",
|
||||
"remark-breaks": "catalog:",
|
||||
"remark-directive": "catalog:",
|
||||
"scheduler": "catalog:",
|
||||
"sharp": "catalog:",
|
||||
"shiki": "catalog:",
|
||||
"sortablejs": "catalog:",
|
||||
"std-semver": "catalog:",
|
||||
"streamdown": "catalog:",
|
||||
@@ -196,8 +197,6 @@
|
||||
"@types/qs": "catalog:",
|
||||
"@types/react": "catalog:",
|
||||
"@types/react-dom": "catalog:",
|
||||
"@types/react-syntax-highlighter": "catalog:",
|
||||
"@types/react-window": "catalog:",
|
||||
"@types/sortablejs": "catalog:",
|
||||
"@typescript-eslint/parser": "catalog:",
|
||||
"@typescript/native-preview": "catalog:",
|
||||
@@ -212,7 +211,6 @@
|
||||
"eslint-plugin-hyoban": "catalog:",
|
||||
"eslint-plugin-markdown-preferences": "catalog:",
|
||||
"eslint-plugin-no-barrel-files": "catalog:",
|
||||
"eslint-plugin-react-hooks": "catalog:",
|
||||
"eslint-plugin-react-refresh": "catalog:",
|
||||
"eslint-plugin-sonarjs": "catalog:",
|
||||
"eslint-plugin-storybook": "catalog:",
|
||||
@@ -221,7 +219,6 @@
|
||||
"knip": "catalog:",
|
||||
"postcss": "catalog:",
|
||||
"react-server-dom-webpack": "catalog:",
|
||||
"sass": "catalog:",
|
||||
"storybook": "catalog:",
|
||||
"tailwindcss": "catalog:",
|
||||
"tsx": "catalog:",
|
||||
|
||||
Reference in New Issue
Block a user