Merge branch 'main' into jzh

This commit is contained in:
JzoNg
2026-04-03 15:15:11 +08:00
50 changed files with 2578 additions and 3141 deletions

View File

@@ -8,7 +8,7 @@ from hashlib import sha256
from typing import Any, TypedDict, cast
from pydantic import BaseModel, TypeAdapter
from sqlalchemy import func, select
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session
@@ -144,22 +144,26 @@ class AccountService:
@staticmethod
def load_user(user_id: str) -> None | Account:
account = db.session.query(Account).filter_by(id=user_id).first()
account = db.session.get(Account, user_id)
if not account:
return None
if account.status == AccountStatus.BANNED:
raise Unauthorized("Account is banned.")
current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
current_tenant = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True)
.limit(1)
)
if current_tenant:
account.set_tenant_id(current_tenant.tenant_id)
else:
available_ta = (
db.session.query(TenantAccountJoin)
.filter_by(account_id=account.id)
available_ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
.limit(1)
)
if not available_ta:
return None
@@ -195,7 +199,7 @@ class AccountService:
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
"""authenticate account with email and password"""
account = db.session.query(Account).filter_by(email=email).first()
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
if not account:
raise AccountPasswordError("Invalid email or password.")
@@ -371,8 +375,10 @@ class AccountService:
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: AccountIntegrate | None = (
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
account_integrate: AccountIntegrate | None = db.session.scalar(
select(AccountIntegrate)
.where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider)
.limit(1)
)
if account_integrate:
@@ -416,7 +422,9 @@ class AccountService:
def update_account_email(account: Account, email: str) -> Account:
"""Update account email"""
account.email = email
account_integrate = db.session.query(AccountIntegrate).filter_by(account_id=account.id).first()
account_integrate = db.session.scalar(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1)
)
if account_integrate:
db.session.delete(account_integrate)
db.session.add(account)
@@ -818,7 +826,7 @@ class AccountService:
)
)
account = db.session.query(Account).where(Account.email == email).first()
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
if not account:
return None
@@ -1018,7 +1026,7 @@ class AccountService:
@staticmethod
def check_email_unique(email: str) -> bool:
return db.session.query(Account).filter_by(email=email).first() is None
return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None
class TenantService:
@@ -1384,10 +1392,10 @@ class RegisterService:
db.session.add(dify_setup)
db.session.commit()
except Exception as e:
db.session.query(DifySetup).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Account).delete()
db.session.query(Tenant).delete()
db.session.execute(delete(DifySetup))
db.session.execute(delete(TenantAccountJoin))
db.session.execute(delete(Account))
db.session.execute(delete(Tenant))
db.session.commit()
logger.exception("Setup account failed, email: %s, name: %s", email, name)
@@ -1488,7 +1496,11 @@ class RegisterService:
TenantService.switch_tenant(account, tenant.id)
else:
TenantService.check_member_permission(tenant, inviter, account, "add")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if not ta:
TenantService.create_tenant_member(tenant, account, role)
@@ -1545,21 +1557,18 @@ class RegisterService:
if not invitation_data:
return None
tenant = (
db.session.query(Tenant)
.where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
.first()
tenant = db.session.scalar(
select(Tenant).where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal").limit(1)
)
if not tenant:
return None
tenant_account = (
db.session.query(Account, TenantAccountJoin.role)
tenant_account = db.session.execute(
select(Account, TenantAccountJoin.role)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
.first()
)
).first()
if not tenant_account:
return None

View File

@@ -4,6 +4,8 @@ import uuid
import pandas as pd
logger = logging.getLogger(__name__)
from typing import TypedDict
from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@@ -23,6 +25,27 @@ from tasks.annotation.enable_annotation_reply_task import enable_annotation_repl
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
class AnnotationJobStatusDict(TypedDict):
job_id: str
job_status: str
class EmbeddingModelDict(TypedDict):
embedding_provider_name: str
embedding_model_name: str
class AnnotationSettingDict(TypedDict):
id: str
enabled: bool
score_threshold: float
embedding_model: EmbeddingModelDict | dict
class AnnotationSettingDisabledDict(TypedDict):
enabled: bool
class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
@@ -85,7 +108,7 @@ class AppAnnotationService:
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str):
def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict:
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
@@ -109,7 +132,7 @@ class AppAnnotationService:
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def disable_app_annotation(cls, app_id: str):
def disable_app_annotation(cls, app_id: str) -> AnnotationJobStatusDict:
_, current_tenant_id = current_account_with_tenant()
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(disable_app_annotation_key)
@@ -567,7 +590,7 @@ class AppAnnotationService:
db.session.commit()
@classmethod
def get_app_annotation_setting_by_app_id(cls, app_id: str):
def get_app_annotation_setting_by_app_id(cls, app_id: str) -> AnnotationSettingDict | AnnotationSettingDisabledDict:
_, current_tenant_id = current_account_with_tenant()
# get app info
app = (
@@ -602,7 +625,9 @@ class AppAnnotationService:
return {"enabled": False}
@classmethod
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
def update_app_annotation_setting(
cls, app_id: str, annotation_setting_id: str, args: dict
) -> AnnotationSettingDict:
current_user, current_tenant_id = current_account_with_tenant()
# get app info
app = (

View File

@@ -5,7 +5,7 @@ from urllib.parse import urlparse
import httpx
from graphon.nodes.http_request.exc import InvalidHttpMethodError
from sqlalchemy import select
from sqlalchemy import func, select
from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
@@ -103,8 +103,10 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str, tenant_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@@ -112,8 +114,10 @@ class ExternalDatasetService:
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@@ -132,8 +136,10 @@ class ExternalDatasetService:
@staticmethod
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@@ -144,9 +150,12 @@ class ExternalDatasetService:
@staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
count = (
db.session.query(ExternalKnowledgeBindings)
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
.count()
db.session.scalar(
select(func.count(ExternalKnowledgeBindings.id)).where(
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id
)
)
or 0
)
if count > 0:
return True, count
@@ -154,8 +163,10 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: ExternalKnowledgeBindings | None = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
external_knowledge_binding: ExternalKnowledgeBindings | None = db.session.scalar(
select(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
.limit(1)
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
@@ -163,8 +174,10 @@ class ExternalDatasetService:
@staticmethod
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None or external_knowledge_api.settings is None:
raise ValueError("api template not found")
@@ -238,12 +251,17 @@ class ExternalDatasetService:
@staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
# check if dataset name already exists
if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
if db.session.scalar(
select(Dataset).where(Dataset.name == args.get("name"), Dataset.tenant_id == tenant_id).limit(1)
):
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
.first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(
ExternalKnowledgeApis.id == args.get("external_knowledge_api_id"),
ExternalKnowledgeApis.tenant_id == tenant_id,
)
.limit(1)
)
if external_knowledge_api is None:
@@ -286,16 +304,18 @@ class ExternalDatasetService:
external_retrieval_parameters: dict,
metadata_condition: MetadataCondition | None = None,
):
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
external_knowledge_binding = db.session.scalar(
select(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
.limit(1)
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=external_knowledge_binding.external_knowledge_api_id)
.first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
.limit(1)
)
if external_knowledge_api is None or external_knowledge_api.settings is None:
raise ValueError("external api template not found")

View File

@@ -156,27 +156,27 @@ class RagPipelineService:
:param template_id: template id
:param template_info: template info
"""
customized_template: PipelineCustomizedTemplate | None = (
db.session.query(PipelineCustomizedTemplate)
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
.first()
.limit(1)
)
if not customized_template:
raise ValueError("Customized pipeline template not found.")
# check template name is exist
template_name = template_info.name
if template_name:
template = (
db.session.query(PipelineCustomizedTemplate)
template = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
PipelineCustomizedTemplate.id != template_id,
)
.first()
.limit(1)
)
if template:
raise ValueError("Template name is already exists")
@@ -192,13 +192,13 @@ class RagPipelineService:
"""
Delete customized pipeline template.
"""
customized_template: PipelineCustomizedTemplate | None = (
db.session.query(PipelineCustomizedTemplate)
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
.first()
.limit(1)
)
if not customized_template:
raise ValueError("Customized pipeline template not found.")
@@ -210,14 +210,14 @@ class RagPipelineService:
Get draft workflow
"""
# fetch draft workflow by rag pipeline
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
.first()
.limit(1)
)
# return draft workflow
@@ -232,28 +232,28 @@ class RagPipelineService:
return None
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.id == pipeline.workflow_id,
)
.first()
.limit(1)
)
return workflow
def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
"""Fetch a published workflow snapshot by ID for restore operations."""
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.id == workflow_id,
)
.first()
.limit(1)
)
if workflow and workflow.version == Workflow.VERSION_DRAFT:
raise IsDraftWorkflowError("source workflow must be published")
@@ -974,7 +974,7 @@ class RagPipelineService:
if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE:
document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID)
if document_id:
document = db.session.query(Document).where(Document.id == document_id.value).first()
document = db.session.get(Document, document_id.value)
if document:
document.indexing_status = IndexingStatus.ERROR
document.error = error
@@ -1178,12 +1178,12 @@ class RagPipelineService:
"""
Publish customized pipeline template
"""
pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
pipeline = db.session.get(Pipeline, pipeline_id)
if not pipeline:
raise ValueError("Pipeline not found")
if not pipeline.workflow_id:
raise ValueError("Pipeline workflow not found")
workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
workflow = db.session.get(Workflow, pipeline.workflow_id)
if not workflow:
raise ValueError("Workflow not found")
with Session(db.engine) as session:
@@ -1194,21 +1194,21 @@ class RagPipelineService:
# check template name is exist
template_name = args.get("name")
if template_name:
template = (
db.session.query(PipelineCustomizedTemplate)
template = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
)
.first()
.limit(1)
)
if template:
raise ValueError("Template name is already exists")
max_position = (
db.session.query(func.max(PipelineCustomizedTemplate.position))
.where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
.scalar()
max_position = db.session.scalar(
select(func.max(PipelineCustomizedTemplate.position)).where(
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id
)
)
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@@ -1239,13 +1239,14 @@ class RagPipelineService:
def is_workflow_exist(self, pipeline: Pipeline) -> bool:
return (
db.session.query(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == Workflow.VERSION_DRAFT,
db.session.scalar(
select(func.count(Workflow.id)).where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == Workflow.VERSION_DRAFT,
)
)
.count()
or 0
) > 0
def get_node_last_run(
@@ -1353,11 +1354,11 @@ class RagPipelineService:
def get_recommended_plugins(self, type: str) -> dict:
# Query active recommended plugins
query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
if type and type != "all":
query = query.where(PipelineRecommendedPlugin.type == type)
stmt = stmt.where(PipelineRecommendedPlugin.type == type)
pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all()
pipeline_recommended_plugins = db.session.scalars(stmt.order_by(PipelineRecommendedPlugin.position.asc())).all()
if not pipeline_recommended_plugins:
return {
@@ -1396,14 +1397,12 @@ class RagPipelineService:
"""
Retry error document
"""
document_pipeline_execution_log = (
db.session.query(DocumentPipelineExecutionLog)
.where(DocumentPipelineExecutionLog.document_id == document.id)
.first()
document_pipeline_execution_log = db.session.scalar(
select(DocumentPipelineExecutionLog).where(DocumentPipelineExecutionLog.document_id == document.id).limit(1)
)
if not document_pipeline_execution_log:
raise ValueError("Document pipeline execution log not found")
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_execution_log.pipeline_id).first()
pipeline = db.session.get(Pipeline, document_pipeline_execution_log.pipeline_id)
if not pipeline:
raise ValueError("Pipeline not found")
# convert to app config
@@ -1432,23 +1431,23 @@ class RagPipelineService:
"""
Get datasource plugins
"""
dataset: Dataset | None = (
db.session.query(Dataset)
dataset: Dataset | None = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = (
db.session.query(Pipeline)
pipeline: Pipeline | None = db.session.scalar(
select(Pipeline)
.where(
Pipeline.id == dataset.pipeline_id,
Pipeline.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not pipeline:
raise ValueError("Pipeline not found")
@@ -1530,23 +1529,23 @@ class RagPipelineService:
"""
Get pipeline
"""
dataset: Dataset | None = (
db.session.query(Dataset)
dataset: Dataset | None = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = (
db.session.query(Pipeline)
pipeline: Pipeline | None = db.session.scalar(
select(Pipeline)
.where(
Pipeline.id == dataset.pipeline_id,
Pipeline.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not pipeline:
raise ValueError("Pipeline not found")

View File

@@ -292,7 +292,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
"""
api = Mock(spec=ExternalKnowledgeApis)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
mock_db_session.scalar.return_value = api
result = ExternalDatasetService.get_external_knowledge_api("api-id", "tenant-id")
assert result is api
@@ -302,7 +302,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
When the record is absent, a ``ValueError`` is raised.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.get_external_knowledge_api("missing-id", "tenant-id")
@@ -320,7 +320,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
existing_api = Mock(spec=ExternalKnowledgeApis)
existing_api.settings_dict = {"api_key": "stored-key"}
existing_api.settings = '{"api_key":"stored-key"}'
mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_api
mock_db_session.scalar.return_value = existing_api
args = {
"name": "New Name",
@@ -340,7 +340,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
Updating a nonexistent API template should raise ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.update_external_knowledge_api(
@@ -356,7 +356,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
"""
api = Mock(spec=ExternalKnowledgeApis)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
mock_db_session.scalar.return_value = api
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "api-1")
@@ -368,7 +368,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
Deletion of a missing template should raise ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "missing")
@@ -394,7 +394,7 @@ class TestExternalDatasetServiceUsageAndBindings:
When there are bindings, ``external_knowledge_api_use_check`` returns True and count.
"""
mock_db_session.query.return_value.filter_by.return_value.count.return_value = 3
mock_db_session.scalar.return_value = 3
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
@@ -406,7 +406,7 @@ class TestExternalDatasetServiceUsageAndBindings:
Zero bindings should return ``(False, 0)``.
"""
mock_db_session.query.return_value.filter_by.return_value.count.return_value = 0
mock_db_session.scalar.return_value = 0
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
@@ -419,7 +419,7 @@ class TestExternalDatasetServiceUsageAndBindings:
"""
binding = Mock(spec=ExternalKnowledgeBindings)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = binding
mock_db_session.scalar.return_value = binding
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
assert result is binding
@@ -429,7 +429,7 @@ class TestExternalDatasetServiceUsageAndBindings:
Missing binding should result in a ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="external knowledge binding not found"):
ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
@@ -460,7 +460,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
)
# Raw string; the service itself calls json.loads on it
mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
mock_db_session.scalar.return_value = external_api
process_parameter = {"foo": "value", "bar": "optional"}
@@ -474,7 +474,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
When the referenced API template is missing, a ``ValueError`` is raised.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.document_create_args_validate("tenant-1", "missing", {})
@@ -488,7 +488,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
external_api.settings = (
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
mock_db_session.scalar.return_value = external_api
process_parameter = {"bar": "present"} # missing "foo"
@@ -702,7 +702,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
}
# No existing dataset with same name.
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
None, # duplicatename check
Mock(spec=ExternalKnowledgeApis), # external knowledge api
]
@@ -724,7 +724,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
"""
existing_dataset = Mock(spec=Dataset)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_dataset
mock_db_session.scalar.return_value = existing_dataset
args = {
"name": "Existing",
@@ -744,7 +744,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
"""
# First call: duplicate name check not found.
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
None,
None, # external knowledge api lookup
]
@@ -763,8 +763,10 @@ class TestExternalDatasetServiceCreateExternalDataset:
``external_knowledge_id`` and ``external_knowledge_api_id`` are mandatory.
"""
# duplicate name check
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
# duplicate name check — two calls to create_external_dataset, each does 2 scalar calls
mock_db_session.scalar.side_effect = [
None,
Mock(spec=ExternalKnowledgeApis),
None,
Mock(spec=ExternalKnowledgeApis),
]
@@ -826,7 +828,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
# First query: binding; second query: api.
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
binding,
api,
]
@@ -861,7 +863,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
Missing binding should raise ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="external knowledge binding not found"):
ExternalDatasetService.fetch_external_knowledge_retrieval(
@@ -878,7 +880,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
"""
binding = ExternalDatasetTestDataFactory.create_external_binding()
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
binding,
None,
]
@@ -901,7 +903,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
api = Mock(spec=ExternalKnowledgeApis)
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
binding,
api,
]

View File

@@ -117,9 +117,7 @@ def test_get_all_published_workflow_applies_limit_and_has_more(rag_pipeline_serv
def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service) -> None:
first_query = mocker.Mock()
first_query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=first_query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
with pytest.raises(ValueError, match="Dataset not found"):
rag_pipeline_service.get_pipeline("tenant-1", "dataset-1")
@@ -131,12 +129,8 @@ def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service
def test_update_customized_pipeline_template_success(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
# First query finds the template, second query (duplicate check) returns None
query_mock_1 = mocker.Mock()
query_mock_1.where.return_value.first.return_value = template
query_mock_2 = mocker.Mock()
query_mock_2.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", side_effect=[query_mock_1, query_mock_2])
# First scalar finds the template, second scalar (duplicate check) returns None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, None])
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
@@ -152,9 +146,7 @@ def test_update_customized_pipeline_template_success(mocker) -> None:
def test_update_customized_pipeline_template_not_found(mocker) -> None:
query_mock = mocker.Mock()
query_mock.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
info = PipelineTemplateInfoEntity(name="x", description="d", icon_info=IconInfo(icon="i"))
@@ -166,9 +158,7 @@ def test_update_customized_pipeline_template_duplicate_name(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
duplicate = SimpleNamespace(name="dup")
query_mock = mocker.Mock()
query_mock.where.return_value.first.side_effect = [template, duplicate]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, duplicate])
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
info = PipelineTemplateInfoEntity(name="dup", description="d", icon_info=IconInfo(icon="i"))
@@ -181,9 +171,7 @@ def test_update_customized_pipeline_template_duplicate_name(mocker) -> None:
def test_delete_customized_pipeline_template_success(mocker) -> None:
template = SimpleNamespace(id="tpl-1")
query_mock = mocker.Mock()
query_mock.where.return_value.first.return_value = template
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
delete_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.delete")
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
@@ -196,9 +184,7 @@ def test_delete_customized_pipeline_template_success(mocker) -> None:
def test_delete_customized_pipeline_template_not_found(mocker) -> None:
query_mock = mocker.Mock()
query_mock.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
with pytest.raises(ValueError, match="Customized pipeline template not found"):
@@ -397,18 +383,14 @@ def test_get_rag_pipeline_workflow_run_delegates(mocker, rag_pipeline_service) -
def test_is_workflow_exist_returns_true_when_draft_exists(mocker, rag_pipeline_service) -> None:
query_mock = mocker.Mock()
query_mock.where.return_value.count.return_value = 1
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=1)
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
assert rag_pipeline_service.is_workflow_exist(pipeline) is True
def test_is_workflow_exist_returns_false_when_no_draft(mocker, rag_pipeline_service) -> None:
query_mock = mocker.Mock()
query_mock.where.return_value.count.return_value = 0
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=0)
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
assert rag_pipeline_service.is_workflow_exist(pipeline) is False
@@ -738,8 +720,7 @@ def test_get_second_step_parameters_success(mocker, rag_pipeline_service) -> Non
def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_service) -> None:
from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate
from models.workflow import Workflow
from models.dataset import Pipeline
# 1. Setup mocks
pipeline = mocker.Mock(spec=Pipeline)
@@ -754,36 +735,15 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi
# Mock db itself to avoid app context errors
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
# Improved mocking for session.query
def mock_query_side_effect(model):
m = mocker.Mock()
if model == Pipeline:
m.where.return_value.first.return_value = pipeline
elif model == Workflow:
m.where.return_value.first.return_value = workflow
elif model == PipelineCustomizedTemplate:
m.where.return_value.first.return_value = None
elif model == Dataset:
m.where.return_value.first.return_value = mocker.Mock()
else:
# For func.max cases
m.where.return_value.scalar.return_value = 5
m.where.return_value.first.return_value = mocker.Mock()
return m
mock_db.session.query.side_effect = mock_query_side_effect
# Mock get() for Pipeline and Workflow PK lookups
mock_db.session.get.side_effect = [pipeline, workflow]
# Mock scalar() for template name check (None) and max position (5)
mock_db.session.scalar.side_effect = [None, 5]
# Mock retrieve_dataset
dataset = mocker.Mock()
pipeline.retrieve_dataset.return_value = dataset
# Mock max position
mocker.patch("services.rag_pipeline.rag_pipeline.func.max", return_value=1)
mocker.patch(
"services.rag_pipeline.rag_pipeline.db.session.query.return_value.where.return_value.scalar",
return_value=5,
)
# Mock RagPipelineDslService
mock_dsl_service = mocker.Mock()
mock_dsl_service.export_rag_pipeline_dsl.return_value = {"dsl": "content"}
@@ -839,9 +799,7 @@ def test_get_datasource_plugins_success(mocker, rag_pipeline_service) -> None:
workflow.rag_pipeline_variables = []
# Mock queries
mock_query = mocker.Mock()
mock_query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=mock_query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
@@ -881,11 +839,9 @@ def test_retry_error_document_success(mocker, rag_pipeline_service) -> None:
workflow = mocker.Mock()
# Mock queries
mock_query = mocker.Mock()
# Log lookup, then Pipeline lookup
mock_query.where.return_value.first.side_effect = [log, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=mock_query)
# Mock queries: Log lookup via scalar, Pipeline lookup via get
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
@@ -913,7 +869,7 @@ def test_set_datasource_variables_success(mocker, rag_pipeline_service) -> None:
# Mock db aggressively
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.engine = mocker.Mock()
mock_db.session.query.return_value.where.return_value.first.return_value = mocker.Mock()
mock_db.session.scalar.return_value = mocker.Mock()
pipeline = mocker.Mock(spec=Pipeline)
pipeline.id = "p-1"
@@ -976,7 +932,7 @@ def test_get_draft_workflow_success(mocker, rag_pipeline_service) -> None:
workflow = mocker.Mock(spec=Workflow)
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value.where.return_value.first.return_value = workflow
mock_db.session.scalar.return_value = workflow
# 2. Run test
result = rag_pipeline_service.get_draft_workflow(pipeline)
@@ -998,7 +954,7 @@ def test_get_published_workflow_success(mocker, rag_pipeline_service) -> None:
workflow = mocker.Mock(spec=Workflow)
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value.where.return_value.first.return_value = workflow
mock_db.session.scalar.return_value = workflow
# 2. Run test
result = rag_pipeline_service.get_published_workflow(pipeline)
@@ -1319,11 +1275,8 @@ def test_get_rag_pipeline_workflow_run_node_executions_returns_sorted_executions
def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value = query
query.order_by.return_value.all.return_value = []
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value = query
mock_db.session.scalars.return_value.all.return_value = []
result = rag_pipeline_service.get_recommended_plugins("all")
@@ -1336,11 +1289,8 @@ def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, ra
def test_get_recommended_plugins_returns_installed_and_uninstalled(mocker, rag_pipeline_service) -> None:
plugin_a = SimpleNamespace(plugin_id="plugin-a")
plugin_b = SimpleNamespace(plugin_id="plugin-b")
query = mocker.Mock()
query.where.return_value = query
query.order_by.return_value.all.return_value = [plugin_a, plugin_b]
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value = query
mock_db.session.scalars.return_value.all.return_value = [plugin_a, plugin_b]
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
mocker.patch(
"services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools",
@@ -1568,9 +1518,7 @@ def test_get_second_step_parameters_filters_first_step_variables(mocker, rag_pip
def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
with pytest.raises(ValueError, match="Document pipeline execution log not found"):
rag_pipeline_service.retry_error_document(
@@ -1581,9 +1529,7 @@ def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pi
def test_get_datasource_plugins_raises_when_workflow_not_found(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
with pytest.raises(ValueError, match="Pipeline or workflow not found"):
@@ -1656,8 +1602,7 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(mocker
document = SimpleNamespace(indexing_status="waiting", error=None)
query = mocker.Mock()
query.where.return_value.first.return_value = document
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=document)
add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add")
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
@@ -1712,9 +1657,7 @@ def test_run_datasource_node_preview_raises_for_unsupported_provider(mocker, rag
def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
with pytest.raises(ValueError, match="Pipeline not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
@@ -1722,9 +1665,7 @@ def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker
def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id=None)
query = mocker.Mock()
query.where.return_value.first.return_value = pipeline
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
with pytest.raises(ValueError, match="Pipeline workflow not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {"name": "template-name"})
@@ -1732,8 +1673,7 @@ def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(moc
def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
with pytest.raises(ValueError, match="Dataset not found"):
rag_pipeline_service.get_pipeline("t1", "d1")
@@ -1742,8 +1682,7 @@ def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service)
def test_get_pipeline_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, None]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, None])
with pytest.raises(ValueError, match="Pipeline not found"):
rag_pipeline_service.get_pipeline("t1", "d1")
@@ -1783,8 +1722,7 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None:
def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
query = mocker.Mock()
query.where.return_value.first.return_value = template
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
@@ -2011,8 +1949,7 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_
def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [pipeline, None]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None])
with pytest.raises(ValueError, match="Workflow not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
@@ -2021,11 +1958,9 @@ def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocke
def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
workflow = SimpleNamespace(id="wf-1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [pipeline, workflow]
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.engine = mocker.Mock()
mock_db.session.query.return_value = query
mock_db.session.get.side_effect = [pipeline, workflow]
session_ctx = mocker.MagicMock()
session_ctx.__enter__.return_value = SimpleNamespace()
session_ctx.__exit__.return_value = False
@@ -2038,11 +1973,8 @@ def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker
def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipeline_service) -> None:
plugin = SimpleNamespace(plugin_id="plugin-a")
query = mocker.Mock()
query.where.return_value = query
query.order_by.return_value.all.return_value = [plugin]
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value = query
mock_db.session.scalars.return_value.all.return_value = [plugin]
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
mocker.patch("services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools", return_value=[])
mocker.patch("services.rag_pipeline.rag_pipeline.marketplace.batch_fetch_plugin_by_ids", return_value=[])
@@ -2056,8 +1988,8 @@ def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipelin
def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
exec_log = SimpleNamespace(pipeline_id="p1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [exec_log, None]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
with pytest.raises(ValueError, match="Pipeline not found"):
rag_pipeline_service.retry_error_document(
@@ -2069,8 +2001,8 @@ def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_
exec_log = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [exec_log, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
with pytest.raises(ValueError, match="Workflow not found"):
@@ -2086,8 +2018,7 @@ def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, r
graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[]
)
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
assert rag_pipeline_service.get_datasource_plugins("t1", "d1", True) == []
@@ -2250,8 +2181,7 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published(
rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}],
)
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=workflow)
mocker.patch(
"services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials", return_value=[]
@@ -2291,8 +2221,7 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
],
)
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
mocker.patch(
"services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials",
@@ -2310,8 +2239,7 @@ def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service)
dataset = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
result = rag_pipeline_service.get_pipeline("t1", "d1")

View File

@@ -173,9 +173,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock
query_results = {("Account", "email", "test@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_account
mock_password_dependencies["compare_password"].return_value = True
@@ -188,9 +186,7 @@ class TestAccountService:
def test_authenticate_account_not_found(self, mock_db_dependencies):
"""Test authentication when account does not exist."""
# Setup smart database query mock - no matching results
query_results = {("Account", "email", "notfound@example.com"): None}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = None
# Execute test and verify exception
self._assert_exception_raised(
@@ -202,9 +198,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
# Setup smart database query mock
query_results = {("Account", "email", "banned@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_account
# Execute test and verify exception
self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password")
@@ -214,9 +208,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock
query_results = {("Account", "email", "test@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_account
mock_password_dependencies["compare_password"].return_value = False
@@ -230,9 +222,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="pending")
# Setup smart database query mock
query_results = {("Account", "email", "pending@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_account
mock_password_dependencies["compare_password"].return_value = True
@@ -422,12 +412,8 @@ class TestAccountService:
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock()
# Setup smart database query mock
query_results = {
("Account", "id", "user-123"): mock_account,
("TenantAccountJoin", "account_id", "user-123"): mock_tenant_join,
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = mock_account
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant_join
# Mock datetime
with patch("services.account_service.datetime") as mock_datetime:
@@ -444,9 +430,7 @@ class TestAccountService:
def test_load_user_not_found(self, mock_db_dependencies):
"""Test user loading when user does not exist."""
# Setup smart database query mock - no matching results
query_results = {("Account", "id", "non-existent-user"): None}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = None
# Execute test
result = AccountService.load_user("non-existent-user")
@@ -459,9 +443,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
# Setup smart database query mock
query_results = {("Account", "id", "user-123"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = mock_account
# Execute test and verify exception
self._assert_exception_raised(
@@ -476,13 +458,9 @@ class TestAccountService:
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
mock_available_tenant = TestAccountAssociatedDataFactory.create_tenant_join_mock(current=False)
# Setup smart database query mock for complex scenario
query_results = {
("Account", "id", "user-123"): mock_account,
("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
("TenantAccountJoin", "order_by", "first_available"): mock_available_tenant, # First available tenant
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = mock_account
# First scalar: current tenant (None), second scalar: available tenant
mock_db_dependencies["db"].session.scalar.side_effect = [None, mock_available_tenant]
# Mock datetime
with patch("services.account_service.datetime") as mock_datetime:
@@ -503,13 +481,9 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock for no tenants scenario
query_results = {
("Account", "id", "user-123"): mock_account,
("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
("TenantAccountJoin", "order_by", "first_available"): None, # No available tenants
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = mock_account
# First scalar: current tenant (None), second scalar: available tenant (None)
mock_db_dependencies["db"].session.scalar.side_effect = [None, None]
# Mock datetime
with patch("services.account_service.datetime") as mock_datetime:
@@ -1060,7 +1034,7 @@ class TestRegisterService:
)
# Verify rollback operations were called
mock_db_dependencies["db"].session.query.assert_called()
mock_db_dependencies["db"].session.execute.assert_called()
# ==================== Registration Tests ====================
@@ -1625,10 +1599,8 @@ class TestRegisterService:
mock_session_class.return_value.__exit__.return_value = None
mock_lookup.return_value = mock_existing_account
# Mock the db.session.query for TenantAccountJoin
mock_db_query = MagicMock()
mock_db_query.filter_by.return_value.first.return_value = None # No existing member
mock_db_dependencies["db"].session.query.return_value = mock_db_query
# Mock scalar for TenantAccountJoin lookup - no existing member
mock_db_dependencies["db"].session.scalar.return_value = None
# Mock TenantService methods
with (
@@ -1803,14 +1775,9 @@ class TestRegisterService:
}
mock_get_invitation_by_token.return_value = invitation_data
# Mock database queries - complex query mocking
mock_query1 = MagicMock()
mock_query1.where.return_value.first.return_value = mock_tenant
mock_query2 = MagicMock()
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
# Mock scalar for tenant lookup, execute for account+role lookup
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
mock_db_dependencies["db"].session.execute.return_value.first.return_value = (mock_account, "normal")
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
@@ -1842,10 +1809,8 @@ class TestRegisterService:
}
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
# Mock database queries - no tenant found
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = None
mock_db_dependencies["db"].session.query.return_value = mock_query
# Mock scalar for tenant lookup - not found
mock_db_dependencies["db"].session.scalar.return_value = None
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
@@ -1868,14 +1833,9 @@ class TestRegisterService:
}
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
# Mock database queries
mock_query1 = MagicMock()
mock_query1.filter.return_value.first.return_value = mock_tenant
mock_query2 = MagicMock()
mock_query2.join.return_value.where.return_value.first.return_value = None # No account found
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
# Mock scalar for tenant, execute for account+role
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
mock_db_dependencies["db"].session.execute.return_value.first.return_value = None # No account found
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
@@ -1901,14 +1861,9 @@ class TestRegisterService:
}
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
# Mock database queries
mock_query1 = MagicMock()
mock_query1.filter.return_value.first.return_value = mock_tenant
mock_query2 = MagicMock()
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
# Mock scalar for tenant, execute for account+role
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
mock_db_dependencies["db"].session.execute.return_value.first.return_value = (mock_account, "normal")
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")

View File

@@ -799,10 +799,7 @@ class TestExternalDatasetServiceGetAPI:
api_id = "api-123"
expected_api = factory.create_external_knowledge_api_mock(api_id=api_id)
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = expected_api
mock_db.session.scalar.return_value = expected_api
# Act
tenant_id = "tenant-123"
@@ -810,16 +807,12 @@ class TestExternalDatasetServiceGetAPI:
# Assert
assert result.id == api_id
mock_query.filter_by.assert_called_once_with(id=api_id, tenant_id=tenant_id)
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_api_not_found(self, mock_db, factory):
"""Test error when API is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="api template not found"):
@@ -848,10 +841,7 @@ class TestExternalDatasetServiceUpdateAPI:
"settings": {"endpoint": "https://new.example.com", "api_key": "new-key"},
}
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_api
mock_db.session.scalar.return_value = existing_api
# Act
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, user_id, api_id, args)
@@ -881,10 +871,7 @@ class TestExternalDatasetServiceUpdateAPI:
"settings": {"endpoint": "https://api.example.com", "api_key": HIDDEN_VALUE},
}
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_api
mock_db.session.scalar.return_value = existing_api
# Act
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, "user-123", api_id, args)
@@ -897,10 +884,7 @@ class TestExternalDatasetServiceUpdateAPI:
def test_update_external_knowledge_api_not_found(self, mock_db, factory):
"""Test error when API is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
args = {"name": "Updated API"}
@@ -912,10 +896,7 @@ class TestExternalDatasetServiceUpdateAPI:
def test_update_external_knowledge_api_tenant_mismatch(self, mock_db, factory):
"""Test error when tenant ID doesn't match."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
args = {"name": "Updated API"}
@@ -934,10 +915,7 @@ class TestExternalDatasetServiceUpdateAPI:
args = {"name": "New Name Only"}
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_api
mock_db.session.scalar.return_value = existing_api
# Act
result = ExternalDatasetService.update_external_knowledge_api("tenant-123", "user-123", "api-123", args)
@@ -958,10 +936,7 @@ class TestExternalDatasetServiceDeleteAPI:
existing_api = factory.create_external_knowledge_api_mock(api_id=api_id, tenant_id=tenant_id)
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_api
mock_db.session.scalar.return_value = existing_api
# Act
ExternalDatasetService.delete_external_knowledge_api(tenant_id, api_id)
@@ -974,10 +949,7 @@ class TestExternalDatasetServiceDeleteAPI:
def test_delete_external_knowledge_api_not_found(self, mock_db, factory):
"""Test error when API is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="api template not found"):
@@ -987,10 +959,7 @@ class TestExternalDatasetServiceDeleteAPI:
def test_delete_external_knowledge_api_tenant_mismatch(self, mock_db, factory):
"""Test error when tenant ID doesn't match."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="api template not found"):
@@ -1006,10 +975,7 @@ class TestExternalDatasetServiceAPIUseCheck:
# Arrange
api_id = "api-123"
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.count.return_value = 1
mock_db.session.scalar.return_value = 1
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
@@ -1024,10 +990,7 @@ class TestExternalDatasetServiceAPIUseCheck:
# Arrange
api_id = "api-123"
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.count.return_value = 10
mock_db.session.scalar.return_value = 10
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
@@ -1042,10 +1005,7 @@ class TestExternalDatasetServiceAPIUseCheck:
# Arrange
api_id = "api-123"
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.count.return_value = 0
mock_db.session.scalar.return_value = 0
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
@@ -1067,10 +1027,7 @@ class TestExternalDatasetServiceGetBinding:
expected_binding = factory.create_external_knowledge_binding_mock(tenant_id=tenant_id, dataset_id=dataset_id)
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = expected_binding
mock_db.session.scalar.return_value = expected_binding
# Act
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id(tenant_id, dataset_id)
@@ -1083,10 +1040,7 @@ class TestExternalDatasetServiceGetBinding:
def test_get_external_knowledge_binding_not_found(self, mock_db, factory):
"""Test error when binding is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="external knowledge binding not found"):
@@ -1113,10 +1067,7 @@ class TestExternalDatasetServiceDocumentValidate:
api = factory.create_external_knowledge_api_mock(api_id=api_id, settings=[settings])
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = api
mock_db.session.scalar.return_value = api
process_parameter = {"param1": "value1", "param2": "value2"}
@@ -1134,10 +1085,7 @@ class TestExternalDatasetServiceDocumentValidate:
api = factory.create_external_knowledge_api_mock(api_id=api_id, settings=[settings])
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = api
mock_db.session.scalar.return_value = api
process_parameter = {}
@@ -1149,10 +1097,7 @@ class TestExternalDatasetServiceDocumentValidate:
def test_document_create_args_validate_api_not_found(self, mock_db, factory):
"""Test validation fails when API is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="api template not found"):
@@ -1165,10 +1110,7 @@ class TestExternalDatasetServiceDocumentValidate:
settings = {}
api = factory.create_external_knowledge_api_mock(settings=[settings])
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = api
mock_db.session.scalar.return_value = api
# Act & Assert - should not raise
ExternalDatasetService.document_create_args_validate("tenant-123", "api-123", {})
@@ -1186,10 +1128,7 @@ class TestExternalDatasetServiceDocumentValidate:
api = factory.create_external_knowledge_api_mock(settings=[settings])
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = api
mock_db.session.scalar.return_value = api
process_parameter = {"required_param": "value"}
@@ -1498,24 +1437,7 @@ class TestExternalDatasetServiceCreateDataset:
api = factory.create_external_knowledge_api_mock(api_id="api-123")
# Mock database queries
mock_dataset_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == Dataset:
return mock_dataset_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_dataset_query.filter_by.return_value = mock_dataset_query
mock_dataset_query.first.return_value = None
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [None, api]
# Act
result = ExternalDatasetService.create_external_dataset(tenant_id, user_id, args)
@@ -1534,10 +1456,7 @@ class TestExternalDatasetServiceCreateDataset:
# Arrange
existing_dataset = factory.create_dataset_mock(name="Duplicate Dataset")
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_dataset
mock_db.session.scalar.return_value = existing_dataset
args = {"name": "Duplicate Dataset"}
@@ -1549,23 +1468,7 @@ class TestExternalDatasetServiceCreateDataset:
def test_create_external_dataset_api_not_found_error(self, mock_db, factory):
"""Test error when external knowledge API is not found."""
# Arrange
mock_dataset_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == Dataset:
return mock_dataset_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_dataset_query.filter_by.return_value = mock_dataset_query
mock_dataset_query.first.return_value = None
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = None
mock_db.session.scalar.side_effect = [None, None]
args = {"name": "Test Dataset", "external_knowledge_api_id": "nonexistent-api"}
@@ -1579,23 +1482,7 @@ class TestExternalDatasetServiceCreateDataset:
# Arrange
api = factory.create_external_knowledge_api_mock()
mock_dataset_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == Dataset:
return mock_dataset_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_dataset_query.filter_by.return_value = mock_dataset_query
mock_dataset_query.first.return_value = None
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [None, api]
args = {"name": "Test Dataset", "external_knowledge_api_id": "api-123"}
@@ -1609,23 +1496,7 @@ class TestExternalDatasetServiceCreateDataset:
# Arrange
api = factory.create_external_knowledge_api_mock()
mock_dataset_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == Dataset:
return mock_dataset_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_dataset_query.filter_by.return_value = mock_dataset_query
mock_dataset_query.first.return_value = None
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [None, api]
args = {"name": "Test Dataset", "external_knowledge_id": "knowledge-123"}
@@ -1651,23 +1522,7 @@ class TestExternalDatasetServiceFetchRetrieval:
)
api = factory.create_external_knowledge_api_mock(api_id="api-123")
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 200
@@ -1695,10 +1550,7 @@ class TestExternalDatasetServiceFetchRetrieval:
def test_fetch_external_knowledge_retrieval_binding_not_found_error(self, mock_db, factory):
"""Test error when external knowledge binding is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="external knowledge binding not found"):
@@ -1712,23 +1564,7 @@ class TestExternalDatasetServiceFetchRetrieval:
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 200
@@ -1751,23 +1587,7 @@ class TestExternalDatasetServiceFetchRetrieval:
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 200
@@ -1799,23 +1619,7 @@ class TestExternalDatasetServiceFetchRetrieval:
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 500
@@ -1856,23 +1660,7 @@ class TestExternalDatasetServiceFetchRetrieval:
)
api = factory.create_external_knowledge_api_mock(api_id="api-123")
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = status_code
@@ -1891,23 +1679,7 @@ class TestExternalDatasetServiceFetchRetrieval:
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 503

View File

@@ -5,7 +5,6 @@
"prepare": "vp config"
},
"devDependencies": {
"taze": "catalog:",
"vite-plus": "catalog:"
},
"engines": {

3079
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,9 @@
catalogMode: prefer
trustPolicy: no-downgrade
minimumReleaseAge: 2880
trustPolicyExclude:
- chokidar@4.0.3
- reselect@5.1.1
- semver@6.3.1
blockExoticSubdeps: true
strictDepBuilds: true
allowBuilds:
@@ -23,7 +27,7 @@ overrides:
array.prototype.flatmap: npm:@nolyfill/array.prototype.flatmap@^1.0.44
array.prototype.tosorted: npm:@nolyfill/array.prototype.tosorted@^1.0.44
assert: npm:@nolyfill/assert@^1.0.26
brace-expansion@<2.0.2: 2.0.2
brace-expansion@>=2.0.0 <2.0.3: 2.0.3
canvas: ^3.2.2
devalue@<5.3.2: 5.3.2
dompurify@>=3.1.3 <=3.3.1: 3.3.2
@@ -37,6 +41,8 @@ overrides:
is-generator-function: npm:@nolyfill/is-generator-function@^1.0.44
is-typed-array: npm:@nolyfill/is-typed-array@^1.0.44
isarray: npm:@nolyfill/isarray@^1.0.44
lodash@>=4.0.0 <= 4.17.23: 4.18.0
lodash-es@>=4.0.0 <= 4.17.23: 4.18.0
object.assign: npm:@nolyfill/object.assign@^1.0.44
object.entries: npm:@nolyfill/object.entries@^1.0.44
object.fromentries: npm:@nolyfill/object.fromentries@^1.0.44
@@ -64,15 +70,15 @@ overrides:
tar@<=7.5.10: 7.5.11
typed-array-buffer: npm:@nolyfill/typed-array-buffer@^1.0.44
undici@>=7.0.0 <7.24.0: 7.24.0
vite: npm:@voidzero-dev/vite-plus-core@0.1.14
vitest: npm:@voidzero-dev/vite-plus-test@0.1.14
vite: npm:@voidzero-dev/vite-plus-core@0.1.15
vitest: npm:@voidzero-dev/vite-plus-test@0.1.15
which-typed-array: npm:@nolyfill/which-typed-array@^1.0.44
yaml@>=2.0.0 <2.8.3: 2.8.3
yauzl@<3.2.1: 3.2.1
catalog:
"@amplitude/analytics-browser": 2.38.0
"@amplitude/plugin-session-replay-browser": 1.27.5
"@antfu/eslint-config": 7.7.3
"@amplitude/analytics-browser": 2.38.1
"@amplitude/plugin-session-replay-browser": 1.27.6
"@antfu/eslint-config": 8.0.0
"@base-ui/react": 1.3.0
"@chromatic-com/storybook": 5.1.1
"@cucumber/cucumber": 12.7.0
@@ -84,7 +90,7 @@ catalog:
"@formatjs/intl-localematcher": 0.8.2
"@headlessui/react": 2.2.9
"@heroicons/react": 2.2.0
"@hono/node-server": 1.19.11
"@hono/node-server": 1.19.12
"@iconify-json/heroicons": 1.2.3
"@iconify-json/ri": 1.2.10
"@lexical/code": 0.42.0
@@ -98,34 +104,35 @@ catalog:
"@mdx-js/react": 3.1.1
"@mdx-js/rollup": 3.1.1
"@monaco-editor/react": 4.7.0
"@next/eslint-plugin-next": 16.2.1
"@next/mdx": 16.2.1
"@next/eslint-plugin-next": 16.2.2
"@next/mdx": 16.2.2
"@orpc/client": 1.13.13
"@orpc/contract": 1.13.13
"@orpc/openapi-client": 1.13.13
"@orpc/tanstack-query": 1.13.13
"@playwright/test": 1.58.2
"@playwright/test": 1.59.1
"@remixicon/react": 4.9.0
"@rgrove/parse-xml": 4.2.0
"@sentry/react": 10.46.0
"@storybook/addon-docs": 10.3.3
"@storybook/addon-links": 10.3.3
"@storybook/addon-onboarding": 10.3.3
"@storybook/addon-themes": 10.3.3
"@storybook/nextjs-vite": 10.3.3
"@storybook/react": 10.3.3
"@sentry/react": 10.47.0
"@storybook/addon-docs": 10.3.4
"@storybook/addon-links": 10.3.4
"@storybook/addon-onboarding": 10.3.4
"@storybook/addon-themes": 10.3.4
"@storybook/nextjs-vite": 10.3.4
"@storybook/react": 10.3.4
"@streamdown/math": 1.0.2
"@svgdotjs/svg.js": 3.2.5
"@t3-oss/env-nextjs": 0.13.11
"@tailwindcss/postcss": 4.2.2
"@tailwindcss/typography": 0.5.19
"@tailwindcss/vite": 4.2.2
"@tanstack/eslint-plugin-query": 5.95.2
"@tanstack/react-devtools": 0.10.0
"@tanstack/react-form": 1.28.5
"@tanstack/react-form-devtools": 0.2.19
"@tanstack/react-query": 5.95.2
"@tanstack/react-query-devtools": 5.95.2
"@tanstack/eslint-plugin-query": 5.96.1
"@tanstack/react-devtools": 0.10.1
"@tanstack/react-form": 1.28.6
"@tanstack/react-form-devtools": 0.2.20
"@tanstack/react-query": 5.96.1
"@tanstack/react-query-devtools": 5.96.1
"@tanstack/react-virtual": 3.13.23
"@testing-library/dom": 10.4.1
"@testing-library/jest-dom": 6.9.1
"@testing-library/react": 16.3.2
@@ -141,15 +148,13 @@ catalog:
"@types/qs": 6.15.0
"@types/react": 19.2.14
"@types/react-dom": 19.2.3
"@types/react-syntax-highlighter": 15.5.13
"@types/react-window": 1.8.8
"@types/sortablejs": 1.15.9
"@typescript-eslint/eslint-plugin": 8.57.2
"@typescript-eslint/parser": 8.57.2
"@typescript/native-preview": 7.0.0-dev.20260329.1
"@typescript-eslint/eslint-plugin": 8.58.0
"@typescript-eslint/parser": 8.58.0
"@typescript/native-preview": 7.0.0-dev.20260401.1
"@vitejs/plugin-react": 6.0.1
"@vitejs/plugin-rsc": 0.5.21
"@vitest/coverage-v8": 4.1.1
"@vitest/coverage-v8": 4.1.2
abcjs: 6.6.2
agentation: 3.0.2
ahooks: 3.9.7
@@ -157,7 +162,7 @@ catalog:
class-variance-authority: 0.7.1
clsx: 2.1.1
cmdk: 1.1.1
code-inspector-plugin: 1.4.5
code-inspector-plugin: 1.5.1
copy-to-clipboard: 3.3.3
cron-parser: 5.5.0
dayjs: 1.11.20
@@ -174,19 +179,19 @@ catalog:
eslint-markdown: 0.6.0
eslint-plugin-better-tailwindcss: 4.3.2
eslint-plugin-hyoban: 0.14.1
eslint-plugin-markdown-preferences: 0.40.3
eslint-plugin-markdown-preferences: 0.41.0
eslint-plugin-no-barrel-files: 1.2.2
eslint-plugin-react-hooks: 7.0.1
eslint-plugin-react-refresh: 0.5.2
eslint-plugin-sonarjs: 4.0.2
eslint-plugin-storybook: 10.3.3
eslint-plugin-storybook: 10.3.4
fast-deep-equal: 3.1.3
foxact: 0.3.0
happy-dom: 20.8.9
hono: 4.12.9
hast-util-to-jsx-runtime: 2.3.6
hono: 4.12.10
html-entities: 2.6.0
html-to-image: 1.11.13
i18next: 25.10.10
i18next: 26.0.3
i18next-resources-to-backend: 1.2.1
iconify-import-svg: 0.1.2
immer: 11.1.4
@@ -196,15 +201,15 @@ catalog:
js-yaml: 4.1.1
jsonschema: 1.5.0
katex: 0.16.44
knip: 6.1.0
knip: 6.2.0
ky: 1.14.3
lamejs: 1.2.1
lexical: 0.42.0
mermaid: 11.13.0
mermaid: 11.14.0
mime: 4.1.0
mitt: 3.0.1
negotiator: 1.0.0
next: 16.2.1
next: 16.2.2
next-themes: 0.4.6
nuqs: 2.8.9
pinyin-pro: 3.28.0
@@ -217,42 +222,39 @@ catalog:
react-dom: 19.2.4
react-easy-crop: 5.5.7
react-hotkeys-hook: 5.2.4
react-i18next: 16.6.6
react-i18next: 17.0.2
react-multi-email: 1.0.25
react-papaparse: 4.4.0
react-pdf-highlighter: 8.0.0-rc.0
react-server-dom-webpack: 19.2.4
react-sortablejs: 6.1.4
react-syntax-highlighter: 15.6.6
react-textarea-autosize: 8.5.9
react-window: 1.8.11
reactflow: 11.11.4
remark-breaks: 4.0.0
remark-directive: 4.0.0
sass: 1.98.0
scheduler: 0.27.0
sharp: 0.34.5
shiki: 4.0.2
sortablejs: 1.15.7
std-semver: 1.0.8
storybook: 10.3.3
storybook: 10.3.4
streamdown: 2.5.0
string-ts: 2.3.1
tailwind-merge: 3.5.0
tailwindcss: 4.2.2
taze: 19.10.0
tldts: 7.0.27
tsup: ^8.5.1
tsdown: 0.21.7
tsx: 4.21.0
typescript: 5.9.3
typescript: 6.0.2
uglify-js: 3.19.3
unist-util-visit: 5.1.0
use-context-selector: 2.0.0
uuid: 13.0.0
vinext: 0.0.38
vite: npm:@voidzero-dev/vite-plus-core@0.1.14
vinext: 0.0.39
vite: npm:@voidzero-dev/vite-plus-core@0.1.15
vite-plugin-inspect: 12.0.0-beta.1
vite-plus: 0.1.14
vitest: npm:@voidzero-dev/vite-plus-test@0.1.14
vite-plus: 0.1.15
vitest: npm:@voidzero-dev/vite-plus-test@0.1.15
vitest-canvas-mock: 1.1.4
zod: 4.3.6
zundo: 2.3.0

View File

@@ -45,12 +45,12 @@
"homepage": "https://dify.ai",
"license": "MIT",
"scripts": {
"build": "tsup",
"build": "vp pack",
"lint": "eslint",
"lint:fix": "eslint --fix",
"type-check": "tsc -p tsconfig.json --noEmit",
"test": "vitest run",
"test:coverage": "vitest run --coverage",
"test": "vp test",
"test:coverage": "vp test --coverage",
"publish:check": "./scripts/publish.sh --dry-run",
"publish:npm": "./scripts/publish.sh"
},
@@ -61,8 +61,8 @@
"@typescript-eslint/parser": "catalog:",
"@vitest/coverage-v8": "catalog:",
"eslint": "catalog:",
"tsup": "catalog:",
"typescript": "catalog:",
"vite-plus": "catalog:",
"vitest": "catalog:"
}
}

View File

@@ -11,7 +11,8 @@
"strict": true,
"esModuleInterop": true,
"forceConsistentCasingInFileNames": true,
"skipLibCheck": true
"skipLibCheck": true,
"types": ["node"]
},
"include": ["src/**/*.ts", "tests/**/*.ts"]
}

View File

@@ -1,12 +0,0 @@
import { defineConfig } from "tsup";
export default defineConfig({
entry: ["src/index.ts"],
format: ["esm"],
dts: true,
clean: true,
sourcemap: true,
splitting: false,
treeshake: true,
outDir: "dist",
});

View File

@@ -1,6 +1,17 @@
import { defineConfig } from "vitest/config";
import { defineConfig } from "vite-plus";
export default defineConfig({
pack: {
entry: ["src/index.ts"],
format: ["esm"],
dts: true,
clean: true,
sourcemap: true,
// splitting: false,
treeshake: true,
outDir: "dist",
target: false,
},
test: {
environment: "node",
include: ["**/*.test.ts"],

View File

@@ -1,15 +0,0 @@
import { defineConfig } from 'taze'
export default defineConfig({
exclude: [
// We are going to replace these
'react-syntax-highlighter',
'react-window',
'@types/react-window',
// We can not upgrade these yet
'typescript',
],
maturityPeriod: 2,
})

View File

@@ -0,0 +1,36 @@
import { vi } from 'vitest'
const mockVirtualizer = ({
count,
estimateSize,
}: {
count: number
estimateSize?: (index: number) => number
}) => {
const getSize = (index: number) => estimateSize?.(index) ?? 0
return {
getTotalSize: () => Array.from({ length: count }).reduce<number>((total, _, index) => total + getSize(index), 0),
getVirtualItems: () => {
let start = 0
return Array.from({ length: count }).map((_, index) => {
const size = getSize(index)
const virtualItem = {
end: start + size,
index,
key: index,
size,
start,
}
start += size
return virtualItem
})
},
measureElement: vi.fn(),
scrollToIndex: vi.fn(),
}
}
export { mockVirtualizer as useVirtualizer }

View File

@@ -1,10 +1,10 @@
import { render, screen } from '@testing-library/react'
import * as React from 'react'
import { useContext } from 'react'
import { use } from 'react'
import { FeaturesContext, FeaturesProvider } from '../context'
const TestConsumer = () => {
const store = useContext(FeaturesContext)
const store = use(FeaturesContext)
if (!store)
return <div>no store</div>
@@ -34,10 +34,10 @@ describe('FeaturesProvider', () => {
})
it('should maintain the same store reference across re-renders', () => {
const storeRefs: Array<ReturnType<typeof useContext>> = []
const storeRefs: Array<React.ContextType<typeof FeaturesContext>> = []
const StoreRefCollector = () => {
const store = useContext(FeaturesContext)
const store = use(FeaturesContext)
storeRefs.push(store)
return null
}

View File

@@ -5,6 +5,10 @@ import { Theme } from '@/types/app'
import CodeBlock from '../code-block'
const { mockHighlightCode } = vi.hoisted(() => ({
mockHighlightCode: vi.fn(),
}))
type UseThemeReturn = {
theme: Theme
}
@@ -70,6 +74,10 @@ vi.mock('@/hooks/use-theme', () => ({
default: () => mockUseTheme(),
}))
vi.mock('../shiki-highlight', () => ({
highlightCode: mockHighlightCode,
}))
vi.mock('echarts', () => ({
getInstanceByDom: mockEcharts.getInstanceByDom,
}))
@@ -130,6 +138,11 @@ describe('CodeBlock', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseTheme.mockReturnValue({ theme: Theme.light })
mockHighlightCode.mockImplementation(async ({ code, language }) => (
<pre className="shiki">
<code className={`language-${language}`}>{code}</code>
</pre>
))
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {})
clientWidthSpy = vi.spyOn(HTMLElement.prototype, 'clientWidth', 'get').mockReturnValue(900)
@@ -198,11 +211,13 @@ describe('CodeBlock', () => {
expect(container.querySelector('code')?.textContent).toBe('plain text')
})
it('should render syntax-highlighted output when language is standard', () => {
it('should render syntax-highlighted output when language is standard', async () => {
render(<CodeBlock className="language-javascript">const x = 1;</CodeBlock>)
expect(screen.getByText('JavaScript')).toBeInTheDocument()
expect(document.querySelector('code.language-javascript')?.textContent).toContain('const x = 1;')
await waitFor(() => {
expect(document.querySelector('code.language-javascript')?.textContent).toContain('const x = 1;')
})
})
it('should format unknown language labels with capitalized fallback when language is not in map', () => {
@@ -242,13 +257,26 @@ describe('CodeBlock', () => {
expect(screen.queryByText(/Error rendering SVG/i)).not.toBeInTheDocument()
})
it('should render syntax-highlighted output when language is standard and app theme is dark', () => {
it('should render syntax-highlighted output when language is standard and app theme is dark', async () => {
mockUseTheme.mockReturnValue({ theme: Theme.dark })
render(<CodeBlock className="language-javascript">const y = 2;</CodeBlock>)
expect(screen.getByText('JavaScript')).toBeInTheDocument()
expect(document.querySelector('code.language-javascript')?.textContent).toContain('const y = 2;')
await waitFor(() => {
expect(document.querySelector('code.language-javascript')?.textContent).toContain('const y = 2;')
})
})
it('should fall back to plain code block when shiki highlighting fails', async () => {
mockHighlightCode.mockRejectedValueOnce(new Error('highlight failed'))
render(<CodeBlock className="language-javascript">const z = 3;</CodeBlock>)
await waitFor(() => {
expect(screen.getByText('const z = 3;')).toBeInTheDocument()
})
expect(document.querySelector('code.language-javascript')).toBeNull()
})
})

View File

@@ -1,10 +1,7 @@
import type { JSX } from 'react'
import type { BundledLanguage, BundledTheme } from 'shiki/bundle/web'
import ReactEcharts from 'echarts-for-react'
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'
import SyntaxHighlighter from 'react-syntax-highlighter'
import {
atelierHeathDark,
atelierHeathLight,
} from 'react-syntax-highlighter/dist/esm/styles/hljs'
import { memo, useCallback, useEffect, useLayoutEffect, useMemo, useRef, useState } from 'react'
import ActionButton from '@/app/components/base/action-button'
import CopyIcon from '@/app/components/base/copy-icon'
import MarkdownMusic from '@/app/components/base/markdown-blocks/music'
@@ -14,10 +11,10 @@ import useTheme from '@/hooks/use-theme'
import dynamic from '@/next/dynamic'
import { Theme } from '@/types/app'
import SVGRenderer from '../svg-gallery' // Assumes svg-gallery.tsx is in /base directory
import { highlightCode } from './shiki-highlight'
const Flowchart = dynamic(() => import('@/app/components/base/mermaid'), { ssr: false })
// Available language https://github.com/react-syntax-highlighter/react-syntax-highlighter/blob/master/AVAILABLE_LANGUAGES_HLJS.MD
const capitalizationLanguageNameMap: Record<string, string> = {
sql: 'SQL',
javascript: 'JavaScript',
@@ -64,6 +61,61 @@ const getCorrectCapitalizationLanguageName = (language: string) => {
// visit https://reactjs.org/docs/error-decoder.html?invariant=185 for the full message
// or use the non-minified dev environment for full errors and additional helpful warnings.
const ShikiCodeBlock = memo(({ code, language, theme, initial }: { code: string, language: string, theme: BundledTheme, initial?: JSX.Element }) => {
const [nodes, setNodes] = useState(initial)
useLayoutEffect(() => {
let cancelled = false
void highlightCode({
code,
language: language as BundledLanguage,
theme,
}).then((result) => {
if (!cancelled)
setNodes(result)
}).catch((error) => {
console.error('Shiki highlighting failed:', error)
if (!cancelled)
setNodes(undefined)
})
return () => {
cancelled = true
}
}, [code, language, theme])
if (!nodes) {
return (
<pre style={{
paddingLeft: 12,
borderBottomLeftRadius: '10px',
borderBottomRightRadius: '10px',
backgroundColor: 'var(--color-components-input-bg-normal)',
margin: 0,
overflow: 'auto',
}}
>
<code>{code}</code>
</pre>
)
}
return (
<div
style={{
borderBottomLeftRadius: '10px',
borderBottomRightRadius: '10px',
overflow: 'auto',
}}
className="shiki-line-numbers [&_pre]:m-0! [&_pre]:rounded-t-none! [&_pre]:rounded-b-[10px]! [&_pre]:bg-components-input-bg-normal! [&_pre]:py-2!"
>
{nodes}
</div>
)
})
ShikiCodeBlock.displayName = 'ShikiCodeBlock'
// Define ECharts event parameter types
type EChartsEventParams = {
type: string
@@ -416,20 +468,11 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any
)
default:
return (
<SyntaxHighlighter
{...props}
style={theme === Theme.light ? atelierHeathLight : atelierHeathDark}
customStyle={{
paddingLeft: 12,
borderBottomLeftRadius: '10px',
borderBottomRightRadius: '10px',
backgroundColor: 'var(--color-components-input-bg-normal)',
}}
language={match?.[1]}
showLineNumbers
>
{content}
</SyntaxHighlighter>
<ShikiCodeBlock
code={content}
language={match?.[1] || 'text'}
theme={isDarkMode ? 'github-dark' : 'github-light'}
/>
)
}
}, [children, language, isSVG, finalChartOption, props, theme, match, chartState, isDarkMode, echartsStyle, echartsOpts, handleChartReady, echartsEvents])
@@ -440,7 +483,7 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any
return (
<div className="relative">
<div className="flex h-8 items-center justify-between rounded-t-[10px] border-b border-divider-subtle bg-components-input-bg-normal p-1 pl-3">
<div className="text-text-secondary system-xs-semibold-uppercase">{languageShowName}</div>
<div className="system-xs-semibold-uppercase text-text-secondary">{languageShowName}</div>
<div className="flex items-center gap-1">
{language === 'svg' && <SVGBtn isSVG={isSVG} setIsSVG={setIsSVG} />}
<ActionButton>

View File

@@ -0,0 +1,29 @@
import type { JSX } from 'react'
import type { BundledLanguage, BundledTheme } from 'shiki/bundle/web'
import { toJsxRuntime } from 'hast-util-to-jsx-runtime'
import { Fragment } from 'react'
import { jsx, jsxs } from 'react/jsx-runtime'
import { codeToHast } from 'shiki/bundle/web'
type HighlightCodeOptions = {
code: string
language: BundledLanguage
theme: BundledTheme
}
export const highlightCode = async ({
code,
language,
theme,
}: HighlightCodeOptions): Promise<JSX.Element> => {
const hast = await codeToHast(code, {
lang: language,
theme,
})
return toJsxRuntime(hast, {
Fragment,
jsx,
jsxs,
}) as JSX.Element
}

View File

@@ -9,6 +9,8 @@ import { useModalContextSelector } from '@/context/modal-context'
import { useInvalidPreImportNotionPages, usePreImportNotionPages } from '@/service/knowledge/use-import'
import NotionPageSelector from '../base'
vi.mock('@tanstack/react-virtual')
vi.mock('@/service/knowledge/use-import', () => ({
usePreImportNotionPages: vi.fn(),
useInvalidPreImportNotionPages: vi.fn(),
@@ -183,7 +185,7 @@ describe('NotionPageSelector Base', () => {
const user = userEvent.setup()
render(<NotionPageSelector credentialList={mockCredentialList} onSelect={vi.fn()} />)
await user.click(screen.getByRole('button', { name: 'Configure Notion' }))
await user.click(screen.getByRole('button', { name: 'common.dataSource.notion.selector.configure' }))
expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({ payload: ACCOUNT_SETTING_TAB.DATA_SOURCE })
})

View File

@@ -2,6 +2,7 @@ import type { DataSourceCredential } from '../../header/account-setting/data-sou
import type { NotionCredential } from './credential-selector'
import type { DataSourceNotionPageMap, DataSourceNotionWorkspace, NotionPage } from '@/models/common'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
import { useModalContextSelector } from '@/context/modal-context'
import { useInvalidPreImportNotionPages, usePreImportNotionPages } from '@/service/knowledge/use-import'
@@ -33,6 +34,7 @@ const NotionPageSelector = ({
credentialList,
onSelectCredential,
}: NotionPageSelectorProps) => {
const { t } = useTranslation()
const [searchValue, setSearchValue] = useState('')
const setShowAccountSettingModal = useModalContextSelector(s => s.setShowAccountSettingModal)
@@ -48,27 +50,34 @@ const NotionPageSelector = ({
}
})
}, [credentialList])
const [currentCredential, setCurrentCredential] = useState(notionCredentials[0])
const [selectedCredentialId, setSelectedCredentialId] = useState(() => notionCredentials[0]?.credentialId ?? '')
const currentCredential = useMemo(() => {
return notionCredentials.find(item => item.credentialId === selectedCredentialId) ?? notionCredentials[0] ?? null
}, [notionCredentials, selectedCredentialId])
const currentCredentialId = currentCredential?.credentialId ?? ''
useEffect(() => {
const credential = notionCredentials.find(item => item.credentialId === currentCredential?.credentialId)
if (!credential) {
const firstCredential = notionCredentials[0]
invalidPreImportNotionPages({ datasetId, credentialId: firstCredential.credentialId })
setCurrentCredential(notionCredentials[0])
onSelect([]) // Clear selected pages when changing credential
onSelectCredential?.(firstCredential.credentialId)
onSelectCredential?.(currentCredentialId)
}, [currentCredentialId, onSelectCredential])
useEffect(() => {
if (!notionCredentials.length) {
onSelect([])
return
}
else {
onSelectCredential?.(credential?.credentialId || '')
}
}, [notionCredentials])
if (!selectedCredentialId || selectedCredentialId === currentCredentialId)
return
invalidPreImportNotionPages({ datasetId, credentialId: currentCredentialId })
onSelect([])
}, [currentCredentialId, datasetId, invalidPreImportNotionPages, notionCredentials.length, onSelect, selectedCredentialId])
const {
data: notionsPages,
isFetching: isFetchingNotionPages,
isError: isFetchingNotionPagesError,
} = usePreImportNotionPages({ datasetId, credentialId: currentCredential.credentialId || '' })
} = usePreImportNotionPages({ datasetId, credentialId: currentCredentialId })
const pagesMapAndSelectedPagesId: [DataSourceNotionPageMap, Set<string>, Set<string>] = useMemo(() => {
const selectedPagesId = new Set<string>()
@@ -94,28 +103,24 @@ const NotionPageSelector = ({
const defaultSelectedPagesId = useMemo(() => {
return [...Array.from(pagesMapAndSelectedPagesId[1]), ...(value || [])]
}, [pagesMapAndSelectedPagesId, value])
const [selectedPagesId, setSelectedPagesId] = useState<Set<string>>(() => new Set(defaultSelectedPagesId))
useEffect(() => {
setSelectedPagesId(new Set(defaultSelectedPagesId))
}, [defaultSelectedPagesId])
const selectedPagesId = useMemo(() => new Set(defaultSelectedPagesId), [defaultSelectedPagesId])
const handleSearchValueChange = useCallback((value: string) => {
setSearchValue(value)
}, [])
const handleSelectCredential = useCallback((credentialId: string) => {
const credential = notionCredentials.find(item => item.credentialId === credentialId)!
invalidPreImportNotionPages({ datasetId, credentialId: credential.credentialId })
setCurrentCredential(credential)
if (credentialId === currentCredentialId)
return
invalidPreImportNotionPages({ datasetId, credentialId })
setSelectedCredentialId(credentialId)
onSelect([]) // Clear selected pages when changing credential
onSelectCredential?.(credential.credentialId)
}, [datasetId, invalidPreImportNotionPages, notionCredentials, onSelect, onSelectCredential])
}, [currentCredentialId, datasetId, invalidPreImportNotionPages, onSelect])
const handleSelectPages = useCallback((newSelectedPagesId: Set<string>) => {
const selectedPages = Array.from(newSelectedPagesId).map(pageId => pagesMapAndSelectedPagesId[0][pageId])
setSelectedPagesId(new Set(Array.from(newSelectedPagesId)))
onSelect(selectedPages)
}, [pagesMapAndSelectedPagesId, onSelect])
@@ -140,16 +145,16 @@ const NotionPageSelector = ({
<div className="flex flex-col gap-y-2" data-testid="notion-page-selector-base">
<Header
onClickConfiguration={handleConfigureNotion}
title="Choose notion pages"
buttonText="Configure Notion"
docTitle="Notion docs"
title={t('dataSource.notion.selector.headerTitle', { ns: 'common' })}
buttonText={t('dataSource.notion.selector.configure', { ns: 'common' })}
docTitle={t('dataSource.notion.selector.docs', { ns: 'common' })}
docLink="https://www.notion.so/docs"
/>
<div className="rounded-xl border border-components-panel-border bg-background-default-subtle">
<div className="flex h-12 items-center gap-x-2 rounded-t-xl border-b border-b-divider-regular bg-components-panel-bg p-2">
<div className="flex grow items-center gap-x-1">
<WorkspaceSelector
value={currentCredential.credentialId}
value={currentCredentialId}
items={notionCredentials}
onSelect={handleSelectCredential}
/>
@@ -168,6 +173,7 @@ const NotionPageSelector = ({
)
: (
<PageSelector
key={currentCredentialId || 'default'}
value={selectedPagesId}
disabledValue={pagesMapAndSelectedPagesId[2]}
searchValue={searchValue}

View File

@@ -3,6 +3,8 @@ import { render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import PageSelector from '../index'
vi.mock('@tanstack/react-virtual')
const buildPage = (overrides: Partial<DataSourceNotionPage>): DataSourceNotionPage => ({
page_id: 'page-id',
page_name: 'Page name',

View File

@@ -1,11 +1,7 @@
import type { ListChildComponentProps } from 'react-window'
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
import { memo, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { areEqual, FixedSizeList as List } from 'react-window'
import { cn } from '@/utils/classnames'
import Checkbox from '../../checkbox'
import NotionIcon from '../../notion-icon'
import { usePageSelectorModel } from './use-page-selector-model'
import VirtualPageList from './virtual-page-list'
type PageSelectorProps = {
value: Set<string>
@@ -17,173 +13,7 @@ type PageSelectorProps = {
canPreview?: boolean
previewPageId?: string
onPreview?: (selectedPageId: string) => void
isMultipleChoice?: boolean
}
type NotionPageTreeItem = {
children: Set<string>
descendants: Set<string>
depth: number
ancestors: string[]
} & DataSourceNotionPage
type NotionPageTreeMap = Record<string, NotionPageTreeItem>
type NotionPageItem = {
expand: boolean
depth: number
} & DataSourceNotionPage
const recursivePushInParentDescendants = (
pagesMap: DataSourceNotionPageMap,
listTreeMap: NotionPageTreeMap,
current: NotionPageTreeItem,
leafItem: NotionPageTreeItem,
) => {
const parentId = current.parent_id
const pageId = current.page_id
if (!parentId || !pageId)
return
if (parentId !== 'root' && pagesMap[parentId]) {
if (!listTreeMap[parentId]) {
const children = new Set([pageId])
const descendants = new Set([pageId, leafItem.page_id])
listTreeMap[parentId] = {
...pagesMap[parentId],
children,
descendants,
depth: 0,
ancestors: [],
}
}
else {
listTreeMap[parentId].children.add(pageId)
listTreeMap[parentId].descendants.add(pageId)
listTreeMap[parentId].descendants.add(leafItem.page_id)
}
leafItem.depth++
leafItem.ancestors.unshift(listTreeMap[parentId].page_name)
if (listTreeMap[parentId].parent_id !== 'root')
recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap[parentId], leafItem)
}
}
const ItemComponent = ({ index, style, data }: ListChildComponentProps<{
dataList: NotionPageItem[]
handleToggle: (index: number) => void
checkedIds: Set<string>
disabledCheckedIds: Set<string>
handleCheck: (index: number) => void
canPreview?: boolean
handlePreview: (index: number) => void
listMapWithChildrenAndDescendants: NotionPageTreeMap
searchValue: string
previewPageId: string
pagesMap: DataSourceNotionPageMap
}>) => {
const { t } = useTranslation()
const {
dataList,
handleToggle,
checkedIds,
disabledCheckedIds,
handleCheck,
canPreview,
handlePreview,
listMapWithChildrenAndDescendants,
searchValue,
previewPageId,
pagesMap,
} = data
const current = dataList[index]
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[current.page_id]
const hasChild = currentWithChildrenAndDescendants.descendants.size > 0
const ancestors = currentWithChildrenAndDescendants.ancestors
const breadCrumbs = ancestors.length ? [...ancestors, current.page_name] : [current.page_name]
const disabled = disabledCheckedIds.has(current.page_id)
const renderArrow = () => {
if (hasChild) {
return (
<div
className="mr-1 flex h-5 w-5 shrink-0 items-center justify-center rounded-md hover:bg-components-button-ghost-bg-hover"
style={{ marginLeft: current.depth * 8 }}
onClick={() => handleToggle(index)}
data-testid={`notion-page-toggle-${current.page_id}`}
>
{
current.expand
? <div className="i-ri-arrow-down-s-line h-4 w-4 text-text-tertiary" />
: <div className="i-ri-arrow-right-s-line h-4 w-4 text-text-tertiary" />
}
</div>
)
}
if (current.parent_id === 'root' || !pagesMap[current.parent_id]) {
return (
<div></div>
)
}
return (
<div className="mr-1 h-5 w-5 shrink-0" style={{ marginLeft: current.depth * 8 }} />
)
}
return (
<div
className={cn('group flex cursor-pointer items-center rounded-md pl-2 pr-[2px] hover:bg-state-base-hover', previewPageId === current.page_id && 'bg-state-base-hover')}
style={{ ...style, top: style.top as number + 8, left: 8, right: 8, width: 'calc(100% - 16px)' }}
data-testid={`notion-page-row-${current.page_id}`}
>
<Checkbox
className="mr-2 shrink-0"
checked={checkedIds.has(current.page_id)}
disabled={disabled}
onCheck={() => {
handleCheck(index)
}}
id={`notion-page-checkbox-${current.page_id}`}
/>
{!searchValue && renderArrow()}
<NotionIcon
className="mr-1 shrink-0"
type="page"
src={current.page_icon}
/>
<div
className="grow truncate text-[13px] font-medium leading-4 text-text-secondary"
title={current.page_name}
data-testid={`notion-page-name-${current.page_id}`}
>
{current.page_name}
</div>
{
canPreview && (
<div
className="ml-1 hidden h-6 shrink-0 cursor-pointer items-center rounded-md border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-2 text-xs
font-medium leading-4 text-components-button-secondary-text shadow-xs shadow-shadow-shadow-3 backdrop-blur-[10px]
hover:border-components-button-secondary-border-hover hover:bg-components-button-secondary-bg-hover group-hover:flex"
onClick={() => handlePreview(index)}
data-testid={`notion-page-preview-${current.page_id}`}
>
{t('dataSource.notion.selector.preview', { ns: 'common' })}
</div>
)
}
{
searchValue && (
<div
className="ml-1 max-w-[120px] shrink-0 truncate text-xs text-text-quaternary"
title={breadCrumbs.join(' / ')}
>
{breadCrumbs.join(' / ')}
</div>
)
}
</div>
)
}
const Item = memo(ItemComponent, areEqual)
const PageSelector = ({
value,
@@ -197,108 +27,25 @@ const PageSelector = ({
onPreview,
}: PageSelectorProps) => {
const { t } = useTranslation()
const [dataList, setDataList] = useState<NotionPageItem[]>([])
const [localPreviewPageId, setLocalPreviewPageId] = useState('')
useEffect(() => {
setDataList(list.filter(item => item.parent_id === 'root' || !pagesMap[item.parent_id]).map((item) => {
return {
...item,
expand: false,
depth: 0,
}
}))
}, [list])
const searchDataList = list.filter((item) => {
return item.page_name.includes(searchValue)
}).map((item) => {
return {
...item,
expand: false,
depth: 0,
}
const {
currentPreviewPageId,
effectiveSearchValue,
rows,
handlePreview,
handleSelect,
handleToggle,
} = usePageSelectorModel({
checkedIds: value,
list,
onPreview,
onSelect,
pagesMap,
previewPageId,
searchValue,
selectionMode: 'multiple',
})
const currentDataList = searchValue ? searchDataList : dataList
const currentPreviewPageId = previewPageId === undefined ? localPreviewPageId : previewPageId
const listMapWithChildrenAndDescendants = useMemo(() => {
return list.reduce((prev: NotionPageTreeMap, next: DataSourceNotionPage) => {
const pageId = next.page_id
if (!prev[pageId])
prev[pageId] = { ...next, children: new Set(), descendants: new Set(), depth: 0, ancestors: [] }
recursivePushInParentDescendants(pagesMap, prev, prev[pageId], prev[pageId])
return prev
}, {})
}, [list, pagesMap])
const handleToggle = (index: number) => {
const current = dataList[index]
const pageId = current.page_id
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId]
const descendantsIds = Array.from(currentWithChildrenAndDescendants.descendants)
const childrenIds = Array.from(currentWithChildrenAndDescendants.children)
let newDataList = []
if (current.expand) {
current.expand = false
newDataList = dataList.filter(item => !descendantsIds.includes(item.page_id))
}
else {
current.expand = true
newDataList = [
...dataList.slice(0, index + 1),
...childrenIds.map(item => ({
...pagesMap[item],
expand: false,
depth: listMapWithChildrenAndDescendants[item].depth,
})),
...dataList.slice(index + 1),
]
}
setDataList(newDataList)
}
const copyValue = new Set(value)
const handleCheck = (index: number) => {
const current = currentDataList[index]
const pageId = current.page_id
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId]
if (copyValue.has(pageId)) {
if (!searchValue) {
for (const item of currentWithChildrenAndDescendants.descendants)
copyValue.delete(item)
}
copyValue.delete(pageId)
}
else {
if (!searchValue) {
for (const item of currentWithChildrenAndDescendants.descendants)
copyValue.add(item)
}
copyValue.add(pageId)
}
onSelect(new Set(copyValue))
}
const handlePreview = (index: number) => {
const current = currentDataList[index]
const pageId = current.page_id
setLocalPreviewPageId(pageId)
if (onPreview)
onPreview(pageId)
}
if (!currentDataList.length) {
if (!rows.length) {
return (
<div className="flex h-[296px] items-center justify-center text-[13px] text-text-tertiary">
{t('dataSource.notion.selector.noSearchResult', { ns: 'common' })}
@@ -307,29 +54,18 @@ const PageSelector = ({
}
return (
<List
className="py-2"
height={296}
itemCount={currentDataList.length}
itemSize={28}
width="100%"
itemKey={(index, data) => data.dataList[index].page_id}
itemData={{
dataList: currentDataList,
handleToggle,
checkedIds: value,
disabledCheckedIds: disabledValue,
handleCheck,
canPreview,
handlePreview,
listMapWithChildrenAndDescendants,
searchValue,
previewPageId: currentPreviewPageId,
pagesMap,
}}
>
{Item}
</List>
<VirtualPageList
checkedIds={value}
disabledValue={disabledValue}
onPreview={handlePreview}
onSelect={handleSelect}
onToggle={handleToggle}
previewPageId={currentPreviewPageId}
rows={rows}
searchValue={effectiveSearchValue}
selectionMode="multiple"
showPreview={canPreview}
/>
)
}

View File

@@ -0,0 +1,116 @@
import type { CSSProperties } from 'react'
import type { NotionPageRow as NotionPageRowData, NotionPageSelectionMode } from './types'
import { RiArrowDownSLine, RiArrowRightSLine } from '@remixicon/react'
import { memo } from 'react'
import { useTranslation } from 'react-i18next'
import Checkbox from '@/app/components/base/checkbox'
import NotionIcon from '@/app/components/base/notion-icon'
import Radio from '@/app/components/base/radio/ui'
import { cn } from '@/utils/classnames'
type NotionPageRowProps = {
checked: boolean
disabled: boolean
isPreviewed: boolean
onPreview: (pageId: string) => void
onSelect: (pageId: string) => void
onToggle: (pageId: string) => void
row: NotionPageRowData
searchValue: string
selectionMode: NotionPageSelectionMode
showPreview: boolean
style: CSSProperties
}
const NotionPageRow = ({
checked,
disabled,
isPreviewed,
onPreview,
onSelect,
onToggle,
row,
searchValue,
selectionMode,
showPreview,
style,
}: NotionPageRowProps) => {
const { t } = useTranslation()
const pageId = row.page.page_id
const breadcrumbs = row.ancestors.length ? [...row.ancestors, row.page.page_name] : [row.page.page_name]
return (
<div
className={cn('group flex cursor-pointer items-center rounded-md pr-[2px] pl-2 hover:bg-state-base-hover', isPreviewed && 'bg-state-base-hover')}
style={style}
data-testid={`notion-page-row-${pageId}`}
>
{selectionMode === 'multiple'
? (
<Checkbox
className="mr-2 shrink-0"
checked={checked}
disabled={disabled}
onCheck={() => onSelect(pageId)}
id={`notion-page-checkbox-${pageId}`}
/>
)
: (
<Radio
className="mr-2 shrink-0"
isChecked={checked}
disabled={disabled}
onCheck={() => onSelect(pageId)}
/>
)}
{!searchValue && row.hasChild && (
<div
className="mr-1 flex h-5 w-5 shrink-0 items-center justify-center rounded-md hover:bg-components-button-ghost-bg-hover"
style={{ marginLeft: row.depth * 8 }}
onClick={() => onToggle(pageId)}
data-testid={`notion-page-toggle-${pageId}`}
>
{row.expand
? <RiArrowDownSLine className="h-4 w-4 text-text-tertiary" />
: <RiArrowRightSLine className="h-4 w-4 text-text-tertiary" />}
</div>
)}
{!searchValue && !row.hasChild && row.parentExists && (
<div className="mr-1 h-5 w-5 shrink-0" style={{ marginLeft: row.depth * 8 }} />
)}
<NotionIcon
className="mr-1 shrink-0"
type="page"
src={row.page.page_icon}
/>
<div
className="grow truncate text-[13px] leading-4 font-medium text-text-secondary"
title={row.page.page_name}
data-testid={`notion-page-name-${pageId}`}
>
{row.page.page_name}
</div>
{showPreview && (
<div
className="ml-1 hidden h-6 shrink-0 cursor-pointer items-center rounded-md border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-2 text-xs
leading-4 font-medium text-components-button-secondary-text shadow-xs shadow-shadow-shadow-3 backdrop-blur-[10px]
group-hover:flex hover:border-components-button-secondary-border-hover hover:bg-components-button-secondary-bg-hover"
onClick={() => onPreview(pageId)}
data-testid={`notion-page-preview-${pageId}`}
>
{t('dataSource.notion.selector.preview', { ns: 'common' })}
</div>
)}
{searchValue && (
<div
className="ml-1 max-w-[120px] shrink-0 truncate text-xs text-text-quaternary"
title={breadcrumbs.join(' / ')}
>
{breadcrumbs.join(' / ')}
</div>
)}
</div>
)
}
export default memo(NotionPageRow)

View File

@@ -0,0 +1,21 @@
import type { DataSourceNotionPage } from '@/models/common'
export type NotionPageSelectionMode = 'multiple' | 'single'
export type NotionPageTreeItem = {
children: Set<string>
descendants: Set<string>
depth: number
ancestors: string[]
} & DataSourceNotionPage
export type NotionPageTreeMap = Record<string, NotionPageTreeItem>
export type NotionPageRow = {
page: DataSourceNotionPage
parentExists: boolean
depth: number
expand: boolean
hasChild: boolean
ancestors: string[]
}

View File

@@ -0,0 +1,88 @@
import type { NotionPageSelectionMode } from './types'
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
import { startTransition, useCallback, useDeferredValue, useMemo, useState } from 'react'
import { buildNotionPageTree, getNextSelectedPageIds, getRootPageIds, getVisiblePageRows } from './utils'
type UsePageSelectorModelProps = {
checkedIds: Set<string>
searchValue: string
pagesMap: DataSourceNotionPageMap
list: DataSourceNotionPage[]
onSelect: (selectedPagesId: Set<string>) => void
previewPageId?: string
onPreview?: (selectedPageId: string) => void
selectionMode: NotionPageSelectionMode
}
export const usePageSelectorModel = ({
checkedIds,
searchValue,
pagesMap,
list,
onSelect,
previewPageId,
onPreview,
selectionMode,
}: UsePageSelectorModelProps) => {
const deferredSearchValue = useDeferredValue(searchValue)
const [expandedIds, setExpandedIds] = useState<Set<string>>(() => new Set())
const [localPreviewPageId, setLocalPreviewPageId] = useState('')
const treeMap = useMemo(() => buildNotionPageTree(list, pagesMap), [list, pagesMap])
const rootPageIds = useMemo(() => getRootPageIds(list, pagesMap), [list, pagesMap])
const rows = useMemo(() => {
return getVisiblePageRows({
list,
pagesMap,
searchValue: deferredSearchValue,
treeMap,
rootPageIds,
expandedIds,
})
}, [deferredSearchValue, expandedIds, list, pagesMap, rootPageIds, treeMap])
const currentPreviewPageId = previewPageId ?? localPreviewPageId
const handleToggle = useCallback((pageId: string) => {
startTransition(() => {
setExpandedIds((currentExpandedIds) => {
const nextExpandedIds = new Set(currentExpandedIds)
if (nextExpandedIds.has(pageId)) {
nextExpandedIds.delete(pageId)
treeMap[pageId]?.descendants.forEach(descendantId => nextExpandedIds.delete(descendantId))
}
else {
nextExpandedIds.add(pageId)
}
return nextExpandedIds
})
})
}, [treeMap])
const handleSelect = useCallback((pageId: string) => {
onSelect(getNextSelectedPageIds({
checkedIds,
pageId,
searchValue: deferredSearchValue,
selectionMode,
treeMap,
}))
}, [checkedIds, deferredSearchValue, onSelect, selectionMode, treeMap])
const handlePreview = useCallback((pageId: string) => {
setLocalPreviewPageId(pageId)
onPreview?.(pageId)
}, [onPreview])
return {
currentPreviewPageId,
effectiveSearchValue: deferredSearchValue,
rows,
handlePreview,
handleSelect,
handleToggle,
}
}

View File

@@ -0,0 +1,163 @@
import type { NotionPageRow, NotionPageSelectionMode, NotionPageTreeItem, NotionPageTreeMap } from './types'
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
export const recursivePushInParentDescendants = (
pagesMap: DataSourceNotionPageMap,
listTreeMap: NotionPageTreeMap,
current: NotionPageTreeItem,
leafItem: NotionPageTreeItem,
) => {
const parentId = current.parent_id
const pageId = current.page_id
if (!parentId || !pageId)
return
if (parentId !== 'root' && pagesMap[parentId]) {
if (!listTreeMap[parentId]) {
const children = new Set([pageId])
const descendants = new Set([pageId, leafItem.page_id])
listTreeMap[parentId] = {
...pagesMap[parentId],
children,
descendants,
depth: 0,
ancestors: [],
}
}
else {
listTreeMap[parentId].children.add(pageId)
listTreeMap[parentId].descendants.add(pageId)
listTreeMap[parentId].descendants.add(leafItem.page_id)
}
leafItem.depth++
leafItem.ancestors.unshift(listTreeMap[parentId].page_name)
if (listTreeMap[parentId].parent_id !== 'root')
recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap[parentId], leafItem)
}
}
export const buildNotionPageTree = (
list: DataSourceNotionPage[],
pagesMap: DataSourceNotionPageMap,
): NotionPageTreeMap => {
return list.reduce((prev: NotionPageTreeMap, next) => {
const pageId = next.page_id
if (!prev[pageId])
prev[pageId] = { ...next, children: new Set(), descendants: new Set(), depth: 0, ancestors: [] }
recursivePushInParentDescendants(pagesMap, prev, prev[pageId], prev[pageId])
return prev
}, {})
}
export const getRootPageIds = (
list: DataSourceNotionPage[],
pagesMap: DataSourceNotionPageMap,
) => {
return list
.filter(item => item.parent_id === 'root' || !pagesMap[item.parent_id])
.map(item => item.page_id)
}
export const getVisiblePageRows = ({
list,
pagesMap,
searchValue,
treeMap,
rootPageIds,
expandedIds,
}: {
list: DataSourceNotionPage[]
pagesMap: DataSourceNotionPageMap
searchValue: string
treeMap: NotionPageTreeMap
rootPageIds: string[]
expandedIds: Set<string>
}): NotionPageRow[] => {
if (searchValue) {
return list
.filter(item => item.page_name.includes(searchValue))
.map(item => ({
page: item,
parentExists: item.parent_id !== 'root' && Boolean(pagesMap[item.parent_id]),
depth: treeMap[item.page_id]?.depth ?? 0,
expand: false,
hasChild: (treeMap[item.page_id]?.children.size ?? 0) > 0,
ancestors: treeMap[item.page_id]?.ancestors ?? [],
}))
}
const rows: NotionPageRow[] = []
const visit = (pageId: string) => {
const current = treeMap[pageId]
if (!current)
return
const expand = expandedIds.has(pageId)
rows.push({
page: current,
parentExists: current.parent_id !== 'root' && Boolean(pagesMap[current.parent_id]),
depth: current.depth,
expand,
hasChild: current.children.size > 0,
ancestors: current.ancestors,
})
if (!expand)
return
current.children.forEach(visit)
}
rootPageIds.forEach(visit)
return rows
}
export const getNextSelectedPageIds = ({
checkedIds,
pageId,
searchValue,
selectionMode,
treeMap,
}: {
checkedIds: Set<string>
pageId: string
searchValue: string
selectionMode: NotionPageSelectionMode
treeMap: NotionPageTreeMap
}) => {
const nextCheckedIds = new Set(checkedIds)
const descendants = treeMap[pageId]?.descendants ?? new Set<string>()
if (selectionMode === 'single') {
if (nextCheckedIds.has(pageId)) {
nextCheckedIds.delete(pageId)
}
else {
nextCheckedIds.clear()
nextCheckedIds.add(pageId)
}
return nextCheckedIds
}
if (nextCheckedIds.has(pageId)) {
if (!searchValue)
descendants.forEach(item => nextCheckedIds.delete(item))
nextCheckedIds.delete(pageId)
return nextCheckedIds
}
if (!searchValue)
descendants.forEach(item => nextCheckedIds.add(item))
nextCheckedIds.add(pageId)
return nextCheckedIds
}

View File

@@ -0,0 +1,93 @@
'use client'
import type { NotionPageRow, NotionPageSelectionMode } from './types'
import { useVirtualizer } from '@tanstack/react-virtual'
import { useRef } from 'react'
import PageRow from './page-row'
type VirtualPageListProps = {
checkedIds: Set<string>
disabledValue: Set<string>
onPreview: (pageId: string) => void
onSelect: (pageId: string) => void
onToggle: (pageId: string) => void
previewPageId: string
rows: NotionPageRow[]
searchValue: string
selectionMode: NotionPageSelectionMode
showPreview: boolean
}
const rowHeight = 28
const VirtualPageList = ({
checkedIds,
disabledValue,
onPreview,
onSelect,
onToggle,
previewPageId,
rows,
searchValue,
selectionMode,
showPreview,
}: VirtualPageListProps) => {
const scrollRef = useRef<HTMLDivElement>(null)
const rowVirtualizer = useVirtualizer({
count: rows.length,
estimateSize: () => rowHeight,
getScrollElement: () => scrollRef.current,
overscan: 6,
paddingEnd: 8,
paddingStart: 8,
})
const virtualRows = rowVirtualizer.getVirtualItems()
return (
<div
ref={scrollRef}
className="h-[296px] overflow-auto"
data-testid="virtual-list"
>
<div
style={{
height: `${rowVirtualizer.getTotalSize()}px`,
position: 'relative',
}}
>
{virtualRows.map((virtualRow) => {
const row = rows[virtualRow.index]
const pageId = row.page.page_id
return (
<PageRow
key={pageId}
checked={checkedIds.has(pageId)}
disabled={disabledValue.has(pageId)}
isPreviewed={previewPageId === pageId}
onPreview={onPreview}
onSelect={onSelect}
onToggle={onToggle}
row={row}
searchValue={searchValue}
selectionMode={selectionMode}
showPreview={showPreview}
style={{
height: `${virtualRow.size}px`,
left: 8,
position: 'absolute',
top: 0,
transform: `translateY(${virtualRow.start}px)`,
width: 'calc(100% - 16px)',
}}
/>
)
})}
</div>
</div>
)
}
export default VirtualPageList

View File

@@ -12,6 +12,12 @@ export const Select = BaseSelect.Root
export const SelectGroup = BaseSelect.Group
export const SelectGroupLabel = BaseSelect.GroupLabel
export const SelectValue = BaseSelect.Value
/** @public */
export const SelectGroup = BaseSelect.Group
/** @public */
export const SelectGroupLabel = BaseSelect.GroupLabel
/** @public */
export const SelectSeparator = BaseSelect.Separator
const selectTriggerVariants = cva(
'',

View File

@@ -106,7 +106,7 @@ const OnlineDocuments = ({
if (!currentCredentialId)
return
getOnlineDocuments()
}, [currentCredentialId])
}, [currentCredentialId, getOnlineDocuments])
const handleSearchValueChange = useCallback((value: string) => {
const { setSearchValue } = dataSourceStore.getState()
@@ -156,6 +156,7 @@ const OnlineDocuments = ({
{documentsData?.length
? (
<PageSelector
key={`${currentCredentialId}:${supportBatchUpload ? 'multiple' : 'single'}`}
checkedIds={selectedPagesId}
disabledValue={new Set()}
searchValue={searchValue}
@@ -165,7 +166,6 @@ const OnlineDocuments = ({
canPreview={!isInPipeline}
onPreview={handlePreviewPage}
isMultipleChoice={supportBatchUpload}
currentCredentialId={currentCredentialId}
/>
)
: (

View File

@@ -1,26 +1,11 @@
import type { NotionPageTreeItem, NotionPageTreeMap } from '../index'
import type { NotionPageTreeItem, NotionPageTreeMap } from '@/app/components/base/notion-page-selector/page-selector/types'
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import { recursivePushInParentDescendants } from '@/app/components/base/notion-page-selector/page-selector/utils'
import PageSelector from '../index'
import { recursivePushInParentDescendants } from '../utils'
// Mock react-window FixedSizeList - renders items directly for testing
vi.mock('react-window', () => ({
FixedSizeList: ({ children: ItemComponent, itemCount, itemData, itemKey }: { children: React.ComponentType<{ index: number, style: React.CSSProperties, data: unknown }>, itemCount: number, itemData: unknown, itemKey?: (index: number, data: unknown) => string | number }) => (
<div data-testid="virtual-list">
{Array.from({ length: itemCount }).map((_, index) => (
<ItemComponent
key={itemKey?.(index, itemData) || index}
index={index}
style={{ top: index * 28, left: 0, right: 0, width: '100%', position: 'absolute' as const }}
data={itemData}
/>
))}
</div>
),
areEqual: (prevProps: Record<string, unknown>, nextProps: Record<string, unknown>) => prevProps === nextProps,
}))
vi.mock('@tanstack/react-virtual')
// Note: NotionIcon from @/app/components/base/ is NOT mocked - using real component per testing guidelines
@@ -70,7 +55,6 @@ const createDefaultProps = (overrides?: Partial<PageSelectorProps>): PageSelecto
canPreview: true,
onPreview: vi.fn(),
isMultipleChoice: true,
currentCredentialId: 'cred-1',
...overrides,
}
}
@@ -114,7 +98,7 @@ describe('PageSelector', () => {
expect(screen.queryByTestId('virtual-list')).not.toBeInTheDocument()
})
it('should render items using FixedSizeList', () => {
it('should render items using VirtualList', () => {
const pages = [
createMockPage({ page_id: 'page-1', page_name: 'Page 1' }),
createMockPage({ page_id: 'page-2', page_name: 'Page 2' }),

View File

@@ -1,7 +1,7 @@
import type { NotionPageTreeItem, NotionPageTreeMap } from '../index'
import type { NotionPageTreeItem, NotionPageTreeMap } from '@/app/components/base/notion-page-selector/page-selector/types'
import type { DataSourceNotionPageMap } from '@/models/common'
import { describe, expect, it } from 'vitest'
import { recursivePushInParentDescendants } from '../utils'
import { recursivePushInParentDescendants } from '@/app/components/base/notion-page-selector/page-selector/utils'
const makePageEntry = (overrides: Partial<NotionPageTreeItem>): NotionPageTreeItem => ({
page_icon: null,

View File

@@ -1,9 +1,7 @@
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { FixedSizeList as List } from 'react-window'
import Item from './item'
import { recursivePushInParentDescendants } from './utils'
import { usePageSelectorModel } from '@/app/components/base/notion-page-selector/page-selector/use-page-selector-model'
import VirtualPageList from '@/app/components/base/notion-page-selector/page-selector/virtual-page-list'
type PageSelectorProps = {
checkedIds: Set<string>
@@ -15,23 +13,9 @@ type PageSelectorProps = {
canPreview?: boolean
onPreview?: (selectedPageId: string) => void
isMultipleChoice?: boolean
currentCredentialId: string
currentCredentialId?: string
}
export type NotionPageTreeItem = {
children: Set<string>
descendants: Set<string>
depth: number
ancestors: string[]
} & DataSourceNotionPage
export type NotionPageTreeMap = Record<string, NotionPageTreeItem>
type NotionPageItem = {
expand: boolean
depth: number
} & DataSourceNotionPage
const PageSelector = ({
checkedIds,
disabledValue,
@@ -42,116 +26,28 @@ const PageSelector = ({
canPreview = true,
onPreview,
isMultipleChoice = true,
currentCredentialId,
currentCredentialId: _currentCredentialId,
}: PageSelectorProps) => {
const { t } = useTranslation()
const [dataList, setDataList] = useState<NotionPageItem[]>([])
const [currentPreviewPageId, setCurrentPreviewPageId] = useState('')
useEffect(() => {
setDataList(list.filter(item => item.parent_id === 'root' || !pagesMap[item.parent_id]).map((item) => {
return {
...item,
expand: false,
depth: 0,
}
}))
}, [currentCredentialId])
const searchDataList = list.filter((item) => {
return item.page_name.includes(searchValue)
}).map((item) => {
return {
...item,
expand: false,
depth: 0,
}
const selectionMode = isMultipleChoice ? 'multiple' : 'single'
const {
currentPreviewPageId,
effectiveSearchValue,
rows,
handlePreview,
handleSelect,
handleToggle,
} = usePageSelectorModel({
checkedIds,
list,
onPreview,
onSelect,
pagesMap,
searchValue,
selectionMode,
})
const currentDataList = searchValue ? searchDataList : dataList
const listMapWithChildrenAndDescendants = useMemo(() => {
return list.reduce((prev: NotionPageTreeMap, next: DataSourceNotionPage) => {
const pageId = next.page_id
if (!prev[pageId])
prev[pageId] = { ...next, children: new Set(), descendants: new Set(), depth: 0, ancestors: [] }
recursivePushInParentDescendants(pagesMap, prev, prev[pageId], prev[pageId])
return prev
}, {})
}, [list, pagesMap])
const handleToggle = useCallback((index: number) => {
const current = dataList[index]
const pageId = current.page_id
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId]
const descendantsIds = Array.from(currentWithChildrenAndDescendants.descendants)
const childrenIds = Array.from(currentWithChildrenAndDescendants.children)
let newDataList = []
if (current.expand) {
current.expand = false
newDataList = dataList.filter(item => !descendantsIds.includes(item.page_id))
}
else {
current.expand = true
newDataList = [
...dataList.slice(0, index + 1),
...childrenIds.map(item => ({
...pagesMap[item],
expand: false,
depth: listMapWithChildrenAndDescendants[item].depth,
})),
...dataList.slice(index + 1),
]
}
setDataList(newDataList)
}, [dataList, listMapWithChildrenAndDescendants, pagesMap])
const handleCheck = useCallback((index: number) => {
const copyValue = new Set(checkedIds)
const current = currentDataList[index]
const pageId = current.page_id
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId]
if (copyValue.has(pageId)) {
if (!searchValue && isMultipleChoice) {
for (const item of currentWithChildrenAndDescendants.descendants)
copyValue.delete(item)
}
copyValue.delete(pageId)
}
else {
if (!searchValue && isMultipleChoice) {
for (const item of currentWithChildrenAndDescendants.descendants)
copyValue.add(item)
}
// Single choice mode, clear previous selection
if (!isMultipleChoice && copyValue.size > 0) {
copyValue.clear()
copyValue.add(pageId)
}
else {
copyValue.add(pageId)
}
}
onSelect(new Set(copyValue))
}, [currentDataList, isMultipleChoice, listMapWithChildrenAndDescendants, onSelect, searchValue, checkedIds])
const handlePreview = useCallback((index: number) => {
const current = currentDataList[index]
const pageId = current.page_id
setCurrentPreviewPageId(pageId)
if (onPreview)
onPreview(pageId)
}, [currentDataList, onPreview])
if (!currentDataList.length) {
if (!rows.length) {
return (
<div className="flex h-[296px] items-center justify-center text-[13px] text-text-tertiary">
{t('dataSource.notion.selector.noSearchResult', { ns: 'common' })}
@@ -160,30 +56,18 @@ const PageSelector = ({
}
return (
<List
className="py-2"
height={296}
itemCount={currentDataList.length}
itemSize={28}
width="100%"
itemKey={(index, data) => data.dataList[index].page_id}
itemData={{
dataList: currentDataList,
handleToggle,
checkedIds,
disabledCheckedIds: disabledValue,
handleCheck,
canPreview,
handlePreview,
listMapWithChildrenAndDescendants,
searchValue,
previewPageId: currentPreviewPageId,
pagesMap,
isMultipleChoice,
}}
>
{Item}
</List>
<VirtualPageList
checkedIds={checkedIds}
disabledValue={disabledValue}
onPreview={handlePreview}
onSelect={handleSelect}
onToggle={handleToggle}
previewPageId={currentPreviewPageId}
rows={rows}
searchValue={effectiveSearchValue}
selectionMode={selectionMode}
showPreview={canPreview}
/>
)
}

View File

@@ -1,152 +0,0 @@
import type { ListChildComponentProps } from 'react-window'
import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common'
import { RiArrowDownSLine, RiArrowRightSLine } from '@remixicon/react'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { areEqual } from 'react-window'
import Checkbox from '@/app/components/base/checkbox'
import NotionIcon from '@/app/components/base/notion-icon'
import Radio from '@/app/components/base/radio/ui'
import { cn } from '@/utils/classnames'
type NotionPageTreeItem = {
children: Set<string>
descendants: Set<string>
depth: number
ancestors: string[]
} & DataSourceNotionPage
type NotionPageTreeMap = Record<string, NotionPageTreeItem>
type NotionPageItem = {
expand: boolean
depth: number
} & DataSourceNotionPage
const Item = ({ index, style, data }: ListChildComponentProps<{
dataList: NotionPageItem[]
handleToggle: (index: number) => void
checkedIds: Set<string>
disabledCheckedIds: Set<string>
handleCheck: (index: number) => void
canPreview?: boolean
handlePreview: (index: number) => void
listMapWithChildrenAndDescendants: NotionPageTreeMap
searchValue: string
previewPageId: string
pagesMap: DataSourceNotionPageMap
isMultipleChoice?: boolean
}>) => {
const { t } = useTranslation()
const {
dataList,
handleToggle,
checkedIds,
disabledCheckedIds,
handleCheck,
canPreview,
handlePreview,
listMapWithChildrenAndDescendants,
searchValue,
previewPageId,
pagesMap,
isMultipleChoice,
} = data
const current = dataList[index]
const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[current.page_id]
const hasChild = currentWithChildrenAndDescendants.descendants.size > 0
const ancestors = currentWithChildrenAndDescendants.ancestors
const breadCrumbs = ancestors.length ? [...ancestors, current.page_name] : [current.page_name]
const disabled = disabledCheckedIds.has(current.page_id)
const renderArrow = () => {
if (hasChild) {
return (
<div
className="mr-1 flex h-5 w-5 shrink-0 items-center justify-center rounded-md hover:bg-components-button-ghost-bg-hover"
style={{ marginLeft: current.depth * 8 }}
onClick={() => handleToggle(index)}
>
{
current.expand
? <RiArrowDownSLine className="h-4 w-4 text-text-tertiary" />
: <RiArrowRightSLine className="h-4 w-4 text-text-tertiary" />
}
</div>
)
}
if (current.parent_id === 'root' || !pagesMap[current.parent_id]) {
return (
<div></div>
)
}
return (
<div className="mr-1 h-5 w-5 shrink-0" style={{ marginLeft: current.depth * 8 }} />
)
}
return (
<div
className={cn('group flex cursor-pointer items-center rounded-md pl-2 pr-[2px] hover:bg-state-base-hover', previewPageId === current.page_id && 'bg-state-base-hover')}
style={{ ...style, top: style.top as number + 8, left: 8, right: 8, width: 'calc(100% - 16px)' }}
>
{isMultipleChoice
? (
<Checkbox
className="mr-2 shrink-0"
checked={checkedIds.has(current.page_id)}
disabled={disabled}
onCheck={() => {
handleCheck(index)
}}
/>
)
: (
<Radio
className="mr-2 shrink-0"
isChecked={checkedIds.has(current.page_id)}
disabled={disabled}
onCheck={() => {
handleCheck(index)
}}
/>
)}
{!searchValue && renderArrow()}
<NotionIcon
className="mr-1 shrink-0"
type="page"
src={current.page_icon}
/>
<div
className="grow truncate text-[13px] font-medium leading-4 text-text-secondary"
title={current.page_name}
>
{current.page_name}
</div>
{
canPreview && (
<div
className="ml-1 hidden h-6 shrink-0 cursor-pointer items-center rounded-md border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-2 text-xs
font-medium leading-4 text-components-button-secondary-text shadow-xs shadow-shadow-shadow-3 backdrop-blur-[10px]
hover:border-components-button-secondary-border-hover hover:bg-components-button-secondary-bg-hover group-hover:flex"
onClick={() => handlePreview(index)}
>
{t('dataSource.notion.selector.preview', { ns: 'common' })}
</div>
)
}
{
searchValue && (
<div
className="ml-1 max-w-[120px] shrink-0 truncate text-xs text-text-quaternary"
title={breadCrumbs.join(' / ')}
>
{breadCrumbs.join(' / ')}
</div>
)
}
</div>
)
}
export default React.memo(Item, areEqual)

View File

@@ -1,39 +0,0 @@
import type { NotionPageTreeItem, NotionPageTreeMap } from './index'
import type { DataSourceNotionPageMap } from '@/models/common'
export const recursivePushInParentDescendants = (
pagesMap: DataSourceNotionPageMap,
listTreeMap: NotionPageTreeMap,
current: NotionPageTreeItem,
leafItem: NotionPageTreeItem,
) => {
const parentId = current.parent_id
const pageId = current.page_id
if (!parentId || !pageId)
return
if (parentId !== 'root' && pagesMap[parentId]) {
if (!listTreeMap[parentId]) {
const children = new Set([pageId])
const descendants = new Set([pageId, leafItem.page_id])
listTreeMap[parentId] = {
...pagesMap[parentId],
children,
descendants,
depth: 0,
ancestors: [],
}
}
else {
listTreeMap[parentId].children.add(pageId)
listTreeMap[parentId].descendants.add(pageId)
listTreeMap[parentId].descendants.add(leafItem.page_id)
}
leafItem.depth++
leafItem.ancestors.unshift(listTreeMap[parentId].page_name)
if (listTreeMap[parentId].parent_id !== 'root')
recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap[parentId], leafItem)
}
}

View File

@@ -1,5 +1,5 @@
import { render, screen } from '@testing-library/react'
import { useContext } from 'react'
import { use } from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import DataSourceProvider, { DataSourceContext } from '../provider'
@@ -11,7 +11,7 @@ vi.mock('../', () => ({
// Test consumer component that reads from context
function ContextConsumer() {
const store = useContext(DataSourceContext)
const store = use(DataSourceContext)
return (
<div data-testid="context-value" data-has-store={store !== null}>
{store ? 'has-store' : 'no-store'}
@@ -65,7 +65,7 @@ describe('DataSourceProvider', () => {
const storeValues: Array<typeof mockStore | null> = []
function StoreCapture() {
const store = useContext(DataSourceContext)
const store = use(DataSourceContext)
storeValues.push(store as typeof mockStore | null)
return null
}

View File

@@ -3,7 +3,7 @@ import type { LocalFileSliceShape } from './slices/local-file'
import type { OnlineDocumentSliceShape } from './slices/online-document'
import type { OnlineDriveSliceShape } from './slices/online-drive'
import type { WebsiteCrawlSliceShape } from './slices/website-crawl'
import { useContext } from 'react'
import { use } from 'react'
import { createStore, useStore } from 'zustand'
import { DataSourceContext } from './provider'
import { createCommonSlice } from './slices/common'
@@ -29,7 +29,7 @@ export const createDataSourceStore = () => {
}
export const useDataSourceStoreWithSelector = <T>(selector: (state: DataSourceShape) => T): T => {
const store = useContext(DataSourceContext)
const store = use(DataSourceContext)
if (!store)
throw new Error('Missing DataSourceContext.Provider in the tree')
@@ -37,7 +37,7 @@ export const useDataSourceStoreWithSelector = <T>(selector: (state: DataSourceSh
}
export const useDataSourceStore = () => {
const store = useContext(DataSourceContext)
const store = use(DataSourceContext)
if (!store)
throw new Error('Missing DataSourceContext.Provider in the tree')

View File

@@ -61,7 +61,7 @@ vi.mock('@/app/components/datasets/common/image-list', () => ({
),
}))
// Markdown uses next/dynamic and react-syntax-highlighter (ESM)
// Markdown uses next/dynamic and shiki (ESM)
vi.mock('@/app/components/base/markdown', () => ({
Markdown: ({ content, className }: { content: string, className?: string }) => (
<div data-testid="markdown" className={`markdown-body ${className || ''}`}>{content}</div>

View File

@@ -1,5 +1,5 @@
import type { ReactNode } from 'react'
import { createContext, useContext } from 'react'
import { createContext, use } from 'react'
import { cn } from '@/utils/classnames'
import styles from './quota-panel.module.css'
@@ -41,7 +41,7 @@ const SystemQuotaCard = ({
}
const Label = ({ children, className }: { children: ReactNode, className?: string }) => {
const variant = useContext(VariantContext)
const variant = use(VariantContext)
return (
<div className={cn(
'relative z-1 flex items-center gap-1 truncate px-1.5 pt-1 system-xs-medium',

View File

@@ -1,5 +1,5 @@
import { render, screen, waitFor } from '@testing-library/react'
import { useContext } from 'react'
import { use } from 'react'
import { HooksStoreContext, HooksStoreContextProvider } from '../provider'
const mockRefreshAll = vi.fn()
@@ -27,7 +27,7 @@ vi.mock('../store', async () => {
})
const Consumer = () => {
const store = useContext(HooksStoreContext)
const store = use(HooksStoreContext)
return <div>{store ? 'has-hooks-store' : 'missing-hooks-store'}</div>
}

View File

@@ -1,6 +1,6 @@
'use client'
import type { ReactNode } from 'react'
import { createContext, useContext } from 'react'
import { createContext, use } from 'react'
type MCPToolAvailabilityContextValue = {
versionSupported?: boolean
@@ -26,7 +26,7 @@ export const MCPToolAvailabilityProvider = ({
)
export const useMCPToolAvailability = (): MCPToolAvailability => {
const context = useContext(MCPToolAvailabilityContext)
const context = use(MCPToolAvailabilityContext)
if (context === undefined)
return { allowed: true }

View File

@@ -909,4 +909,19 @@
[data-theme='light'] [data-hide-on-theme='light'] {
display: none;
}
/* Shiki code block line numbers */
.shiki-line-numbers code {
counter-reset: line;
}
.shiki-line-numbers .line::before {
counter-increment: line;
content: counter(line);
display: inline-block;
width: 1rem;
margin-right: 0.75rem;
text-align: right;
color: var(--color-text-quaternary);
user-select: none;
}
}

View File

@@ -1070,9 +1070,6 @@
filter: invert(50%);
}
.markdown-body .react-syntax-highlighter-line-number {
color: var(--color-text-quaternary);
}
.markdown-body .abcjs-inline-audio .abcjs-btn {
display: flex !important;
}

View File

@@ -748,9 +748,6 @@
"no-restricted-imports": {
"count": 2
},
"react-refresh/only-export-components": {
"count": 1
},
"tailwindcss/enforce-consistent-class-order": {
"count": 1
}
@@ -3150,9 +3147,6 @@
"react/set-state-in-effect": {
"count": 7
},
"tailwindcss/enforce-consistent-class-order": {
"count": 1
},
"ts/no-explicit-any": {
"count": 9
}
@@ -3339,11 +3333,6 @@
"count": 2
}
},
"app/components/base/notion-page-selector/base.tsx": {
"react/set-state-in-effect": {
"count": 2
}
},
"app/components/base/notion-page-selector/credential-selector/index.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 2
@@ -3359,14 +3348,6 @@
"count": 1
}
},
"app/components/base/notion-page-selector/page-selector/index.tsx": {
"react/set-state-in-effect": {
"count": 1
},
"tailwindcss/enforce-consistent-class-order": {
"count": 3
}
},
"app/components/base/notion-page-selector/search-input/index.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 1
@@ -3635,17 +3616,11 @@
}
},
"app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx": {
"react-refresh/only-export-components": {
"count": 1
},
"ts/no-explicit-any": {
"count": 2
}
},
"app/components/base/prompt-editor/plugins/update-block.tsx": {
"react-refresh/only-export-components": {
"count": 2
},
"ts/no-explicit-any": {
"count": 2
}
@@ -4848,16 +4823,6 @@
"count": 1
}
},
"app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/index.tsx": {
"react/set-state-in-effect": {
"count": 1
}
},
"app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/item.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 3
}
},
"app/components/datasets/documents/create-from-pipeline/data-source/online-documents/title.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 1
@@ -7927,9 +7892,6 @@
}
},
"app/components/signin/countdown.tsx": {
"react-refresh/only-export-components": {
"count": 2
},
"tailwindcss/enforce-consistent-class-order": {
"count": 1
}
@@ -8332,7 +8294,7 @@
},
"app/components/workflow/block-selector/index-bar.tsx": {
"react-refresh/only-export-components": {
"count": 5
"count": 1
},
"tailwindcss/enforce-consistent-class-order": {
"count": 1

View File

@@ -1,14 +1,12 @@
// @ts-check
import antfu, { GLOB_MARKDOWN, GLOB_MARKDOWN_CODE, GLOB_TESTS, GLOB_TS, GLOB_TSX, isInEditorEnv, isInGitHooksOrLintStaged } from '@antfu/eslint-config'
import pluginReact from '@eslint-react/eslint-plugin'
import pluginQuery from '@tanstack/eslint-plugin-query'
import md from 'eslint-markdown'
import tailwindcss from 'eslint-plugin-better-tailwindcss'
import hyoban from 'eslint-plugin-hyoban'
import markdownPreferences from 'eslint-plugin-markdown-preferences'
import noBarrelFiles from 'eslint-plugin-no-barrel-files'
import { reactRefresh } from 'eslint-plugin-react-refresh'
import sonar from 'eslint-plugin-sonarjs'
import storybook from 'eslint-plugin-storybook'
import {
@@ -26,11 +24,14 @@ process.env.TAILWIND_MODE ??= 'ESLINT'
const disableRuleAutoFix = !(isInEditorEnv() || isInGitHooksOrLintStaged())
const plugins = pluginReact.configs.all.plugins
export default antfu(
{
react: false,
react: {
overrides: {
'react/set-state-in-effect': 'error',
'react/no-unnecessary-use-prefix': 'error',
},
},
nextjs: {
overrides: {
'next/no-img-element': 'off',
@@ -58,24 +59,6 @@ export default antfu(
e18e: false,
pnpm: false,
},
{
plugins: {
'react': plugins?.['@eslint-react'],
'react-dom': plugins?.['@eslint-react/dom'],
'react-naming-convention': plugins?.['@eslint-react/naming-convention'],
'react-rsc': plugins?.['@eslint-react/rsc'],
'react-web-api': plugins?.['@eslint-react/web-api'],
},
},
{
files: [GLOB_TS, GLOB_TSX],
rules: {
...pluginReact.configs['recommended-typescript'].rules,
'react/prefer-namespace-import': 'error',
'react/set-state-in-effect': 'error',
'react/no-unnecessary-use-prefix': 'error',
},
},
{
files: [...GLOB_TESTS, GLOB_MARKDOWN_CODE, 'vitest.setup.ts', 'test/i18n-mock.ts'],
rules: {
@@ -92,7 +75,6 @@ export default antfu(
'no-barrel-files/no-barrel-files': 'error',
},
},
reactRefresh.configs.next(),
markdownPreferences.configs.standard,
{
files: [GLOB_MARKDOWN],
@@ -231,10 +213,3 @@ export default antfu(
'tailwindcss/no-unnecessary-whitespace',
]
: [])
.renamePlugins({
'@eslint-react': 'react',
'@eslint-react/dom': 'react-dom',
'@eslint-react/naming-convention': 'react-naming-convention',
'@eslint-react/rsc': 'react-rsc',
'@eslint-react/web-api': 'react-web-api',
})

View File

@@ -6,7 +6,6 @@ export function getInitOptions(): InitOptions {
// We do not have en for fallback
load: 'currentOnly',
fallbackLng: 'en-US',
showSupportNotice: false,
partialBundledLanguages: true,
keySeparator: false,
ns: namespaces,

View File

@@ -136,6 +136,9 @@
"dataSource.notion.pagesAuthorized": "Pages authorized",
"dataSource.notion.remove": "Remove",
"dataSource.notion.selector.addPages": "Add pages",
"dataSource.notion.selector.configure": "Configure Notion",
"dataSource.notion.selector.docs": "Notion docs",
"dataSource.notion.selector.headerTitle": "Choose Notion pages",
"dataSource.notion.selector.noSearchResult": "No search results",
"dataSource.notion.selector.pageSelected": "Pages Selected",
"dataSource.notion.selector.preview": "PREVIEW",

View File

@@ -81,6 +81,7 @@
"@tailwindcss/typography": "catalog:",
"@tanstack/react-form": "catalog:",
"@tanstack/react-query": "catalog:",
"@tanstack/react-virtual": "catalog:",
"abcjs": "catalog:",
"ahooks": "catalog:",
"class-variance-authority": "catalog:",
@@ -100,6 +101,7 @@
"es-toolkit": "catalog:",
"fast-deep-equal": "catalog:",
"foxact": "catalog:",
"hast-util-to-jsx-runtime": "catalog:",
"html-entities": "catalog:",
"html-to-image": "catalog:",
"i18next": "catalog:",
@@ -134,14 +136,13 @@
"react-papaparse": "catalog:",
"react-pdf-highlighter": "catalog:",
"react-sortablejs": "catalog:",
"react-syntax-highlighter": "catalog:",
"react-textarea-autosize": "catalog:",
"react-window": "catalog:",
"reactflow": "catalog:",
"remark-breaks": "catalog:",
"remark-directive": "catalog:",
"scheduler": "catalog:",
"sharp": "catalog:",
"shiki": "catalog:",
"sortablejs": "catalog:",
"std-semver": "catalog:",
"streamdown": "catalog:",
@@ -196,8 +197,6 @@
"@types/qs": "catalog:",
"@types/react": "catalog:",
"@types/react-dom": "catalog:",
"@types/react-syntax-highlighter": "catalog:",
"@types/react-window": "catalog:",
"@types/sortablejs": "catalog:",
"@typescript-eslint/parser": "catalog:",
"@typescript/native-preview": "catalog:",
@@ -212,7 +211,6 @@
"eslint-plugin-hyoban": "catalog:",
"eslint-plugin-markdown-preferences": "catalog:",
"eslint-plugin-no-barrel-files": "catalog:",
"eslint-plugin-react-hooks": "catalog:",
"eslint-plugin-react-refresh": "catalog:",
"eslint-plugin-sonarjs": "catalog:",
"eslint-plugin-storybook": "catalog:",
@@ -221,7 +219,6 @@
"knip": "catalog:",
"postcss": "catalog:",
"react-server-dom-webpack": "catalog:",
"sass": "catalog:",
"storybook": "catalog:",
"tailwindcss": "catalog:",
"tsx": "catalog:",