Compare commits

..

12 Commits

Author SHA1 Message Date
QuantumGhost
3d0ff9463f Merge branch 'fix/redis-pubsub-perf' into feat/hitl 2026-02-06 14:42:39 +08:00
QuantumGhost
b893d2df82 docs(api): add a short note about the target_node argument 2026-02-06 14:42:04 +08:00
QuantumGhost
79b6117d80 fixup! fix(api): fix performance issue in ShardedRedisBroadcastChannel 2026-02-06 14:35:19 +08:00
WTW0313
d2ef434dec Merge branch 'main' into feat/hitl 2026-02-06 13:58:24 +08:00
QuantumGhost
aaf83c2b4c chore(api): fix linting issue 2026-02-05 16:15:32 +08:00
QuantumGhost
d898bcff90 feat(api): adjust timeout for get_message to 1s 2026-02-05 15:22:09 +08:00
twwu
b4cf146c85 Merge branch 'main' into feat/hitl 2026-02-05 14:56:02 +08:00
QuantumGhost
f21782a9a3 fix(api): fix performance issue in ShardedRedisBroadcastChannel 2026-02-05 13:28:39 +08:00
JzoNg
e4455987e7 fix: do not stop when workflow paused event recieved 2026-02-05 11:16:14 +08:00
twwu
b2ceb41dd6 Merge branch 'main' into feat/hitl 2026-02-05 11:13:40 +08:00
QuantumGhost
f614153f30 chore(api): fix circular import 2026-02-02 16:52:43 +08:00
QuantumGhost
8ca020e179 Revert "revert: revert human input relevant code (#31766)"
This reverts commit 90fe9abab7.
2026-02-01 16:21:14 +08:00
124 changed files with 1751 additions and 6401 deletions

View File

@@ -4,7 +4,8 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "build/feat/hitl"
- "feat/hitl-frontend"
- "feat/hitl-backend"
types:
- completed
@@ -13,7 +14,10 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'build/feat/hitl'
(
github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
github.event.workflow_run.head_branch == 'feat/hitl-backend'
)
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v1

View File

@@ -37,7 +37,7 @@
"-c",
"1",
"-Q",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
"--loglevel",
"INFO"
],

View File

@@ -52,12 +52,14 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.graph_engine.manager -> extensions.ext_redis
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
# TODO(QuantumGhost): use DI to avoid depending on global DB.
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
@@ -104,6 +106,8 @@ forbidden_modules =
core.trigger
core.variables
ignore_imports =
core.workflow.nodes.agent.agent_node -> core.db.session_factory
core.workflow.nodes.agent.agent_node -> models.tools
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.workflow_entry -> core.app.workflow.layers.observability
@@ -124,6 +128,11 @@ ignore_imports =
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.llm.llm_utils -> configs
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
core.workflow.nodes.llm.llm_utils -> core.file.models
@@ -144,6 +153,7 @@ ignore_imports =
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
@@ -159,6 +169,9 @@ ignore_imports =
core.workflow.workflow_entry -> core.app.workflow.node_factory
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
@@ -207,6 +220,7 @@ ignore_imports =
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
core.workflow.nodes.llm.node -> core.model_manager
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
@@ -222,6 +236,7 @@ ignore_imports =
core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.llm.node -> models.dataset
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
core.workflow.nodes.llm.file_saver -> core.tools.signature
@@ -280,6 +295,8 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database

View File

@@ -122,7 +122,7 @@ These commands assume you start from the repository root.
```bash
cd api
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
```
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).

View File

@@ -1180,16 +1180,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
default=0,
)
# API token last_used_at batch update
ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK: bool = Field(
description="Enable periodic batch update of API token last_used_at timestamps",
default=True,
)
API_TOKEN_LAST_USED_UPDATE_INTERVAL: int = Field(
description="Interval in minutes for batch updating API token last_used_at (default 30)",
default=30,
)
# Trigger provider refresh (simple version)
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
description="Enable trigger provider refresh poller",

View File

@@ -5,6 +5,8 @@ from enum import StrEnum
from flask_restx import Namespace
from pydantic import BaseModel, TypeAdapter
from controllers.console import console_ns
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -22,9 +24,6 @@ def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> No
def get_or_create_model(model_name: str, field_def):
# Import lazily to avoid circular imports between console controllers and schema helpers.
from controllers.console import console_ns
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)

View File

@@ -10,7 +10,6 @@ from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache
from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required
@@ -132,11 +131,6 @@ class BaseApiKeyResource(Resource):
if key is None:
flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found")
# Invalidate cache before deleting from database
# Type assertion: key is guaranteed to be non-None here because abort() raises
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()

View File

@@ -463,9 +463,8 @@ class WorkflowRunNodeExecutionListApi(Resource):
class ConsoleWorkflowPauseDetailsApi(Resource):
"""Console API for getting workflow pause details."""
@setup_required
@login_required
@account_initialization_required
@login_required
def get(self, workflow_run_id: str):
"""
Get workflow pause details.
@@ -478,14 +477,10 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
# Query WorkflowRun to determine if workflow is suspended
session_maker = sessionmaker(bind=db.engine)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_maker)
workflow_run = db.session.get(WorkflowRun, workflow_run_id)
if not workflow_run:
raise NotFoundError("Workflow run not found")
if workflow_run.tenant_id != current_user.current_tenant_id:
raise NotFoundError("Workflow run not found")
# Check if workflow is suspended
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
if not is_paused:

View File

@@ -55,7 +55,6 @@ from libs.login import current_account_with_tenant, login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
# Register models for flask_restx to avoid dict type issues in Swagger
@@ -821,11 +820,6 @@ class DatasetApiDeleteApi(Resource):
if key is None:
console_ns.abort(404, message="API key not found")
# Invalidate cache before deleting from database
# Type assertion: key is guaranteed to be non-None here because abort() raises
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()

View File

@@ -120,7 +120,7 @@ class TagUpdateDeleteApi(Resource):
TagService.delete_tag(tag_id)
return "", 204
return 204
@console_ns.route("/tag-bindings/create")

View File

@@ -34,8 +34,6 @@ from .dataset import (
metadata,
segment,
)
from .dataset.rag_pipeline import rag_pipeline_workflow
from .end_user import end_user
from .workspace import models
__all__ = [
@@ -46,7 +44,6 @@ __all__ = [
"conversation",
"dataset",
"document",
"end_user",
"file",
"file_preview",
"hit_testing",
@@ -54,7 +51,6 @@ __all__ = [
"message",
"metadata",
"models",
"rag_pipeline_workflow",
"segment",
"site",
"workflow",

View File

@@ -396,7 +396,7 @@ class DatasetApi(DatasetApiResource):
try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return "", 204
return 204
else:
raise NotFound("Dataset not found.")
except services.errors.dataset.DatasetInUseError:
@@ -557,7 +557,7 @@ class DatasetTagsApi(DatasetApiResource):
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(payload.tag_id)
return "", 204
return 204
@service_api_ns.route("/datasets/tags/binding")
@@ -581,7 +581,7 @@ class DatasetTagBindingApi(DatasetApiResource):
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
return "", 204
return 204
@service_api_ns.route("/datasets/tags/unbinding")
@@ -605,7 +605,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
return "", 204
return 204
@service_api_ns.route("/datasets/<uuid:dataset_id>/tags")

View File

@@ -746,4 +746,4 @@ class DocumentApi(DatasetApiResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return "", 204
return 204

View File

@@ -128,7 +128,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
return "", 204
return 204
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in")

View File

@@ -1,3 +1,5 @@
import string
import uuid
from collections.abc import Generator
from typing import Any
@@ -10,7 +12,6 @@ from controllers.common.errors import FilenameNotExistsError, NoFileUploadedErro
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import PipelineRunError
from controllers.service_api.dataset.rag_pipeline.serializers import serialize_upload_file
from controllers.service_api.wraps import DatasetApiResource
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -40,7 +41,7 @@ register_schema_model(service_api_ns, DatasourceNodeRunPayload)
register_schema_model(service_api_ns, PipelineRunApiEntity)
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/datasource-plugins")
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
class DatasourcePluginsApi(DatasetApiResource):
"""Resource for datasource plugins."""
@@ -75,7 +76,7 @@ class DatasourcePluginsApi(DatasetApiResource):
return datasource_plugins, 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/datasource/nodes/<string:node_id>/run")
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
class DatasourceNodeRunApi(DatasetApiResource):
"""Resource for datasource node run."""
@@ -130,7 +131,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
)
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/run")
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run")
class PipelineRunApi(DatasetApiResource):
"""Resource for datasource node run."""
@@ -231,4 +232,12 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return serialize_upload_file(upload_file), 201
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at,
}, 201

View File

@@ -1,22 +0,0 @@
"""
Serialization helpers for Service API knowledge pipeline endpoints.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from models.model import UploadFile
def serialize_upload_file(upload_file: UploadFile) -> dict[str, Any]:
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at.isoformat() if upload_file.created_at else None,
}

View File

@@ -233,7 +233,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
SegmentService.delete_segment(segment, document, dataset)
return "", 204
return 204
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
@service_api_ns.doc("update_segment")
@@ -499,7 +499,7 @@ class DatasetChildChunkApi(DatasetApiResource):
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
return "", 204
return 204
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
@service_api_ns.doc("update_child_chunk")

View File

@@ -1,3 +0,0 @@
from . import end_user
__all__ = ["end_user"]

View File

@@ -1,41 +0,0 @@
from uuid import UUID
from flask_restx import Resource
from controllers.service_api import service_api_ns
from controllers.service_api.end_user.error import EndUserNotFoundError
from controllers.service_api.wraps import validate_app_token
from fields.end_user_fields import EndUserDetail
from models.model import App
from services.end_user_service import EndUserService
@service_api_ns.route("/end-users/<uuid:end_user_id>")
class EndUserApi(Resource):
"""Resource for retrieving end user details by ID."""
@service_api_ns.doc("get_end_user")
@service_api_ns.doc(description="Get an end user by ID")
@service_api_ns.doc(
params={"end_user_id": "End user ID"},
responses={
200: "End user retrieved successfully",
401: "Unauthorized - invalid API token",
404: "End user not found",
},
)
@validate_app_token
def get(self, app_model: App, end_user_id: UUID):
"""Get end user detail.
This endpoint is scoped to the current app token's tenant/app to prevent
cross-tenant/app access when an end-user ID is known.
"""
end_user = EndUserService.get_end_user_by_id(
tenant_id=app_model.tenant_id, app_id=app_model.id, end_user_id=str(end_user_id)
)
if end_user is None:
raise EndUserNotFoundError()
return EndUserDetail.model_validate(end_user).model_dump(mode="json")

View File

@@ -1,7 +0,0 @@
from libs.exception import BaseHTTPException
class EndUserNotFoundError(BaseHTTPException):
error_code = "end_user_not_found"
description = "End user not found."
code = 404

View File

@@ -1,24 +1,27 @@
import logging
import time
from collections.abc import Callable
from datetime import timedelta
from enum import StrEnum, auto
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar, cast
from typing import Concatenate, ParamSpec, TypeVar
from flask import current_app, request
from flask_login import user_logged_in
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache, fetch_token_with_single_flight, record_token_usage
from services.end_user_service import EndUserService
from services.feature_service import FeatureService
@@ -217,8 +220,6 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
def decorator(view: Callable[Concatenate[T, P], R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token("dataset")
# get url path dataset_id from positional args or kwargs
# Flask passes URL path parameters as positional arguments
dataset_id = None
@@ -255,18 +256,12 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
# Validate dataset if dataset_id is provided
if dataset_id:
dataset_id = str(dataset_id)
dataset = (
db.session.query(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == api_token.tenant_id,
)
.first()
)
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.")
api_token = validate_and_get_api_token("dataset")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
@@ -301,14 +296,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
def validate_and_get_api_token(scope: str | None = None):
"""
Validate and get API token with Redis caching.
This function uses a two-tier approach:
1. First checks Redis cache for the token
2. If not cached, queries database and caches the result
The last_used_at field is updated asynchronously via Celery task
to avoid blocking the request.
Validate and get API token.
"""
auth_header = request.headers.get("Authorization")
if auth_header is None or " " not in auth_header:
@@ -320,18 +308,29 @@ def validate_and_get_api_token(scope: str | None = None):
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")
# Try to get token from cache first
# Returns a CachedApiToken (plain Python object), not a SQLAlchemy model
cached_token = ApiTokenCache.get(auth_token, scope)
if cached_token is not None:
logger.debug("Token validation served from cache for scope: %s", scope)
# Record usage in Redis for later batch update (no Celery task per request)
record_token_usage(auth_token, scope)
return cast(ApiToken, cached_token)
current_time = naive_utc_now()
cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
update(ApiToken)
.where(
ApiToken.token == auth_token,
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
ApiToken.type == scope,
)
.values(last_used_at=current_time)
)
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
result = session.execute(update_stmt)
api_token = session.scalar(stmt)
# Cache miss - use Redis lock for single-flight mode
# This ensures only one request queries DB for the same token concurrently
return fetch_token_with_single_flight(auth_token, scope)
if hasattr(result, "rowcount") and result.rowcount > 0:
session.commit()
if not api_token:
raise Unauthorized("Access token is invalid")
return api_token
class DatasetApiResource(Resource):

View File

@@ -65,12 +65,15 @@ def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Re
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
# TODO(QuantumGhost): disable authorization for web app
# form api temporarily
@web_ns.route("/form/human_input/<string:form_token>")
# class HumanInputFormApi(WebApiResource):
class HumanInputFormApi(Resource):
"""API for getting and submitting human input forms via the web app."""
# NOTE(QuantumGhost): this endpoint is unauthenticated on purpose for now.
# def get(self, _app_model: App, _end_user: EndUser, form_token: str):
def get(self, form_token: str):
"""

View File

@@ -346,7 +346,7 @@ class WorkflowResponseConverter:
paused_nodes=list(event.paused_nodes),
outputs=encoded_outputs,
reasons=pause_reasons,
status=WorkflowExecutionStatus.PAUSED,
status=WorkflowExecutionStatus.PAUSED.value,
created_at=int(started_at.timestamp()),
elapsed_time=elapsed_time,
total_tokens=graph_runtime_state.total_tokens,
@@ -422,7 +422,7 @@ class WorkflowResponseConverter:
data=WorkflowFinishStreamResponse.Data(
id=run_id,
workflow_id=workflow_run.workflow_id,
status=workflow_run.status,
status=workflow_run.status.value,
outputs=encoded_outputs,
error=workflow_run.error,
elapsed_time=elapsed_time,

View File

@@ -262,7 +262,7 @@ class WorkflowPauseStreamResponse(StreamResponse):
paused_nodes: Sequence[str] = Field(default_factory=list)
outputs: Mapping[str, Any] = Field(default_factory=dict)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
status: WorkflowExecutionStatus
status: str
created_at: int
elapsed_time: float
total_tokens: int

View File

@@ -8,7 +8,6 @@ from core.file.file_manager import file_manager
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.ssrf_proxy import ssrf_proxy
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType
@@ -17,7 +16,6 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from core.workflow.nodes.template_transform.template_renderer import (
@@ -77,7 +75,6 @@ class DifyNodeFactory(NodeFactory):
self._http_request_http_client = http_request_http_client or ssrf_proxy
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
self._http_request_file_manager = http_request_file_manager or file_manager
self._rag_retrieval = DatasetRetrieval()
@override
def create_node(self, node_config: NodeConfigDict) -> Node:
@@ -143,15 +140,6 @@ class DifyNodeFactory(NodeFactory):
file_manager=self._http_request_file_manager,
)
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return KnowledgeRetrievalNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
rag_retrieval=self._rag_retrieval,
)
return node_class(
id=node_id,
config=node_config,

View File

@@ -1,15 +1,13 @@
import json
import logging
import math
import re
import threading
import time
from collections import Counter, defaultdict
from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from flask import Flask, current_app
from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy import and_, literal, or_, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import (
@@ -20,7 +18,6 @@ from core.app.app_config.entities import (
)
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.db.session_factory import session_factory
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.file import File, FileTransferMethod, FileType
@@ -61,30 +58,12 @@ from core.rag.retrieval.template_prompts import (
)
from core.tools.signature import sign_upload_file
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.workflow.nodes.knowledge_retrieval import exc
from core.workflow.repositories.rag_retrieval_protocol import (
KnowledgeRetrievalRequest,
Source,
SourceChildChunk,
SourceMetadata,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown
from models import UploadFile
from models.dataset import (
ChildChunk,
Dataset,
DatasetMetadata,
DatasetQuery,
DocumentSegment,
RateLimitLog,
SegmentAttachmentBinding,
)
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from models.dataset import Document as DocumentModel
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureService
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
@@ -94,8 +73,6 @@ default_retrieval_model: dict[str, Any] = {
"score_threshold_enabled": False,
}
logger = logging.getLogger(__name__)
class DatasetRetrieval:
def __init__(self, application_generate_entity=None):
@@ -114,233 +91,6 @@ class DatasetRetrieval:
else:
self._llm_usage = self._llm_usage.plus(usage)
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]:
self._check_knowledge_rate_limit(request.tenant_id)
available_datasets = self._get_available_datasets(request.tenant_id, request.dataset_ids)
available_datasets_ids = [i.id for i in available_datasets]
if not available_datasets_ids:
return []
if not request.query:
return []
metadata_filter_document_ids, metadata_condition = None, None
if request.metadata_filtering_mode != "disabled":
# Convert workflow layer types to app_config layer types
if not request.metadata_model_config:
raise ValueError("metadata_model_config is required for this method")
app_metadata_model_config = ModelConfig.model_validate(request.metadata_model_config.model_dump())
app_metadata_filtering_conditions = None
if request.metadata_filtering_conditions is not None:
app_metadata_filtering_conditions = MetadataFilteringCondition.model_validate(
request.metadata_filtering_conditions.model_dump()
)
query = request.query if request.query is not None else ""
metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
dataset_ids=available_datasets_ids,
query=query,
tenant_id=request.tenant_id,
user_id=request.user_id,
metadata_filtering_mode=request.metadata_filtering_mode,
metadata_model_config=app_metadata_model_config,
metadata_filtering_conditions=app_metadata_filtering_conditions,
inputs={},
)
if request.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
planning_strategy = PlanningStrategy.REACT_ROUTER
# Ensure required fields are not None for single retrieval mode
if request.model_provider is None or request.model_name is None or request.query is None:
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=request.tenant_id,
model_type=ModelType.LLM,
provider=request.model_provider,
model=request.model_name,
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=request.model_name, model_type=ModelType.LLM
)
if provider_model is None:
raise exc.ModelNotExistError(f"Model {request.model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise exc.ModelCredentialsNotInitializedError(
f"Model {request.model_name} credentials is not initialized."
)
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise exc.ModelNotSupportedError(f"Dify Hosted OpenAI {request.model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise exc.ModelQuotaExceededError(f"Model provider {request.model_provider} quota exceeded.")
stop = []
completion_params = (request.completion_params or {}).copy()
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
model_schema = model_type_instance.get_model_schema(request.model_name, model_credentials)
if not model_schema:
raise exc.ModelNotExistError(f"Model {request.model_name} not exist.")
model_config = ModelConfigWithCredentialsEntity(
provider=request.model_provider,
model=request.model_name,
model_schema=model_schema,
mode=request.model_mode or "chat",
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
all_documents = self.single_retrieve(
request.app_id,
request.tenant_id,
request.user_id,
request.user_from,
request.query,
available_datasets,
model_instance,
model_config,
planning_strategy,
None, # message_id
metadata_filter_document_ids,
metadata_condition,
)
else:
all_documents = self.multiple_retrieve(
app_id=request.app_id,
tenant_id=request.tenant_id,
user_id=request.user_id,
user_from=request.user_from,
available_datasets=available_datasets,
query=request.query,
top_k=request.top_k,
score_threshold=request.score_threshold,
reranking_mode=request.reranking_mode,
reranking_model=request.reranking_model,
weights=request.weights,
reranking_enable=request.reranking_enable,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
attachment_ids=request.attachment_ids,
)
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
retrieval_resource_list = []
# deal with external documents
for item in external_documents:
source = Source(
metadata=SourceMetadata(
source="knowledge",
dataset_id=item.metadata.get("dataset_id"),
dataset_name=item.metadata.get("dataset_name"),
document_id=item.metadata.get("document_id"),
document_name=item.metadata.get("title"),
data_source_type="external",
retriever_from="workflow",
score=item.metadata.get("score"),
doc_metadata=item.metadata,
),
title=item.metadata.get("title"),
content=item.page_content,
)
retrieval_resource_list.append(source)
# deal with dify documents
if dify_documents:
records = RetrievalService.format_retrieval_documents(dify_documents)
dataset_ids = [i.segment.dataset_id for i in records]
document_ids = [i.segment.document_id for i in records]
with session_factory.create_session() as session:
datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all()
dataset_map = {i.id: i for i in datasets}
document_map = {i.id: i for i in documents}
if records:
for record in records:
segment = record.segment
dataset = dataset_map.get(segment.dataset_id)
document = document_map.get(segment.document_id)
if dataset and document:
source = Source(
metadata=SourceMetadata(
source="knowledge",
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from="workflow",
score=record.score or 0.0,
segment_hit_count=segment.hit_count,
segment_word_count=segment.word_count,
segment_position=segment.position,
segment_index_node_hash=segment.index_node_hash,
doc_metadata=document.doc_metadata,
child_chunks=[
SourceChildChunk(
id=str(getattr(chunk, "id", "")),
content=str(getattr(chunk, "content", "")),
position=int(getattr(chunk, "position", 0)),
score=float(getattr(chunk, "score", 0.0)),
)
for chunk in (record.child_chunks or [])
],
position=None,
),
title=document.name,
files=list(record.files) if record.files else None,
content=segment.get_sign_content(),
)
if segment.answer:
source.content = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
if record.summary:
source.summary = record.summary
retrieval_resource_list.append(source)
if retrieval_resource_list:
def _score(item: Source) -> float:
meta = item.metadata
score = meta.score
if isinstance(score, (int, float)):
return float(score)
return 0.0
retrieval_resource_list = sorted(
retrieval_resource_list,
key=_score, # type: ignore[arg-type, return-value]
reverse=True,
)
for position, item in enumerate(retrieval_resource_list, start=1):
item.metadata.position = position # type: ignore[index]
return retrieval_resource_list
def retrieve(
self,
app_id: str,
@@ -400,7 +150,14 @@ class DatasetRetrieval:
if features:
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
available_datasets = self._get_available_datasets(tenant_id, dataset_ids)
available_datasets = []
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
for dataset in datasets:
if dataset.available_document_count == 0 and dataset.provider != "external":
continue
available_datasets.append(dataset)
if inputs:
inputs = {key: str(value) for key, value in inputs.items()}
@@ -1404,6 +1161,7 @@ class DatasetRetrieval:
query=query or "",
)
result_text = ""
try:
# handle invoke result
invoke_result = cast(
@@ -1434,8 +1192,7 @@ class DatasetRetrieval:
"condition": item.get("comparison_operator"),
}
)
except Exception as e:
logger.warning(e, exc_info=True)
except Exception:
return None
return automatic_metadata_filters
@@ -1649,12 +1406,7 @@ class DatasetRetrieval:
usage = None
for result in invoke_result:
text = result.delta.message.content
if isinstance(text, str):
full_text += text
elif isinstance(text, list):
for i in text:
if i.data:
full_text += i.data
full_text += text
if not model:
model = result.model
@@ -1772,53 +1524,3 @@ class DatasetRetrieval:
cancel_event.set()
if thread_exceptions is not None:
thread_exceptions.append(e)
def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]:
with session_factory.create_session() as session:
subquery = (
session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
.where(
DocumentModel.indexing_status == "completed",
DocumentModel.enabled == True,
DocumentModel.archived == False,
DocumentModel.dataset_id.in_(dataset_ids),
)
.group_by(DocumentModel.dataset_id)
.having(func.count(DocumentModel.id) > 0)
.subquery()
)
results = (
session.query(Dataset)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all()
)
available_datasets = []
for dataset in results:
if not dataset:
continue
available_datasets.append(dataset)
return available_datasets
def _check_knowledge_rate_limit(self, tenant_id: str):
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
with session_factory.create_session() as session:
rate_limit_log = RateLimitLog(
tenant_id=tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
session.add(rate_limit_log)
raise exc.RateLimitExceededError(
"you have reached the knowledge base request rate limit of your subscription."
)

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
import base64
import json
import logging
from collections.abc import Generator, Mapping
from typing import Any, cast
from collections.abc import Generator
from typing import Any
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
@@ -17,7 +17,6 @@ from core.mcp.types import (
TextContent,
TextResourceContents,
)
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
@@ -47,7 +46,6 @@ class MCPTool(Tool):
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self._latest_usage = LLMUsage.empty_usage()
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.MCP
@@ -61,10 +59,6 @@ class MCPTool(Tool):
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
result = self.invoke_remote_mcp_tool(tool_parameters)
# Extract usage metadata from MCP protocol's _meta field
self._latest_usage = self._derive_usage_from_result(result)
# handle dify tool output
for content in result.content:
if isinstance(content, TextContent):
@@ -126,99 +120,6 @@ class MCPTool(Tool):
for item in json_list:
yield self.create_json_message(item)
@property
def latest_usage(self) -> LLMUsage:
return self._latest_usage
@classmethod
def _derive_usage_from_result(cls, result: CallToolResult) -> LLMUsage:
"""
Extract usage metadata from MCP tool result's _meta field.
The MCP protocol's _meta field (aliased as 'meta' in Python) can contain
usage information such as token counts, costs, and other metadata.
Args:
result: The CallToolResult from MCP tool invocation
Returns:
LLMUsage instance with values from meta or empty_usage if not found
"""
# Extract usage from the meta field if present
if result.meta:
usage_dict = cls._extract_usage_dict(result.meta)
if usage_dict is not None:
return LLMUsage.from_metadata(cast(LLMUsageMetadata, cast(object, dict(usage_dict))))
return LLMUsage.empty_usage()
@classmethod
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
"""
Recursively search for usage dictionary in the payload.
The MCP protocol's _meta field can contain usage data in various formats:
- Direct usage field: {"usage": {...}}
- Nested in metadata: {"metadata": {"usage": {...}}}
- Or nested within other fields
Args:
payload: The payload to search for usage data
Returns:
The usage dictionary if found, None otherwise
"""
# Check for direct usage field
usage_candidate = payload.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
# Check for metadata nested usage
metadata_candidate = payload.get("metadata")
if isinstance(metadata_candidate, Mapping):
usage_candidate = metadata_candidate.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
# Check for common token counting fields directly in payload
# Some MCP servers may include token counts directly
if "total_tokens" in payload or "prompt_tokens" in payload or "completion_tokens" in payload:
usage_dict: dict[str, Any] = {}
for key in (
"prompt_tokens",
"completion_tokens",
"total_tokens",
"prompt_unit_price",
"completion_unit_price",
"total_price",
"currency",
"prompt_price_unit",
"completion_price_unit",
"prompt_price",
"completion_price",
"latency",
"time_to_first_token",
"time_to_generate",
):
if key in payload:
usage_dict[key] = payload[key]
if usage_dict:
return usage_dict
# Recursively search through nested structures
for value in payload.values():
if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
elif isinstance(value, list) and not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
return None
def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
return MCPTool(
entity=self.entity,

View File

@@ -10,7 +10,6 @@ from pydantic import BaseModel, Field
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.enums import NodeState
from core.workflow.runtime.graph_runtime_state import GraphExecutionProtocol
from .node_execution import NodeExecution
@@ -237,6 +236,3 @@ class GraphExecution:
def record_node_failure(self) -> None:
"""Increment the count of node failures encountered during execution."""
self.exceptions_count += 1
_: GraphExecutionProtocol = GraphExecution(workflow_id="")

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Union, cast
from packaging.version import Version
from pydantic import ValidationError
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.db.session_factory import session_factory
from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
@@ -49,6 +50,12 @@ from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
from models import ToolFile
from models.model import Conversation
from models.tools import (
ApiToolProvider,
BuiltinToolProvider,
MCPToolProvider,
WorkflowToolProvider,
)
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exc import (
@@ -259,7 +266,7 @@ class AgentNode(Node[AgentNodeData]):
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
provider_type = self._infer_tool_provider_type(tool, self.tenant_id)
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
@@ -748,3 +755,34 @@ class AgentNode(Node[AgentNodeData]):
llm_usage=llm_usage,
)
)
@staticmethod
def _infer_tool_provider_type(tool_config: dict[str, Any], tenant_id: str) -> ToolProviderType:
provider_type_str = tool_config.get("type")
if provider_type_str:
return ToolProviderType(provider_type_str)
provider_id = tool_config.get("provider_name")
if not provider_id:
return ToolProviderType.BUILT_IN
with session_factory.create_session() as session:
provider_map: dict[
type[Union[WorkflowToolProvider, MCPToolProvider, ApiToolProvider, BuiltinToolProvider]],
ToolProviderType,
] = {
WorkflowToolProvider: ToolProviderType.WORKFLOW,
MCPToolProvider: ToolProviderType.MCP,
ApiToolProvider: ToolProviderType.API,
BuiltinToolProvider: ToolProviderType.BUILT_IN,
}
for provider_model, provider_type in provider_map.items():
stmt = select(provider_model).where(
provider_model.id == provider_id,
provider_model.tenant_id == tenant_id,
)
if session.scalar(stmt):
return provider_type
raise AgentNodeError(f"Tool provider with ID '{provider_id}' not found.")

View File

@@ -20,7 +20,3 @@ class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
class InvalidModelTypeError(KnowledgeRetrievalNodeError):
"""Raised when the model is not a Large Language Model."""
class RateLimitExceededError(KnowledgeRetrievalNodeError):
"""Raised when the rate limit is exceeded."""

View File

@@ -1,10 +1,29 @@
import json
import logging
import re
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import (
ArrayFileSegment,
FileSegment,
@@ -17,16 +36,35 @@ from core.workflow.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import NodeRunResult
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
METADATA_FILTER_COMPLETION_PROMPT,
METADATA_FILTER_SYSTEM_PROMPT,
METADATA_FILTER_USER_PROMPT_1,
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from services.feature_service import FeatureService
from .entities import KnowledgeRetrievalNodeData
from .exc import (
InvalidModelTypeError,
KnowledgeRetrievalNodeError,
RateLimitExceededError,
ModelCredentialsNotInitializedError,
ModelNotExistError,
ModelNotSupportedError,
ModelQuotaExceededError,
)
if TYPE_CHECKING:
@@ -35,6 +73,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,
"score_threshold_enabled": False,
}
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = NodeType.KNOWLEDGE_RETRIEVAL
@@ -51,7 +97,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
rag_retrieval: RAGRetrievalProtocol,
*,
llm_file_saver: LLMFileSaver | None = None,
):
@@ -63,7 +108,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
self._rag_retrieval = rag_retrieval
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
@@ -77,7 +121,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
return "1"
def _run(self) -> NodeRunResult:
usage = LLMUsage.empty_usage()
if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -85,7 +128,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
process_data={},
outputs={},
metadata={},
llm_usage=usage,
llm_usage=LLMUsage.empty_usage(),
)
variables: dict[str, Any] = {}
# extract variables
@@ -113,9 +156,36 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
else:
variables["attachments"] = [variable.value]
# TODO(-LAN-): Move this check outside.
# check rate limit
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{self.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
with sessionmaker(db.engine).begin() as session:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
session.add(rate_limit_log)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
error_type="RateLimitExceeded",
)
# retrieve knowledge
usage = LLMUsage.empty_usage()
try:
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])}
outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
@@ -128,17 +198,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
},
llm_usage=usage,
)
except RateLimitExceededError as e:
logger.warning(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
except KnowledgeRetrievalNodeError as e:
logger.warning("Error when running knowledge retrieval node", exc_info=True)
logger.warning("Error when running knowledge retrieval node")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
@@ -148,7 +210,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
logger.warning(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
@@ -156,47 +217,92 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
error_type=type(e).__name__,
llm_usage=usage,
)
finally:
db.session.close()
def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
) -> tuple[list[Source], LLMUsage]:
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
available_datasets = []
dataset_ids = node_data.dataset_ids
query = variables.get("query")
attachments = variables.get("attachments")
retrieval_resource_list = []
metadata_filter_document_ids = None
metadata_condition = None
metadata_usage = LLMUsage.empty_usage()
# Subquery: Count the number of available documents for each dataset
subquery = (
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.dataset_id.in_(dataset_ids),
)
.group_by(Document.dataset_id)
.having(func.count(Document.id) > 0)
.subquery()
)
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = "disabled"
if node_data.metadata_filtering_mode is not None:
metadata_filtering_mode = node_data.metadata_filtering_mode
results = (
db.session.query(Dataset)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all()
)
# avoid blocking at retrieval
db.session.close()
for dataset in results:
# pass if dataset is not available
if not dataset:
continue
available_datasets.append(dataset)
if query:
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data
)
usage = self._merge_usage(usage, metadata_usage)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
# fetch model config
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required for single retrieval mode")
model = node_data.single_retrieval_config.model
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
request=KnowledgeRetrievalRequest(
raise ValueError("single_retrieval_config is required")
model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model)
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model_config.model, credentials=model_config.credentials
)
if model_schema:
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
all_documents = dataset_retrieval.single_retrieve(
available_datasets=available_datasets,
tenant_id=self.tenant_id,
user_id=self.user_id,
app_id=self.app_id,
user_from=self.user_from.value,
dataset_ids=dataset_ids,
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value,
completion_params=model.completion_params,
model_provider=model.provider,
model_mode=model.mode,
model_name=model.name,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
metadata_filtering_mode=metadata_filtering_mode,
query=query,
model_config=model_config,
model_instance=model_instance,
planning_strategy=planning_strategy,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
)
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
reranking_model = None
weights = None
match node_data.multiple_retrieval_config.reranking_mode:
case "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
@@ -223,36 +329,284 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
},
}
case _:
# Handle any other reranking_mode values
reranking_model = None
weights = None
all_documents = dataset_retrieval.multiple_retrieve(
app_id=self.app_id,
tenant_id=self.tenant_id,
user_id=self.user_id,
user_from=self.user_from.value,
available_datasets=available_datasets,
query=query,
top_k=node_data.multiple_retrieval_config.top_k,
score_threshold=node_data.multiple_retrieval_config.score_threshold
if node_data.multiple_retrieval_config.score_threshold is not None
else 0.0,
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
reranking_model=reranking_model,
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
)
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
request=KnowledgeRetrievalRequest(
app_id=self.app_id,
tenant_id=self.tenant_id,
user_id=self.user_id,
user_from=self.user_from.value,
dataset_ids=dataset_ids,
query=query,
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value,
top_k=node_data.multiple_retrieval_config.top_k,
score_threshold=node_data.multiple_retrieval_config.score_threshold
if node_data.multiple_retrieval_config.score_threshold is not None
else 0.0,
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
reranking_model=reranking_model,
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
metadata_filtering_mode=metadata_filtering_mode,
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
retrieval_resource_list = []
# deal with external documents
for item in external_documents:
source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = {
"metadata": {
"_source": "knowledge",
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": "workflow",
"score": item.metadata.get("score"),
"doc_metadata": item.metadata,
},
"title": item.metadata.get("title"),
"content": item.page_content,
}
retrieval_resource_list.append(source)
# deal with dify documents
if dify_documents:
records = RetrievalService.format_retrieval_documents(dify_documents)
if records:
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
document = db.session.scalar(stmt)
if dataset and document:
source = {
"metadata": {
"_source": "knowledge",
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": "workflow",
"score": record.score or 0.0,
"child_chunks": [
{
"id": str(getattr(chunk, "id", "")),
"content": str(getattr(chunk, "content", "")),
"position": int(getattr(chunk, "position", 0)),
"score": float(getattr(chunk, "score", 0.0)),
}
for chunk in (record.child_chunks or [])
],
"segment_hit_count": segment.hit_count,
"segment_word_count": segment.word_count,
"segment_position": segment.position,
"segment_index_node_hash": segment.index_node_hash,
"doc_metadata": document.doc_metadata,
},
"title": document.name,
"files": list(record.files) if record.files else None,
}
if segment.answer:
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
else:
source["content"] = segment.get_sign_content()
# Add summary if available
if record.summary:
source["summary"] = record.summary
retrieval_resource_list.append(source)
if retrieval_resource_list:
retrieval_resource_list = sorted(
retrieval_resource_list,
key=self._score, # type: ignore[arg-type, return-value]
reverse=True,
)
for position, item in enumerate(retrieval_resource_list, start=1):
item["metadata"]["position"] = position # type: ignore[index]
return retrieval_resource_list, usage
def _score(self, item: dict[str, Any]) -> float:
meta = item.get("metadata")
if isinstance(meta, dict):
s = meta.get("score")
if isinstance(s, (int, float)):
return float(s)
return 0.0
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
usage = LLMUsage.empty_usage()
document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
filters: list[Any] = []
metadata_condition = None
match node_data.metadata_filtering_mode:
case "disabled":
return None, None, usage
case "automatic":
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
)
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters,
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or",
conditions=conditions,
)
case "manual":
if node_data.metadata_filtering_conditions:
conditions = []
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
)
)
filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
case _:
raise ValueError("Invalid metadata filtering mode")
if filters:
if (
node_data.metadata_filtering_conditions
and node_data.metadata_filtering_conditions.logical_operator == "and"
):
document_query = document_query.where(and_(*filters))
else:
document_query = document_query.where(or_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition, usage
def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
# get all metadata field
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(stmt).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
if node_data.metadata_model_config is None:
raise ValueError("metadata_model_config is required")
# get metadata model instance and fetch model config
model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
# fetch prompt messages
prompt_template = self._get_prompt_template(
node_data=node_data,
metadata_fields=all_metadata_fields,
query=query or "",
)
prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query=query,
memory=None,
model_config=model_config,
sys_files=[],
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
tenant_id=self.tenant_id,
)
result_text = ""
try:
# handle invoke result
generator = LLMNode.invoke_llm(
node_data_model=node_data.metadata_model_config,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self.node_data.structured_output_enabled,
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
)
usage = self._rag_retrieval.llm_usage
return retrieval_resource_list, usage
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = self._merge_usage(usage, event.usage)
break
result_text_json = parse_and_check_json_markdown(result_text, [])
automatic_metadata_filters = []
if "metadata_map" in result_text_json:
metadata_map = result_text_json["metadata_map"]
for item in metadata_map:
if item.get("metadata_field_name") in all_metadata_fields:
automatic_metadata_filters.append(
{
"metadata_name": item.get("metadata_field_name"),
"value": item.get("metadata_field_value"),
"condition": item.get("comparison_operator"),
}
)
except Exception:
return [], usage
return automatic_metadata_filters, usage
@classmethod
def _extract_variable_selector_to_variable_mapping(
@@ -272,3 +626,107 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if typed_node_data.query_attachment_selector:
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
return variable_mapping
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model_name = model.name
provider_name = model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name, model_type=ModelType.LLM
)
if provider_model is None:
raise ModelNotExistError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = model.completion_params
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = model.mode
if not model_mode:
raise ModelNotExistError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
if not model_schema:
raise ModelNotExistError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore
input_text = query
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1
)
prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
)
prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER,
text=METADATA_FILTER_USER_PROMPT_3.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
),
)
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
return LLMNodeCompletionModelPromptTemplate(
text=METADATA_FILTER_COMPLETION_PROMPT.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
)
)
else:
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")

View File

@@ -1,108 +0,0 @@
from typing import Any, Literal, Protocol
from pydantic import BaseModel, Field
from core.model_runtime.entities import LLMUsage
from core.workflow.nodes.knowledge_retrieval.entities import MetadataFilteringCondition
from core.workflow.nodes.llm.entities import ModelConfig
class SourceChildChunk(BaseModel):
id: str = Field(default="", description="Child chunk ID")
content: str = Field(default="", description="Child chunk content")
position: int = Field(default=0, description="Child chunk position")
score: float = Field(default=0.0, description="Child chunk relevance score")
class SourceMetadata(BaseModel):
source: str = Field(
default="knowledge",
serialization_alias="_source",
description="Data source identifier",
)
dataset_id: str = Field(description="Dataset unique identifier")
dataset_name: str = Field(description="Dataset display name")
document_id: str = Field(description="Document unique identifier")
document_name: str = Field(description="Document display name")
data_source_type: str = Field(description="Type of data source")
segment_id: str | None = Field(default=None, description="Segment unique identifier")
retriever_from: str = Field(default="workflow", description="Retriever source context")
score: float = Field(default=0.0, description="Retrieval relevance score")
child_chunks: list[SourceChildChunk] = Field(default=[], description="List of child chunks")
segment_hit_count: int | None = Field(default=0, description="Number of times segment was retrieved")
segment_word_count: int | None = Field(default=0, description="Word count of the segment")
segment_position: int | None = Field(default=0, description="Position of segment in document")
segment_index_node_hash: str | None = Field(default=None, description="Hash of index node for the segment")
doc_metadata: dict[str, Any] | None = Field(default=None, description="Additional document metadata")
position: int | None = Field(default=0, description="Position of the document in the dataset")
class Config:
populate_by_name = True
class Source(BaseModel):
metadata: SourceMetadata = Field(description="Source metadata information")
title: str = Field(description="Document title")
files: list[Any] | None = Field(default=None, description="Associated file references")
content: str | None = Field(description="Segment content text")
summary: str | None = Field(default=None, description="Content summary if available")
class KnowledgeRetrievalRequest(BaseModel):
tenant_id: str = Field(description="Tenant unique identifier")
user_id: str = Field(description="User unique identifier")
app_id: str = Field(description="Application unique identifier")
user_from: str = Field(description="Source of the user request (e.g., 'workflow', 'api')")
dataset_ids: list[str] = Field(description="List of dataset IDs to retrieve from")
query: str | None = Field(default=None, description="Query text for knowledge retrieval")
retrieval_mode: str = Field(description="Retrieval strategy: 'single' or 'multiple'")
model_provider: str | None = Field(default=None, description="Model provider name (e.g., 'openai', 'anthropic')")
completion_params: dict[str, Any] | None = Field(
default=None, description="Model completion parameters (e.g., temperature, max_tokens)"
)
model_mode: str | None = Field(default=None, description="Model mode (e.g., 'chat', 'completion')")
model_name: str | None = Field(default=None, description="Model name (e.g., 'gpt-4', 'claude-3-opus')")
metadata_model_config: ModelConfig | None = Field(
default=None, description="Model config for metadata-based filtering"
)
metadata_filtering_conditions: MetadataFilteringCondition | None = Field(
default=None, description="Conditions for filtering by metadata"
)
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = Field(
default="disabled", description="Metadata filtering mode: 'disabled', 'automatic', or 'manual'"
)
top_k: int = Field(default=0, description="Number of top results to return")
score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold")
reranking_mode: str = Field(default="reranking_model", description="Reranking strategy")
reranking_model: dict | None = Field(default=None, description="Reranking model configuration")
weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking")
reranking_enable: bool = Field(default=True, description="Whether reranking is enabled")
attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval")
class RAGRetrievalProtocol(Protocol):
"""Protocol for RAG-based knowledge retrieval implementations.
Implementations of this protocol handle knowledge retrieval from datasets
including rate limiting, dataset filtering, and document retrieval.
"""
@property
def llm_usage(self) -> LLMUsage:
"""Return accumulated LLM usage for retrieval operations."""
...
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]:
"""Retrieve knowledge from datasets based on the provided request.
Args:
request: Knowledge retrieval request with search parameters
Returns:
List of sources matching the search criteria
Raises:
RateLimitExceededError: If rate limit is exceeded
ModelNotExistError: If specified model doesn't exist
"""
...

View File

@@ -64,7 +64,7 @@ class GraphExecutionProtocol(Protocol):
aborted: bool
error: Exception | None
exceptions_count: int
pause_reasons: list[PauseReason]
pause_reasons: Sequence[PauseReason]
def start(self) -> None:
"""Transition execution into the running state."""
@@ -446,7 +446,7 @@ class GraphRuntimeState:
graph_execution_cls = module.GraphExecution
workflow_id = self._pending_graph_execution_workflow_id or ""
self._pending_graph_execution_workflow_id = None
return graph_execution_cls(workflow_id=workflow_id) # type: ignore[invalid-return-type]
return graph_execution_cls(workflow_id=workflow_id)
def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol:
# Lazily import to keep the runtime domain decoupled from graph_engine modules.

View File

@@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
if [[ -z "${CELERY_QUEUES}" ]]; then
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
else
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
fi
else
DEFAULT_QUEUES="${CELERY_QUEUES}"

View File

@@ -190,14 +190,6 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.trigger_provider_refresh_task.trigger_provider_refresh",
"schedule": timedelta(minutes=dify_config.TRIGGER_PROVIDER_REFRESH_INTERVAL),
}
if dify_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK:
imports.append("schedule.update_api_token_last_used_task")
beat_schedule["batch_update_api_token_last_used"] = {
"task": "schedule.update_api_token_last_used_task.batch_update_api_token_last_used",
"schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL),
}
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
return celery_app

View File

@@ -1,9 +1,7 @@
from __future__ import annotations
from datetime import datetime
from flask_restx import fields
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict
simple_end_user_fields = {
"id": fields.String,
@@ -12,19 +10,6 @@ simple_end_user_fields = {
"session_id": fields.String,
}
end_user_detail_fields = {
"id": fields.String,
"tenant_id": fields.String,
"app_id": fields.String,
"type": fields.String,
"external_user_id": fields.String,
"name": fields.String,
"is_anonymous": fields.Boolean,
"session_id": fields.String,
"created_at": fields.DateTime,
"updated_at": fields.DateTime,
}
class ResponseModel(BaseModel):
model_config = ConfigDict(
@@ -41,23 +26,3 @@ class SimpleEndUser(ResponseModel):
type: str
is_anonymous: bool
session_id: str | None = None
class EndUserDetail(ResponseModel):
"""Full EndUser record for API responses.
Note: The SQLAlchemy model defines an `is_anonymous` property for Flask-Login semantics
(always False). The database column is exposed as `_is_anonymous`, so this DTO maps
`is_anonymous` from `_is_anonymous` to return the stored value.
"""
id: str
tenant_id: str
app_id: str | None = None
type: str
external_user_id: str | None = None
name: str | None = None
is_anonymous: bool = Field(validation_alias="_is_anonymous")
session_id: str
created_at: datetime
updated_at: datetime

View File

@@ -33,7 +33,7 @@ class ShardedTopic:
return self
def publish(self, payload: bytes) -> None:
self._client.spublish(self._topic, payload) # type: ignore[attr-defined,union-attr]
self._client.spublish(self._topic, payload) # type: ignore[attr-defined]
def as_subscriber(self) -> Subscriber:
return self
@@ -75,7 +75,7 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
#
# Here we specify the `target_node` to mitigate this problem.
node = self._client.get_node_from_key(self._topic)
return self._pubsub.get_sharded_message( # type: ignore[attr-defined]
return self._pubsub.get_sharded_message(
ignore_subscribe_messages=False,
timeout=1,
target_node=node,

View File

@@ -20,7 +20,7 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('account_trial_app_records',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('account_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('count', sa.Integer(), nullable=False),
@@ -33,17 +33,17 @@ def upgrade():
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
op.create_table('exporle_banners',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('content', sa.JSON(), nullable=False),
sa.Column('link', sa.String(length=255), nullable=False),
sa.Column('sort', sa.Integer(), nullable=False),
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'"), nullable=False),
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'"), nullable=False),
sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False),
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
)
op.create_table('trial_apps',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),

View File

@@ -1,34 +0,0 @@
"""drop server_default for app trail related tables
Revision ID: c3df22613c99
Revises: e8c3b3c46151
Create Date: 2026-02-09 09:50:46.181969
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'c3df22613c99'
down_revision = 'e8c3b3c46151'
branch_labels = None
depends_on = None
def upgrade():
op.alter_column("account_trial_app_records", "id", server_default=None)
op.alter_column("exporle_banners", "id", server_default=None)
op.alter_column("trial_apps", "id", server_default=None)
def downgrade():
# This migration is primarily for schema consistence
# between database and model definitions.
#
# The original
# DROP SERVER default is idemponent.
# Besides, the original migration has been updated to
# reflect the
pass

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from uuid import uuid4
from sqlalchemy import DateTime, func
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
@@ -51,16 +50,3 @@ class DefaultFieldsMixin:
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(id={self.id})>"
def gen_uuidv4_string() -> str:
"""gen_uuidv4_string generate a UUIDv4 string.
NOTE: This function exists only for historical reasons. New models should use uuidv7 for primary key generation.
"""
return str(uuid4())
def gen_uuidv7_string() -> str:
"""gen_uuidv4_string generate a UUIDv4 string."""
return str(uuidv7())

View File

@@ -26,7 +26,7 @@ from libs.helper import generate_string # type: ignore[import-not-found]
from libs.uuid_utils import uuidv7
from .account import Account, Tenant
from .base import Base, TypeBase, gen_uuidv4_string
from .base import Base, TypeBase
from .engine import db
from .enums import CreatorUserRole
from .provider_ids import GenericProviderID
@@ -620,7 +620,7 @@ class TrialApp(Base):
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
)
id = mapped_column(StringUUID, default=gen_uuidv4_string)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
tenant_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@@ -640,7 +640,7 @@ class AccountTrialAppRecord(Base):
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
)
id = mapped_column(StringUUID, default=gen_uuidv4_string)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
account_id = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=False)
count = mapped_column(sa.Integer, nullable=False, default=0)
@@ -660,7 +660,7 @@ class AccountTrialAppRecord(Base):
class ExporleBanner(TypeBase):
__tablename__ = "exporle_banners"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv4_string, init=False)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
link: Mapped[str] = mapped_column(String(255), nullable=False)
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)

View File

@@ -81,7 +81,7 @@ dependencies = [
"starlette==0.49.1",
"tiktoken~=0.9.0",
"transformers~=4.56.1",
"unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
"yarl~=1.18.3",
"webvtt-py~=0.5.1",
"sseclient-py~=1.8.0",

View File

@@ -1,114 +0,0 @@
"""
Scheduled task to batch-update API token last_used_at timestamps.
Instead of updating the database on every request, token usage is recorded
in Redis as lightweight SET keys (api_token_active:{scope}:{token}).
This task runs periodically (default every 30 minutes) to flush those
records into the database in a single batch operation.
"""
import logging
import time
from datetime import datetime
import click
from sqlalchemy import update
from sqlalchemy.orm import Session
import app
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import ApiToken
from services.api_token_service import ACTIVE_TOKEN_KEY_PREFIX
logger = logging.getLogger(__name__)
@app.celery.task(queue="api_token")
def batch_update_api_token_last_used():
"""
Batch update last_used_at for all recently active API tokens.
Scans Redis for api_token_active:* keys, parses the token and scope
from each key, and performs a batch database update.
"""
click.echo(click.style("batch_update_api_token_last_used: start.", fg="green"))
start_at = time.perf_counter()
updated_count = 0
scanned_count = 0
try:
# Collect all active token keys and their values (the actual usage timestamps)
token_entries: list[tuple[str, str | None, datetime]] = [] # (token, scope, usage_time)
keys_to_delete: list[str | bytes] = []
for key in redis_client.scan_iter(match=f"{ACTIVE_TOKEN_KEY_PREFIX}*", count=200):
if isinstance(key, bytes):
key = key.decode("utf-8")
scanned_count += 1
# Read the value (ISO timestamp recorded at actual request time)
value = redis_client.get(key)
if not value:
keys_to_delete.append(key)
continue
if isinstance(value, bytes):
value = value.decode("utf-8")
try:
usage_time = datetime.fromisoformat(value)
except (ValueError, TypeError):
logger.warning("Invalid timestamp in key %s: %s", key, value)
keys_to_delete.append(key)
continue
# Parse token info from key: api_token_active:{scope}:{token}
suffix = key[len(ACTIVE_TOKEN_KEY_PREFIX) :]
parts = suffix.split(":", 1)
if len(parts) == 2:
scope_str, token = parts
scope = None if scope_str == "None" else scope_str
token_entries.append((token, scope, usage_time))
keys_to_delete.append(key)
if not token_entries:
click.echo(click.style("batch_update_api_token_last_used: no active tokens found.", fg="yellow"))
# Still clean up any invalid keys
if keys_to_delete:
redis_client.delete(*keys_to_delete)
return
# Update each token in its own short transaction to avoid long transactions
for token, scope, usage_time in token_entries:
with Session(db.engine, expire_on_commit=False) as session, session.begin():
stmt = (
update(ApiToken)
.where(
ApiToken.token == token,
ApiToken.type == scope,
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < usage_time)),
)
.values(last_used_at=usage_time)
)
result = session.execute(stmt)
rowcount = getattr(result, "rowcount", 0)
if rowcount > 0:
updated_count += 1
# Delete processed keys from Redis
if keys_to_delete:
redis_client.delete(*keys_to_delete)
except Exception:
logger.exception("batch_update_api_token_last_used failed")
elapsed = time.perf_counter() - start_at
click.echo(
click.style(
f"batch_update_api_token_last_used: done. "
f"scanned={scanned_count}, updated={updated_count}, elapsed={elapsed:.2f}s",
fg="green",
)
)

View File

@@ -1,330 +0,0 @@
"""
API Token Service
Handles all API token caching, validation, and usage recording.
Includes Redis cache operations, database queries, and single-flight concurrency control.
"""
import logging
from datetime import datetime
from typing import Any
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
from libs.datetime_utils import naive_utc_now
from models.model import ApiToken
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------
# Pydantic DTO
# ---------------------------------------------------------------------
class CachedApiToken(BaseModel):
"""
Pydantic model for cached API token data.
This is NOT a SQLAlchemy model instance, but a plain Pydantic model
that mimics the ApiToken model interface for read-only access.
"""
id: str
app_id: str | None
tenant_id: str | None
type: str
token: str
last_used_at: datetime | None
created_at: datetime | None
def __repr__(self) -> str:
return f"<CachedApiToken id={self.id} type={self.type}>"
# ---------------------------------------------------------------------
# Cache configuration
# ---------------------------------------------------------------------
CACHE_KEY_PREFIX = "api_token"
CACHE_TTL_SECONDS = 600 # 10 minutes
CACHE_NULL_TTL_SECONDS = 60 # 1 minute for non-existent tokens
ACTIVE_TOKEN_KEY_PREFIX = "api_token_active:"
# ---------------------------------------------------------------------
# Cache class
# ---------------------------------------------------------------------
class ApiTokenCache:
"""
Redis cache wrapper for API tokens.
Handles serialization, deserialization, and cache invalidation.
"""
@staticmethod
def make_active_key(token: str, scope: str | None = None) -> str:
"""Generate Redis key for recording token usage."""
return f"{ACTIVE_TOKEN_KEY_PREFIX}{scope}:{token}"
@staticmethod
def _make_tenant_index_key(tenant_id: str) -> str:
"""Generate Redis key for tenant token index."""
return f"tenant_tokens:{tenant_id}"
@staticmethod
def _make_cache_key(token: str, scope: str | None = None) -> str:
"""Generate cache key for the given token and scope."""
scope_str = scope or "any"
return f"{CACHE_KEY_PREFIX}:{scope_str}:{token}"
@staticmethod
def _serialize_token(api_token: Any) -> bytes:
"""Serialize ApiToken object to JSON bytes."""
if isinstance(api_token, CachedApiToken):
return api_token.model_dump_json().encode("utf-8")
cached = CachedApiToken(
id=str(api_token.id),
app_id=str(api_token.app_id) if api_token.app_id else None,
tenant_id=str(api_token.tenant_id) if api_token.tenant_id else None,
type=api_token.type,
token=api_token.token,
last_used_at=api_token.last_used_at,
created_at=api_token.created_at,
)
return cached.model_dump_json().encode("utf-8")
@staticmethod
def _deserialize_token(cached_data: bytes | str) -> Any:
"""Deserialize JSON bytes/string back to a CachedApiToken Pydantic model."""
if cached_data in {b"null", "null"}:
return None
try:
if isinstance(cached_data, bytes):
cached_data = cached_data.decode("utf-8")
return CachedApiToken.model_validate_json(cached_data)
except (ValueError, Exception) as e:
logger.warning("Failed to deserialize token from cache: %s", e)
return None
@staticmethod
@redis_fallback(default_return=None)
def get(token: str, scope: str | None) -> Any | None:
"""Get API token from cache."""
cache_key = ApiTokenCache._make_cache_key(token, scope)
cached_data = redis_client.get(cache_key)
if cached_data is None:
logger.debug("Cache miss for token key: %s", cache_key)
return None
logger.debug("Cache hit for token key: %s", cache_key)
return ApiTokenCache._deserialize_token(cached_data)
@staticmethod
def _add_to_tenant_index(tenant_id: str | None, cache_key: str) -> None:
"""Add cache key to tenant index for efficient invalidation."""
if not tenant_id:
return
try:
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
redis_client.sadd(index_key, cache_key)
redis_client.expire(index_key, CACHE_TTL_SECONDS + 60)
except Exception as e:
logger.warning("Failed to update tenant index: %s", e)
@staticmethod
def _remove_from_tenant_index(tenant_id: str | None, cache_key: str) -> None:
"""Remove cache key from tenant index."""
if not tenant_id:
return
try:
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
redis_client.srem(index_key, cache_key)
except Exception as e:
logger.warning("Failed to remove from tenant index: %s", e)
@staticmethod
@redis_fallback(default_return=False)
def set(token: str, scope: str | None, api_token: Any | None, ttl: int = CACHE_TTL_SECONDS) -> bool:
"""Set API token in cache."""
cache_key = ApiTokenCache._make_cache_key(token, scope)
if api_token is None:
cached_value = b"null"
ttl = CACHE_NULL_TTL_SECONDS
else:
cached_value = ApiTokenCache._serialize_token(api_token)
try:
redis_client.setex(cache_key, ttl, cached_value)
if api_token is not None and hasattr(api_token, "tenant_id"):
ApiTokenCache._add_to_tenant_index(api_token.tenant_id, cache_key)
logger.debug("Cached token with key: %s, ttl: %ss", cache_key, ttl)
return True
except Exception as e:
logger.warning("Failed to cache token: %s", e)
return False
@staticmethod
@redis_fallback(default_return=False)
def delete(token: str, scope: str | None = None) -> bool:
"""Delete API token from cache."""
if scope is None:
pattern = f"{CACHE_KEY_PREFIX}:*:{token}"
try:
keys_to_delete = list(redis_client.scan_iter(match=pattern))
if keys_to_delete:
redis_client.delete(*keys_to_delete)
logger.info("Deleted %d cache entries for token", len(keys_to_delete))
return True
except Exception as e:
logger.warning("Failed to delete token cache with pattern: %s", e)
return False
else:
cache_key = ApiTokenCache._make_cache_key(token, scope)
try:
tenant_id = None
try:
cached_data = redis_client.get(cache_key)
if cached_data and cached_data != b"null":
cached_token = ApiTokenCache._deserialize_token(cached_data)
if cached_token:
tenant_id = cached_token.tenant_id
except Exception as e:
logger.debug("Failed to get tenant_id for cache cleanup: %s", e)
redis_client.delete(cache_key)
if tenant_id:
ApiTokenCache._remove_from_tenant_index(tenant_id, cache_key)
logger.info("Deleted cache for key: %s", cache_key)
return True
except Exception as e:
logger.warning("Failed to delete token cache: %s", e)
return False
@staticmethod
@redis_fallback(default_return=False)
def invalidate_by_tenant(tenant_id: str) -> bool:
"""Invalidate all API token caches for a specific tenant via tenant index."""
try:
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
cache_keys = redis_client.smembers(index_key)
if cache_keys:
deleted_count = 0
for cache_key in cache_keys:
if isinstance(cache_key, bytes):
cache_key = cache_key.decode("utf-8")
redis_client.delete(cache_key)
deleted_count += 1
redis_client.delete(index_key)
logger.info(
"Invalidated %d token cache entries for tenant: %s",
deleted_count,
tenant_id,
)
else:
logger.info(
"No tenant index found for %s, relying on TTL expiration",
tenant_id,
)
return True
except Exception as e:
logger.warning("Failed to invalidate tenant token cache: %s", e)
return False
# ---------------------------------------------------------------------
# Token usage recording (for batch update)
# ---------------------------------------------------------------------
def record_token_usage(auth_token: str, scope: str | None) -> None:
"""
Record token usage in Redis for later batch update by a scheduled job.
Instead of dispatching a Celery task per request, we simply SET a key in Redis.
A Celery Beat scheduled task will periodically scan these keys and batch-update
last_used_at in the database.
"""
try:
key = ApiTokenCache.make_active_key(auth_token, scope)
redis_client.set(key, naive_utc_now().isoformat(), ex=3600)
except Exception as e:
logger.warning("Failed to record token usage: %s", e)
# ---------------------------------------------------------------------
# Database query + single-flight
# ---------------------------------------------------------------------
def query_token_from_db(auth_token: str, scope: str | None) -> ApiToken:
"""
Query API token from database and cache the result.
Raises Unauthorized if token is invalid.
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
ApiTokenCache.set(auth_token, scope, None)
raise Unauthorized("Access token is invalid")
ApiTokenCache.set(auth_token, scope, api_token)
record_token_usage(auth_token, scope)
return api_token
def fetch_token_with_single_flight(auth_token: str, scope: str | None) -> ApiToken | Any:
"""
Fetch token from DB with single-flight pattern using Redis lock.
Ensures only one concurrent request queries the database for the same token.
Falls back to direct query if lock acquisition fails.
"""
logger.debug("Token cache miss, attempting to acquire query lock for scope: %s", scope)
lock_key = f"api_token_query_lock:{scope}:{auth_token}"
lock = redis_client.lock(lock_key, timeout=10, blocking_timeout=5)
try:
if lock.acquire(blocking=True):
try:
cached_token = ApiTokenCache.get(auth_token, scope)
if cached_token is not None:
logger.debug("Token cached by concurrent request, using cached version")
return cached_token
return query_token_from_db(auth_token, scope)
finally:
lock.release()
else:
logger.warning("Lock timeout for token: %s, proceeding with direct query", auth_token[:10])
return query_token_from_db(auth_token, scope)
except Unauthorized:
raise
except Exception as e:
logger.warning("Redis lock failed for token query: %s, proceeding anyway", e)
return query_token_from_db(auth_token, scope)

View File

@@ -155,11 +155,11 @@ class AsyncWorkflowService:
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
task = execute_workflow_professional.delay(task_data_dict) # type: ignore
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
task = execute_workflow_team.delay(task_data_dict) # type: ignore
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
task = execute_workflow_sandbox.delay(task_data_dict) # type: ignore
# 10. Update trigger log with task info
trigger_log.status = WorkflowTriggerStatus.QUEUED
@@ -170,7 +170,7 @@ class AsyncWorkflowService:
return AsyncTriggerResponse(
workflow_trigger_log_id=trigger_log.id,
task_id=task.id,
task_id=task.id, # type: ignore
status="queued",
queue=queue_name,
)

View File

@@ -1696,18 +1696,13 @@ class DocumentService:
for document in documents
if document.data_source_type == "upload_file" and document.data_source_info_dict
]
if dataset.doc_form is not None:
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
# Delete documents first, then dispatch cleanup task after commit
# to avoid deadlock between main transaction and async task
for document in documents:
db.session.delete(document)
db.session.commit()
# Dispatch cleanup task after commit to avoid lock contention
# Task cleans up segments, files, and vector indexes
if dataset.doc_form is not None:
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
@staticmethod
def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
assert isinstance(current_user, Account)

View File

@@ -16,25 +16,6 @@ class EndUserService:
Service for managing end users.
"""
@classmethod
def get_end_user_by_id(cls, *, tenant_id: str, app_id: str, end_user_id: str) -> EndUser | None:
"""Get an end user by primary key.
This is scoped to the provided tenant and app to prevent cross-tenant/app access
when an end-user ID is known.
"""
with Session(db.engine, expire_on_commit=False) as session:
return (
session.query(EndUser)
.where(
EndUser.id == end_user_id,
EndUser.tenant_id == tenant_id,
EndUser.app_id == app_id,
)
.first()
)
@classmethod
def get_or_create_end_user(cls, app_model: App, user_id: str | None = None) -> EndUser:
"""

View File

@@ -1329,24 +1329,10 @@ class RagPipelineService:
"""
Get datasource plugins
"""
dataset: Dataset | None = (
db.session.query(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == tenant_id,
)
.first()
)
dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = (
db.session.query(Pipeline)
.where(
Pipeline.id == dataset.pipeline_id,
Pipeline.tenant_id == tenant_id,
)
.first()
)
pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
@@ -1427,24 +1413,10 @@ class RagPipelineService:
"""
Get pipeline
"""
dataset: Dataset | None = (
db.session.query(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == tenant_id,
)
.first()
)
dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = (
db.session.query(Pipeline)
.where(
Pipeline.id == dataset.pipeline_id,
Pipeline.tenant_id == tenant_id,
)
.first()
)
pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
return pipeline

View File

@@ -10,7 +10,6 @@ from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db

View File

@@ -335,7 +335,7 @@ def _build_node_finished_event(
inputs=None,
process_data=None,
outputs=None,
status=WorkflowNodeExecutionStatus(snapshot.status),
status=snapshot.status,
error=None,
elapsed_time=snapshot.elapsed_time,
execution_metadata=None,
@@ -373,7 +373,7 @@ def _build_pause_event(
paused_nodes=paused_nodes,
outputs=outputs,
reasons=reasons,
status=workflow_run.status,
status=workflow_run.status.value,
created_at=int(workflow_run.created_at.timestamp()),
elapsed_time=float(workflow_run.elapsed_time or 0.0),
total_tokens=int(workflow_run.total_tokens or 0),

View File

@@ -6,6 +6,7 @@ from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@@ -57,3 +58,5 @@ def add_annotation_to_index_task(
)
except Exception:
logger.exception("Build index for annotation failed")
finally:
db.session.close()

View File

@@ -5,6 +5,7 @@ import click
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@@ -39,3 +40,5 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("Annotation deleted index failed")
finally:
db.session.close()

View File

@@ -6,6 +6,7 @@ from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@@ -58,3 +59,5 @@ def update_annotation_to_index_task(
)
except Exception:
logger.exception("Build index for annotation failed")
finally:
db.session.close()

View File

@@ -14,9 +14,6 @@ from models.model import UploadFile
logger = logging.getLogger(__name__)
# Batch size for database operations to keep transactions short
BATCH_SIZE = 1000
@shared_task(queue="dataset")
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
@@ -34,179 +31,63 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
if not doc_form:
raise ValueError("doc_form is required")
storage_keys_to_delete: list[str] = []
index_node_ids: list[str] = []
segment_ids: list[str] = []
total_image_upload_file_ids: list[str] = []
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
).delete(synchronize_session=False)
try:
# ============ Step 1: Query segment and file data (short read-only transaction) ============
with session_factory.create_session() as session:
# Get segments info
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
# Collect image file IDs from segment content
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
total_image_upload_file_ids.extend(image_upload_file_ids)
# Query storage keys for image files
if total_image_upload_file_ids:
image_files = session.scalars(
select(UploadFile).where(UploadFile.id.in_(total_image_upload_file_ids))
).all()
storage_keys_to_delete.extend([f.key for f in image_files if f and f.key])
# Query storage keys for document files
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
for image_file in image_files:
try:
if image_file and image_file.key:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(stmt)
session.delete(segment)
if file_ids:
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
storage_keys_to_delete.extend([f.key for f in files if f and f.key])
for file in files:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
session.execute(stmt)
# ============ Step 2: Clean vector index (external service, fresh session for dataset) ============
if index_node_ids:
try:
# Fetch dataset in a fresh session to avoid DetachedInstanceError
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.warning("Dataset not found for vector index cleanup, dataset_id: %s", dataset_id)
else:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
except Exception:
logger.exception(
"Failed to clean vector index for dataset_id: %s, document_ids: %s, index_node_ids count: %d",
dataset_id,
document_ids,
len(index_node_ids),
)
session.commit()
# ============ Step 3: Delete metadata binding (separate short transaction) ============
try:
with session_factory.create_session() as session:
deleted_count = (
session.query(DatasetMetadataBinding)
.where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
)
.delete(synchronize_session=False)
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
fg="green",
)
session.commit()
logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id)
)
except Exception:
logger.exception(
"Failed to delete metadata bindings for dataset_id: %s, document_ids: %s",
dataset_id,
document_ids,
)
# ============ Step 4: Batch delete UploadFile records (multiple short transactions) ============
if total_image_upload_file_ids:
failed_batches = 0
total_batches = (len(total_image_upload_file_ids) + BATCH_SIZE - 1) // BATCH_SIZE
for i in range(0, len(total_image_upload_file_ids), BATCH_SIZE):
batch = total_image_upload_file_ids[i : i + BATCH_SIZE]
try:
with session_factory.create_session() as session:
stmt = delete(UploadFile).where(UploadFile.id.in_(batch))
session.execute(stmt)
session.commit()
except Exception:
failed_batches += 1
logger.exception(
"Failed to delete image UploadFile batch %d-%d for dataset_id: %s",
i,
i + len(batch),
dataset_id,
)
if failed_batches > 0:
logger.warning(
"Image UploadFile deletion: %d/%d batches failed for dataset_id: %s",
failed_batches,
total_batches,
dataset_id,
)
# ============ Step 5: Batch delete DocumentSegment records (multiple short transactions) ============
if segment_ids:
failed_batches = 0
total_batches = (len(segment_ids) + BATCH_SIZE - 1) // BATCH_SIZE
for i in range(0, len(segment_ids), BATCH_SIZE):
batch = segment_ids[i : i + BATCH_SIZE]
try:
with session_factory.create_session() as session:
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(batch))
session.execute(segment_delete_stmt)
session.commit()
except Exception:
failed_batches += 1
logger.exception(
"Failed to delete DocumentSegment batch %d-%d for dataset_id: %s, document_ids: %s",
i,
i + len(batch),
dataset_id,
document_ids,
)
if failed_batches > 0:
logger.warning(
"DocumentSegment deletion: %d/%d batches failed, document_ids: %s",
failed_batches,
total_batches,
document_ids,
)
# ============ Step 6: Delete document-associated files (separate short transaction) ============
if file_ids:
try:
with session_factory.create_session() as session:
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
session.execute(stmt)
session.commit()
except Exception:
logger.exception(
"Failed to delete document UploadFile records for dataset_id: %s, file_ids: %s",
dataset_id,
file_ids,
)
# ============ Step 7: Delete storage files (I/O operations, no DB transaction) ============
storage_delete_failures = 0
for storage_key in storage_keys_to_delete:
try:
storage.delete(storage_key)
except Exception:
storage_delete_failures += 1
logger.exception("Failed to delete file from storage, key: %s", storage_key)
if storage_delete_failures > 0:
logger.warning(
"Storage file deletion completed with %d failures out of %d total files for dataset_id: %s",
storage_delete_failures,
len(storage_keys_to_delete),
dataset_id,
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned documents when documents deleted latency: {end_at - start_at:.2f}s, "
f"dataset_id: {dataset_id}, document_ids: {document_ids}, "
f"segments: {len(segment_ids)}, image_files: {len(total_image_upload_file_ids)}, "
f"storage_files: {len(storage_keys_to_delete)}",
fg="green",
)
)
except Exception:
logger.exception(
"Batch clean documents failed for dataset_id: %s, document_ids: %s",
dataset_id,
document_ids,
)
logger.exception("Cleaned documents when documents deleted failed")

View File

@@ -48,11 +48,6 @@ def batch_create_segment_to_index_task(
indexing_cache_key = f"segment_batch_import_{job_id}"
# Initialize variables with default values
upload_file_key: str | None = None
dataset_config: dict | None = None
document_config: dict | None = None
with session_factory.create_session() as session:
try:
dataset = session.get(Dataset, dataset_id)
@@ -74,115 +69,86 @@ def batch_create_segment_to_index_task(
if not upload_file:
raise ValueError("UploadFile not found.")
dataset_config = {
"id": dataset.id,
"indexing_technique": dataset.indexing_technique,
"tenant_id": dataset.tenant_id,
"embedding_model_provider": dataset.embedding_model_provider,
"embedding_model": dataset.embedding_model,
}
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
document_config = {
"id": dataset_document.id,
"doc_form": dataset_document.doc_form,
"word_count": dataset_document.word_count or 0,
}
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if dataset_document.doc_form == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
upload_file_key = upload_file.key
document_segments = []
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
except Exception:
logger.exception("Segments batch created index failed")
redis_client.setex(indexing_cache_key, 600, "error")
return
# Ensure required variables are set before proceeding
if upload_file_key is None or dataset_config is None or document_config is None:
logger.error("Required configuration not set due to session error")
redis_client.setex(indexing_cache_key, 600, "error")
return
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file_key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file_key, file_path)
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if document_config["doc_form"] == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(
texts=[segment["content"] for segment in content]
)
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
tokens_list = [0] * len(content)
document_segments = []
embedding_model = None
if dataset_config["indexing_technique"] == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset_config["tenant_id"],
provider=dataset_config["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=dataset_config["embedding_model"],
)
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
max_position = (
session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == dataset_document.id)
.scalar()
)
segment_document = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if dataset_document.doc_form == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
session.add(segment_document)
document_segments.append(segment_document)
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content])
else:
tokens_list = [0] * len(content)
with session_factory.create_session() as session, session.begin():
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
max_position = (
session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document_config["id"])
.scalar()
)
segment_document = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if document_config["doc_form"] == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
session.add(segment_document)
document_segments.append(segment_document)
with session_factory.create_session() as session, session.begin():
dataset_document = session.get(Document, document_id)
if dataset_document:
assert dataset_document.word_count is not None
dataset_document.word_count += word_count_change
session.add(dataset_document)
with session_factory.create_session() as session:
dataset = session.get(Dataset, dataset_id)
if dataset:
VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"])
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
fg="green",
)
)
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Segments batch created index failed")
redis_client.setex(indexing_cache_key, 600, "error")

View File

@@ -28,7 +28,6 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
"""
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
start_at = time.perf_counter()
total_attachment_files = []
with session_factory.create_session() as session:
try:
@@ -48,91 +47,78 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
SegmentAttachmentBinding.document_id == document_id,
)
).all()
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings])
index_node_ids = [segment.index_node_id for segment in segments]
segment_contents = [segment.content for segment in segments]
except Exception:
logger.exception("Cleaned document when document deleted failed")
return
# check segment is exist
if index_node_ids:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
total_image_files = []
with session_factory.create_session() as session, session.begin():
for segment_content in segment_contents:
image_upload_file_ids = get_image_upload_file_ids(segment_content)
image_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))).all()
total_image_files.extend([image_file.key for image_file in image_files])
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(image_file_delete_stmt)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.scalars(
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
).all()
for image_file in image_files:
if image_file is None:
continue
try:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
with session_factory.create_session() as session, session.begin():
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
session.execute(segment_delete_stmt)
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(image_file_delete_stmt)
session.delete(segment)
for image_file_key in total_image_files:
try:
storage.delete(image_file_key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file_key,
session.commit()
if file_id:
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
session.delete(file)
# delete segment attachments
if attachments_with_bindings:
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
session.execute(attachment_file_delete_stmt)
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.id.in_(binding_ids)
)
session.execute(binding_delete_stmt)
# delete dataset metadata binding
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
)
)
with session_factory.create_session() as session, session.begin():
if file_id:
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
session.delete(file)
with session_factory.create_session() as session, session.begin():
# delete segment attachments
if attachment_ids:
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
session.execute(attachment_file_delete_stmt)
if binding_ids:
binding_delete_stmt = delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.id.in_(binding_ids))
session.execute(binding_delete_stmt)
for attachment_file_key in total_attachment_files:
try:
storage.delete(attachment_file_key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
attachment_file_key,
)
with session_factory.create_session() as session, session.begin():
# delete dataset metadata binding
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
)
)
logger.exception("Cleaned document when document deleted failed")

View File

@@ -23,40 +23,40 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
"""
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
total_index_node_ids = []
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if not dataset:
raise Exception("Document has no dataset")
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
session.execute(document_delete_stmt)
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
session.execute(document_delete_stmt)
for document_id in document_ids:
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
total_index_node_ids.extend([segment.index_node_id for segment in segments])
for document_id in document_ids:
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
index_processor.clean(
dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
)
)
with session_factory.create_session() as session, session.begin():
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
session.execute(segment_delete_stmt)
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
)
)
except Exception:
logger.exception("Cleaned document when import form notion document deleted failed")

View File

@@ -3,7 +3,6 @@ import time
import click
from celery import shared_task
from sqlalchemy import delete
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -68,14 +67,8 @@ def delete_segment_from_index_task(
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
segment_attachment_bind_ids = [i.id for i in segment_attachment_bindings]
for i in range(0, len(segment_attachment_bind_ids), 1000):
segment_attachment_bind_delete_stmt = delete(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.id.in_(segment_attachment_bind_ids[i : i + 1000])
)
session.execute(segment_attachment_bind_delete_stmt)
for binding in segment_attachment_bindings:
session.delete(binding)
# delete upload file
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
session.commit()

View File

@@ -27,129 +27,104 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
"""
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
start_at = time.perf_counter()
tenant_id = None
with session_factory.create_session() as session, session.begin():
with session_factory.create_session() as session:
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
if document.indexing_status == "parsing":
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
return
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
data_source_info = document.data_source_info_dict
if document.data_source_type != "notion_import":
logger.info(click.style(f"Document {document_id} is not a notion_import, skipping", fg="yellow"))
return
if document.data_source_type == "notion_import":
if (
not data_source_info
or "notion_page_id" not in data_source_info
or "notion_workspace_id" not in data_source_info
):
raise ValueError("no notion page found")
workspace_id = data_source_info["notion_workspace_id"]
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
credential_id = data_source_info.get("credential_id")
if (
not data_source_info
or "notion_page_id" not in data_source_info
or "notion_workspace_id" not in data_source_info
):
raise ValueError("no notion page found")
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=document.tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
workspace_id = data_source_info["notion_workspace_id"]
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
credential_id = data_source_info.get("credential_id")
tenant_id = document.tenant_id
index_type = document.doc_form
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if not credential:
logger.error(
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
document_id,
tenant_id,
credential_id,
)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
if not credential:
logger.error(
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
document_id,
document.tenant_id,
credential_id,
)
document.indexing_status = "error"
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
document.stopped_at = naive_utc_now()
return
session.commit()
return
loader = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=tenant_id,
)
loader = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=document.tenant_id,
)
last_edited_time = loader.get_notion_last_edited_time()
if last_edited_time == page_edited_time:
logger.info(click.style(f"Document {document_id} content unchanged, skipping sync", fg="yellow"))
return
last_edited_time = loader.get_notion_last_edited_time()
logger.info(click.style(f"Document {document_id} content changed, starting sync", fg="green"))
# check the page is updated
if last_edited_time != page_edited_time:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.commit()
try:
index_processor = IndexProcessorFactory(index_type).init_index_processor()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green"))
except Exception:
logger.exception("Failed to clean vector index for document %s", document_id)
# delete all document segment and index
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if not document:
logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow"))
return
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
data_source_info = document.data_source_info_dict
data_source_info["last_edited_time"] = last_edited_time
document.data_source_info = data_source_info
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
session.execute(segment_delete_stmt)
end_at = time.perf_counter()
logger.info(
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
),
fg="green",
)
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
logger.info(click.style(f"Deleted segments for document {document_id}", fg="green"))
try:
indexing_runner = IndexingRunner()
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
if document:
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"Sync completed for document {document_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception as e:
logger.exception("document_indexing_sync_task failed for document_id: %s", document_id)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)

View File

@@ -81,35 +81,26 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
session.commit()
return
# Phase 1: Update status to parsing (short transaction)
with session_factory.create_session() as session, session.begin():
documents = (
session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all()
)
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
for document in documents:
if document:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
session.add(document)
# Transaction committed and closed
session.commit()
# Phase 2: Execute indexing (no transaction - IndexingRunner creates its own sessions)
has_error = False
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
has_error = True
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
has_error = True
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
if not has_error:
with session_factory.create_session() as session:
# Trigger summary index generation for completed documents if enabled
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
# Re-query dataset to get latest summary_index_setting (in case it was updated)
@@ -124,18 +115,17 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
# expire all session to get latest document's indexing status
session.expire_all()
# Check each document's indexing status and trigger summary generation if completed
documents = (
session.query(Document)
.where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
.all()
)
for document in documents:
for document_id in document_ids:
# Re-query document to get latest status (IndexingRunner may have updated it)
document = (
session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:
logger.info(
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
document.id,
document_id,
document.indexing_status,
document.doc_form,
document.need_summary,
@@ -146,36 +136,46 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
and document.need_summary is True
):
try:
generate_summary_index_task.delay(dataset.id, document.id, None)
generate_summary_index_task.delay(dataset.id, document_id, None)
logger.info(
"Queued summary index generation task for document %s in dataset %s "
"after indexing completed",
document.id,
document_id,
dataset.id,
)
except Exception:
logger.exception(
"Failed to queue summary index generation task for document %s",
document.id,
document_id,
)
# Don't fail the entire indexing process if summary task queuing fails
else:
logger.info(
"Skipping summary generation for document %s: "
"status=%s, doc_form=%s, need_summary=%s",
document.id,
document_id,
document.indexing_status,
document.doc_form,
document.need_summary,
)
else:
logger.warning("Document %s not found after indexing", document.id)
logger.warning("Document %s not found after indexing", document_id)
else:
logger.info(
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
dataset.id,
summary_index_setting.get("enable") if summary_index_setting else None,
)
else:
logger.info(
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
dataset.id,
dataset.indexing_technique,
)
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
def _document_indexing_with_tenant_queue(

View File

@@ -36,19 +36,25 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
return
# delete all document segment and index
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
clean_success = False
try:
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if index_node_ids:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
end_at = time.perf_counter()
logger.info(
click.style(
@@ -58,21 +64,15 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
fg="green",
)
)
clean_success = True
except Exception:
logger.exception("Failed to clean document index during update, document_id: %s", document_id)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
if clean_success:
with session_factory.create_session() as session, session.begin():
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
session.execute(segment_delete_stmt)
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)

View File

@@ -48,7 +48,6 @@ from models.workflow import (
WorkflowArchiveLog,
)
from repositories.factory import DifyAPIRepositoryFactory
from services.api_token_service import ApiTokenCache
logger = logging.getLogger(__name__)
@@ -135,12 +134,6 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def _delete_app_api_tokens(tenant_id: str, app_id: str):
def del_api_token(session, api_token_id: str):
# Fetch token details for cache invalidation
token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first()
if token_obj:
# Invalidate cache before deletion
ApiTokenCache.delete(token_obj.token, token_obj.type)
session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
_delete_records(

View File

@@ -6,8 +6,9 @@ improving performance by offloading storage operations to background workers.
"""
from celery import shared_task # type: ignore[import-untyped]
from sqlalchemy.orm import Session
from core.db.session_factory import session_factory
from extensions.ext_database import db
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
@@ -16,6 +17,6 @@ def save_workflow_execution_task(
self,
deletions: list[DraftVarFileDeletion],
):
with session_factory.create_session() as session, session.begin():
with Session(bind=db.engine) as session, session.begin():
srv = WorkflowDraftVariableService(session=session)
srv.delete_workflow_draft_variable_file(deletions=deletions)

View File

@@ -1,375 +0,0 @@
"""
Integration tests for API Token Cache with Redis.
These tests require:
- Redis server running
- Test database configured
"""
import time
from datetime import datetime, timedelta
from unittest.mock import patch
import pytest
from extensions.ext_redis import redis_client
from models.model import ApiToken
from services.api_token_service import ApiTokenCache, CachedApiToken
class TestApiTokenCacheRedisIntegration:
"""Integration tests with real Redis."""
def setup_method(self):
"""Setup test fixtures and clean Redis."""
self.test_token = "test-integration-token-123"
self.test_scope = "app"
self.cache_key = f"api_token:{self.test_scope}:{self.test_token}"
# Clean up any existing test data
self._cleanup()
def teardown_method(self):
"""Cleanup test data from Redis."""
self._cleanup()
def _cleanup(self):
"""Remove test data from Redis."""
try:
redis_client.delete(self.cache_key)
redis_client.delete(ApiTokenCache._make_tenant_index_key("test-tenant-id"))
redis_client.delete(ApiTokenCache.make_active_key(self.test_token, self.test_scope))
except Exception:
pass # Ignore cleanup errors
def test_cache_set_and_get_with_real_redis(self):
"""Test cache set and get operations with real Redis."""
from unittest.mock import MagicMock
mock_token = MagicMock()
mock_token.id = "test-id-123"
mock_token.app_id = "test-app-456"
mock_token.tenant_id = "test-tenant-789"
mock_token.type = "app"
mock_token.token = self.test_token
mock_token.last_used_at = datetime.now()
mock_token.created_at = datetime.now() - timedelta(days=30)
# Set in cache
result = ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
assert result is True
# Verify in Redis
cached_data = redis_client.get(self.cache_key)
assert cached_data is not None
# Get from cache
cached_token = ApiTokenCache.get(self.test_token, self.test_scope)
assert cached_token is not None
assert isinstance(cached_token, CachedApiToken)
assert cached_token.id == "test-id-123"
assert cached_token.app_id == "test-app-456"
assert cached_token.tenant_id == "test-tenant-789"
assert cached_token.type == "app"
assert cached_token.token == self.test_token
def test_cache_ttl_with_real_redis(self):
"""Test cache TTL is set correctly."""
from unittest.mock import MagicMock
mock_token = MagicMock()
mock_token.id = "test-id"
mock_token.app_id = "test-app"
mock_token.tenant_id = "test-tenant"
mock_token.type = "app"
mock_token.token = self.test_token
mock_token.last_used_at = None
mock_token.created_at = datetime.now()
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
ttl = redis_client.ttl(self.cache_key)
assert 595 <= ttl <= 600 # Should be around 600 seconds (10 minutes)
def test_cache_null_value_for_invalid_token(self):
"""Test caching null value for invalid tokens."""
result = ApiTokenCache.set(self.test_token, self.test_scope, None)
assert result is True
cached_data = redis_client.get(self.cache_key)
assert cached_data == b"null"
cached_token = ApiTokenCache.get(self.test_token, self.test_scope)
assert cached_token is None
ttl = redis_client.ttl(self.cache_key)
assert 55 <= ttl <= 60
def test_cache_delete_with_real_redis(self):
"""Test cache deletion with real Redis."""
from unittest.mock import MagicMock
mock_token = MagicMock()
mock_token.id = "test-id"
mock_token.app_id = "test-app"
mock_token.tenant_id = "test-tenant"
mock_token.type = "app"
mock_token.token = self.test_token
mock_token.last_used_at = None
mock_token.created_at = datetime.now()
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
assert redis_client.exists(self.cache_key) == 1
result = ApiTokenCache.delete(self.test_token, self.test_scope)
assert result is True
assert redis_client.exists(self.cache_key) == 0
def test_tenant_index_creation(self):
"""Test tenant index is created when caching token."""
from unittest.mock import MagicMock
tenant_id = "test-tenant-id"
mock_token = MagicMock()
mock_token.id = "test-id"
mock_token.app_id = "test-app"
mock_token.tenant_id = tenant_id
mock_token.type = "app"
mock_token.token = self.test_token
mock_token.last_used_at = None
mock_token.created_at = datetime.now()
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
assert redis_client.exists(index_key) == 1
members = redis_client.smembers(index_key)
cache_keys = [m.decode("utf-8") if isinstance(m, bytes) else m for m in members]
assert self.cache_key in cache_keys
def test_invalidate_by_tenant_via_index(self):
"""Test tenant-wide cache invalidation using index (fast path)."""
from unittest.mock import MagicMock
tenant_id = "test-tenant-id"
for i in range(3):
token_value = f"test-token-{i}"
mock_token = MagicMock()
mock_token.id = f"test-id-{i}"
mock_token.app_id = "test-app"
mock_token.tenant_id = tenant_id
mock_token.type = "app"
mock_token.token = token_value
mock_token.last_used_at = None
mock_token.created_at = datetime.now()
ApiTokenCache.set(token_value, "app", mock_token)
for i in range(3):
key = f"api_token:app:test-token-{i}"
assert redis_client.exists(key) == 1
result = ApiTokenCache.invalidate_by_tenant(tenant_id)
assert result is True
for i in range(3):
key = f"api_token:app:test-token-{i}"
assert redis_client.exists(key) == 0
assert redis_client.exists(ApiTokenCache._make_tenant_index_key(tenant_id)) == 0
def test_concurrent_cache_access(self):
"""Test concurrent cache access doesn't cause issues."""
import concurrent.futures
from unittest.mock import MagicMock
mock_token = MagicMock()
mock_token.id = "test-id"
mock_token.app_id = "test-app"
mock_token.tenant_id = "test-tenant"
mock_token.type = "app"
mock_token.token = self.test_token
mock_token.last_used_at = None
mock_token.created_at = datetime.now()
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
def get_from_cache():
return ApiTokenCache.get(self.test_token, self.test_scope)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(get_from_cache) for _ in range(50)]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
assert len(results) == 50
assert all(r is not None for r in results)
assert all(isinstance(r, CachedApiToken) for r in results)
class TestTokenUsageRecording:
"""Tests for recording token usage in Redis (batch update approach)."""
def setup_method(self):
self.test_token = "test-usage-token"
self.test_scope = "app"
self.active_key = ApiTokenCache.make_active_key(self.test_token, self.test_scope)
def teardown_method(self):
try:
redis_client.delete(self.active_key)
except Exception:
pass
def test_record_token_usage_sets_redis_key(self):
"""Test that record_token_usage writes an active key to Redis."""
from services.api_token_service import record_token_usage
record_token_usage(self.test_token, self.test_scope)
# Key should exist
assert redis_client.exists(self.active_key) == 1
# Value should be an ISO timestamp
value = redis_client.get(self.active_key)
if isinstance(value, bytes):
value = value.decode("utf-8")
datetime.fromisoformat(value) # Should not raise
def test_record_token_usage_has_ttl(self):
"""Test that active keys have a TTL as safety net."""
from services.api_token_service import record_token_usage
record_token_usage(self.test_token, self.test_scope)
ttl = redis_client.ttl(self.active_key)
assert 3595 <= ttl <= 3600 # ~1 hour
def test_record_token_usage_overwrites(self):
"""Test that repeated calls overwrite the same key (no accumulation)."""
from services.api_token_service import record_token_usage
record_token_usage(self.test_token, self.test_scope)
first_value = redis_client.get(self.active_key)
time.sleep(0.01) # Tiny delay so timestamp differs
record_token_usage(self.test_token, self.test_scope)
second_value = redis_client.get(self.active_key)
# Key count should still be 1 (overwritten, not accumulated)
assert redis_client.exists(self.active_key) == 1
class TestEndToEndCacheFlow:
"""End-to-end integration test for complete cache flow."""
@pytest.mark.usefixtures("db_session")
def test_complete_flow_cache_miss_then_hit(self, db_session):
"""
Test complete flow:
1. First request (cache miss) -> query DB -> cache result
2. Second request (cache hit) -> return from cache
3. Verify Redis state
"""
test_token_value = "test-e2e-token"
test_scope = "app"
test_token = ApiToken()
test_token.id = "test-e2e-id"
test_token.token = test_token_value
test_token.type = test_scope
test_token.app_id = "test-app"
test_token.tenant_id = "test-tenant"
test_token.last_used_at = None
test_token.created_at = datetime.now()
db_session.add(test_token)
db_session.commit()
try:
# Step 1: Cache miss - set token in cache
ApiTokenCache.set(test_token_value, test_scope, test_token)
cache_key = f"api_token:{test_scope}:{test_token_value}"
assert redis_client.exists(cache_key) == 1
# Step 2: Cache hit - get from cache
cached_token = ApiTokenCache.get(test_token_value, test_scope)
assert cached_token is not None
assert cached_token.id == test_token.id
assert cached_token.token == test_token_value
# Step 3: Verify tenant index
index_key = ApiTokenCache._make_tenant_index_key(test_token.tenant_id)
assert redis_client.exists(index_key) == 1
assert cache_key.encode() in redis_client.smembers(index_key)
# Step 4: Delete and verify cleanup
ApiTokenCache.delete(test_token_value, test_scope)
assert redis_client.exists(cache_key) == 0
assert cache_key.encode() not in redis_client.smembers(index_key)
finally:
db_session.delete(test_token)
db_session.commit()
redis_client.delete(f"api_token:{test_scope}:{test_token_value}")
redis_client.delete(ApiTokenCache._make_tenant_index_key(test_token.tenant_id))
def test_high_concurrency_simulation(self):
"""Simulate high concurrency access to cache."""
import concurrent.futures
from unittest.mock import MagicMock
test_token_value = "test-concurrent-token"
test_scope = "app"
mock_token = MagicMock()
mock_token.id = "concurrent-id"
mock_token.app_id = "test-app"
mock_token.tenant_id = "test-tenant"
mock_token.type = test_scope
mock_token.token = test_token_value
mock_token.last_used_at = datetime.now()
mock_token.created_at = datetime.now()
ApiTokenCache.set(test_token_value, test_scope, mock_token)
try:
def read_cache():
return ApiTokenCache.get(test_token_value, test_scope)
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
futures = [executor.submit(read_cache) for _ in range(100)]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
elapsed = time.time() - start_time
assert len(results) == 100
assert all(r is not None for r in results)
assert elapsed < 1.0, f"Too slow: {elapsed}s for 100 cache reads"
finally:
ApiTokenCache.delete(test_token_value, test_scope)
redis_client.delete(ApiTokenCache._make_tenant_index_key(mock_token.tenant_id))
class TestRedisFailover:
"""Test behavior when Redis is unavailable."""
@patch("services.api_token_service.redis_client")
def test_graceful_degradation_when_redis_fails(self, mock_redis):
"""Test system degrades gracefully when Redis is unavailable."""
from redis import RedisError
mock_redis.get.side_effect = RedisError("Connection failed")
mock_redis.setex.side_effect = RedisError("Connection failed")
result_get = ApiTokenCache.get("test-token", "app")
assert result_get is None
result_set = ApiTokenCache.set("test-token", "app", None)
assert result_set is False

View File

@@ -1,29 +0,0 @@
"""
Integration tests for KnowledgeRetrievalNode.
This module provides integration tests for KnowledgeRetrievalNode with real database interactions.
Note: These tests require database setup and are more complex than unit tests.
For now, we focus on unit tests which provide better coverage for the node logic.
"""
import pytest
class TestKnowledgeRetrievalNodeIntegration:
"""
Integration test suite for KnowledgeRetrievalNode.
Note: Full integration tests require:
- Database setup with datasets and documents
- Vector store for embeddings
- Model providers for retrieval
For now, unit tests provide comprehensive coverage of the node logic.
"""
@pytest.mark.skip(reason="Integration tests require full database and vector store setup")
def test_end_to_end_knowledge_retrieval(self):
"""Test end-to-end knowledge retrieval workflow."""
# TODO: Implement with real database
pass

View File

@@ -1,614 +0,0 @@
import uuid
from unittest.mock import patch
import pytest
from faker import Faker
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest
from models.dataset import Dataset, Document
from services.account_service import AccountService, TenantService
class TestGetAvailableDatasetsIntegration:
def test_returns_datasets_with_available_documents(
self, db_session_with_containers, mock_external_service_dependencies
):
# Arrange
fake = Faker()
# Create account and tenant
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create dataset
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
provider="dify",
data_source_type="upload_file",
created_by=account.id,
indexing_technique="high_quality",
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
# Create documents with completed status, enabled, not archived
for i in range(3):
document = Document(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
batch=str(uuid.uuid4()), # Required field
name=f"Document {i}",
created_from="web",
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
enabled=True,
archived=False,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Act
dataset_retrieval = DatasetRetrieval()
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
# Assert
assert len(result) == 1
assert result[0].id == dataset.id
assert result[0].tenant_id == tenant.id
assert result[0].name == dataset.name
def test_filters_out_datasets_with_only_archived_documents(
self, db_session_with_containers, mock_external_service_dependencies
):
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
created_by=account.id,
)
db_session_with_containers.add(dataset)
# Create only archived documents
for i in range(2):
document = Document(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
batch=str(uuid.uuid4()), # Required field
created_from="web",
name=f"Archived Document {i}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
enabled=True,
archived=True, # Archived
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Act
dataset_retrieval = DatasetRetrieval()
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
# Assert
assert len(result) == 0
def test_filters_out_datasets_with_only_disabled_documents(
self, db_session_with_containers, mock_external_service_dependencies
):
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
created_by=account.id,
)
db_session_with_containers.add(dataset)
# Create only disabled documents
for i in range(2):
document = Document(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
batch=str(uuid.uuid4()), # Required field
created_from="web",
name=f"Disabled Document {i}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
enabled=False, # Disabled
archived=False,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Act
dataset_retrieval = DatasetRetrieval()
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
# Assert
assert len(result) == 0
def test_filters_out_datasets_with_non_completed_documents(
self, db_session_with_containers, mock_external_service_dependencies
):
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
created_by=account.id,
)
db_session_with_containers.add(dataset)
# Create documents with non-completed status
for i, status in enumerate(["indexing", "parsing", "splitting"]):
document = Document(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
batch=str(uuid.uuid4()), # Required field
created_from="web",
name=f"Document {status}",
created_by=account.id,
doc_form="text_model",
indexing_status=status, # Not completed
enabled=True,
archived=False,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Act
dataset_retrieval = DatasetRetrieval()
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
# Assert
assert len(result) == 0
def test_includes_external_datasets_without_documents(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that external datasets are returned even with no available documents.
External datasets (e.g., from external knowledge bases) don't have
documents stored in Dify's database, so they should always be available.
Verifies:
- External datasets are included in results
- No document count check for external datasets
"""
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
provider="external", # External provider
data_source_type="external",
created_by=account.id,
)
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Act
dataset_retrieval = DatasetRetrieval()
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
# Assert
assert len(result) == 1
assert result[0].id == dataset.id
assert result[0].provider == "external"
def test_filters_by_tenant_id(self, db_session_with_containers, mock_external_service_dependencies):
# Arrange
fake = Faker()
# Create two accounts/tenants
account1 = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account1, name=fake.company())
tenant1 = account1.current_tenant
account2 = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account2, name=fake.company())
tenant2 = account2.current_tenant
# Create dataset for tenant1
dataset1 = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant1.id,
name="Tenant 1 Dataset",
provider="dify",
data_source_type="upload_file",
created_by=account1.id,
)
db_session_with_containers.add(dataset1)
# Create dataset for tenant2
dataset2 = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant2.id,
name="Tenant 2 Dataset",
provider="dify",
data_source_type="upload_file",
created_by=account2.id,
)
db_session_with_containers.add(dataset2)
# Add documents to both datasets
for dataset, account in [(dataset1, account1), (dataset2, account2)]:
document = Document(
id=str(uuid.uuid4()),
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
batch=str(uuid.uuid4()), # Required field
created_from="web",
name=f"Document for {dataset.name}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
enabled=True,
archived=False,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Act - request from tenant1, should only get tenant1's dataset
dataset_retrieval = DatasetRetrieval()
result = dataset_retrieval._get_available_datasets(tenant1.id, [dataset1.id, dataset2.id])
# Assert
assert len(result) == 1
assert result[0].id == dataset1.id
assert result[0].tenant_id == tenant1.id
def test_returns_empty_list_when_no_datasets_found(
self, db_session_with_containers, mock_external_service_dependencies
):
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Don't create any datasets
# Act
dataset_retrieval = DatasetRetrieval()
result = dataset_retrieval._get_available_datasets(tenant.id, [str(uuid.uuid4())])
# Assert
assert result == []
def test_returns_only_requested_dataset_ids(self, db_session_with_containers, mock_external_service_dependencies):
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create multiple datasets
datasets = []
for i in range(3):
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=f"Dataset {i}",
provider="dify",
data_source_type="upload_file",
created_by=account.id,
)
db_session_with_containers.add(dataset)
datasets.append(dataset)
# Add document
document = Document(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
batch=str(uuid.uuid4()), # Required field
created_from="web",
name=f"Document {i}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
enabled=True,
archived=False,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Act - request only dataset 0 and 2, not dataset 1
dataset_retrieval = DatasetRetrieval()
requested_ids = [datasets[0].id, datasets[2].id]
result = dataset_retrieval._get_available_datasets(tenant.id, requested_ids)
# Assert
assert len(result) == 2
returned_ids = {d.id for d in result}
assert returned_ids == {datasets[0].id, datasets[2].id}
class TestKnowledgeRetrievalIntegration:
def test_knowledge_retrieval_with_available_datasets(
self, db_session_with_containers, mock_external_service_dependencies
):
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
created_by=account.id,
indexing_technique="high_quality",
)
db_session_with_containers.add(dataset)
document = Document(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
batch=str(uuid.uuid4()), # Required field
created_from="web",
name=fake.sentence(),
created_by=account.id,
indexing_status="completed",
enabled=True,
archived=False,
doc_form="text_model",
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Create request
request = KnowledgeRetrievalRequest(
tenant_id=tenant.id,
user_id=account.id,
app_id=str(uuid.uuid4()),
user_from="web",
dataset_ids=[dataset.id],
query="test query",
retrieval_mode="multiple",
top_k=5,
)
dataset_retrieval = DatasetRetrieval()
# Mock rate limit check and retrieval
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)):
with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]):
# Act
result = dataset_retrieval.knowledge_retrieval(request)
# Assert
assert isinstance(result, list)
def test_knowledge_retrieval_no_available_datasets(
self, db_session_with_containers, mock_external_service_dependencies
):
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create dataset but no documents
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
created_by=account.id,
)
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
request = KnowledgeRetrievalRequest(
tenant_id=tenant.id,
user_id=account.id,
app_id=str(uuid.uuid4()),
user_from="web",
dataset_ids=[dataset.id],
query="test query",
retrieval_mode="multiple",
top_k=5,
)
dataset_retrieval = DatasetRetrieval()
# Mock rate limit check
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
# Act
result = dataset_retrieval.knowledge_retrieval(request)
# Assert
assert result == []
def test_knowledge_retrieval_rate_limit_exceeded(
self, db_session_with_containers, mock_external_service_dependencies
):
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
created_by=account.id,
)
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
request = KnowledgeRetrievalRequest(
tenant_id=tenant.id,
user_id=account.id,
app_id=str(uuid.uuid4()),
user_from="web",
dataset_ids=[dataset.id],
query="test query",
retrieval_mode="multiple",
top_k=5,
)
dataset_retrieval = DatasetRetrieval()
# Mock rate limit check to raise exception
with patch.object(
dataset_retrieval,
"_check_knowledge_rate_limit",
side_effect=Exception("Rate limit exceeded"),
):
# Act & Assert
with pytest.raises(Exception, match="Rate limit exceeded"):
dataset_retrieval.knowledge_retrieval(request)
@pytest.fixture
def mock_external_service_dependencies():
with (
patch("services.account_service.FeatureService") as mock_account_feature_service,
):
# Setup default mock returns for account service
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
yield {
"account_feature_service": mock_account_feature_service,
}

View File

@@ -605,20 +605,26 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage.download.side_effect = mock_download
# Execute the task - should raise ValueError for empty CSV
# Execute the task
job_id = str(uuid.uuid4())
with pytest.raises(ValueError, match="The CSV file is empty"):
batch_create_segment_to_index_task(
job_id=job_id,
upload_file_id=upload_file.id,
dataset_id=dataset.id,
document_id=document.id,
tenant_id=tenant.id,
user_id=account.id,
)
batch_create_segment_to_index_task(
job_id=job_id,
upload_file_id=upload_file.id,
dataset_id=dataset.id,
document_id=document.id,
tenant_id=tenant.id,
user_id=account.id,
)
# Verify error handling
# Since exception was raised, no segments should be created
# Check Redis cache was set to error status
from extensions.ext_redis import redis_client
cache_key = f"segment_batch_import_{job_id}"
cache_value = redis_client.get(cache_key)
assert cache_value == b"error"
# Verify no segments were created
from extensions.ext_database import db
segments = db.session.query(DocumentSegment).all()

View File

@@ -153,7 +153,8 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task
clean_notion_document_task(document_ids, dataset.id)
# Verify segments are deleted
# Verify documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 0
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(document_ids))
@@ -161,9 +162,9 @@ class TestCleanNotionDocumentTask:
== 0
)
# Verify index processor was called
# Verify index processor was called for each document
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_called_once()
assert mock_processor.clean.call_count == len(document_ids)
# This test successfully verifies:
# 1. Document records are properly deleted from the database
@@ -185,12 +186,12 @@ class TestCleanNotionDocumentTask:
non_existent_dataset_id = str(uuid.uuid4())
document_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
# Execute cleanup task with non-existent dataset - expect exception
with pytest.raises(Exception, match="Document has no dataset"):
clean_notion_document_task(document_ids, non_existent_dataset_id)
# Execute cleanup task with non-existent dataset
clean_notion_document_task(document_ids, non_existent_dataset_id)
# Verify that the index processor factory was not used
mock_index_processor_factory.return_value.init_index_processor.assert_not_called()
# Verify that the index processor was not called
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_not_called()
def test_clean_notion_document_task_empty_document_list(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
@@ -228,13 +229,9 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task with empty document list
clean_notion_document_task([], dataset.id)
# Verify that the index processor was called once with empty node list
# Verify that the index processor was not called
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
assert mock_processor.clean.call_count == 1
args, kwargs = mock_processor.clean.call_args
# args: (dataset, total_index_node_ids)
assert isinstance(args[0], Dataset)
assert args[1] == []
mock_processor.clean.assert_not_called()
def test_clean_notion_document_task_with_different_index_types(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
@@ -318,7 +315,8 @@ class TestCleanNotionDocumentTask:
# Note: This test successfully verifies cleanup with different document types.
# The task properly handles various index types and document configurations.
# Verify segments are deleted
# Verify documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id == document.id)
@@ -406,7 +404,8 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task
clean_notion_document_task([document.id], dataset.id)
# Verify segments are deleted
# Verify documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
== 0
@@ -509,7 +508,8 @@ class TestCleanNotionDocumentTask:
clean_notion_document_task(documents_to_clean, dataset.id)
# Verify only specified documents' segments are deleted
# Verify only specified documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id.in_(documents_to_clean)).count() == 0
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(documents_to_clean))
@@ -697,12 +697,11 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.commit()
# Mock index processor to raise an exception
mock_index_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_index_processor = mock_index_processor_factory.init_index_processor.return_value
mock_index_processor.clean.side_effect = Exception("Index processor error")
# Execute cleanup task - current implementation propagates the exception
with pytest.raises(Exception, match="Index processor error"):
clean_notion_document_task([document.id], dataset.id)
# Execute cleanup task - it should handle the exception gracefully
clean_notion_document_task([document.id], dataset.id)
# Note: This test demonstrates the task's error handling capability.
# Even with external service errors, the database operations complete successfully.
@@ -804,7 +803,8 @@ class TestCleanNotionDocumentTask:
all_document_ids = [doc.id for doc in documents]
clean_notion_document_task(all_document_ids, dataset.id)
# Verify all segments are deleted
# Verify all documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
== 0
@@ -914,7 +914,8 @@ class TestCleanNotionDocumentTask:
clean_notion_document_task([target_document.id], target_dataset.id)
# Verify only documents' segments from target dataset are deleted
# Verify only documents from target dataset are deleted
assert db_session_with_containers.query(Document).filter(Document.id == target_document.id).count() == 0
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id == target_document.id)
@@ -1029,7 +1030,8 @@ class TestCleanNotionDocumentTask:
all_document_ids = [doc.id for doc in documents]
clean_notion_document_task(all_document_ids, dataset.id)
# Verify all segments are deleted regardless of status
# Verify all documents and segments are deleted regardless of status
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
== 0
@@ -1140,7 +1142,8 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task
clean_notion_document_task([document.id], dataset.id)
# Verify segments are deleted
# Verify documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
== 0

View File

@@ -9,7 +9,6 @@ from flask import Flask
from controllers.console import wraps as console_wraps
from controllers.console.app import workflow_run as workflow_run_module
from controllers.web.error import NotFoundError
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes.human_input.entities import FormInput, UserAction
@@ -54,7 +53,6 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte
monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com")
workflow_run = Mock(spec=WorkflowRun)
workflow_run.tenant_id = "tenant-123"
workflow_run.status = WorkflowExecutionStatus.PAUSED
workflow_run.created_at = datetime(2024, 1, 1, 12, 0, 0)
fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run))
@@ -91,20 +89,3 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte
== "https://web.example.com/form/backstage-token"
)
assert "pending_human_inputs" not in response
def test_pause_details_tenant_isolation(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
account = _make_account()
_patch_console_guards(monkeypatch, account)
monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com")
workflow_run = Mock(spec=WorkflowRun)
workflow_run.tenant_id = "tenant-456"
workflow_run.status = WorkflowExecutionStatus.PAUSED
workflow_run.created_at = datetime(2024, 1, 1, 12, 0, 0)
fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run))
monkeypatch.setattr(workflow_run_module, "db", fake_db)
with pytest.raises(NotFoundError):
with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"):
response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1")

View File

@@ -1,62 +0,0 @@
"""
Unit tests for Service API knowledge pipeline file-upload serialization.
"""
import importlib.util
from datetime import UTC, datetime
from pathlib import Path
class FakeUploadFile:
id: str
name: str
size: int
extension: str
mime_type: str
created_by: str
created_at: datetime | None
def _load_serialize_upload_file():
api_dir = Path(__file__).resolve().parents[5]
serializers_path = api_dir / "controllers" / "service_api" / "dataset" / "rag_pipeline" / "serializers.py"
spec = importlib.util.spec_from_file_location("rag_pipeline_serializers", serializers_path)
assert spec
assert spec.loader
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore[attr-defined]
return module.serialize_upload_file
def test_file_upload_created_at_is_isoformat_string():
serialize_upload_file = _load_serialize_upload_file()
created_at = datetime(2026, 2, 8, 12, 0, 0, tzinfo=UTC)
upload_file = FakeUploadFile()
upload_file.id = "file-1"
upload_file.name = "test.pdf"
upload_file.size = 123
upload_file.extension = "pdf"
upload_file.mime_type = "application/pdf"
upload_file.created_by = "account-1"
upload_file.created_at = created_at
result = serialize_upload_file(upload_file)
assert result["created_at"] == created_at.isoformat()
def test_file_upload_created_at_none_serializes_to_null():
serialize_upload_file = _load_serialize_upload_file()
upload_file = FakeUploadFile()
upload_file.id = "file-1"
upload_file.name = "test.pdf"
upload_file.size = 123
upload_file.extension = "pdf"
upload_file.mime_type = "application/pdf"
upload_file.created_by = "account-1"
upload_file.created_at = None
result = serialize_upload_file(upload_file)
assert result["created_at"] is None

View File

@@ -1,54 +0,0 @@
"""
Unit tests for Service API knowledge pipeline route registration.
"""
import ast
from pathlib import Path
def test_rag_pipeline_routes_registered():
api_dir = Path(__file__).resolve().parents[5]
service_api_init = api_dir / "controllers" / "service_api" / "__init__.py"
rag_pipeline_workflow = (
api_dir / "controllers" / "service_api" / "dataset" / "rag_pipeline" / "rag_pipeline_workflow.py"
)
assert service_api_init.exists()
assert rag_pipeline_workflow.exists()
init_tree = ast.parse(service_api_init.read_text(encoding="utf-8"))
import_found = False
for node in ast.walk(init_tree):
if not isinstance(node, ast.ImportFrom):
continue
if node.module != "dataset.rag_pipeline" or node.level != 1:
continue
if any(alias.name == "rag_pipeline_workflow" for alias in node.names):
import_found = True
break
assert import_found, "from .dataset.rag_pipeline import rag_pipeline_workflow not found in service_api/__init__.py"
workflow_tree = ast.parse(rag_pipeline_workflow.read_text(encoding="utf-8"))
route_paths: set[str] = set()
for node in ast.walk(workflow_tree):
if not isinstance(node, ast.ClassDef):
continue
for decorator in node.decorator_list:
if not isinstance(decorator, ast.Call):
continue
if not isinstance(decorator.func, ast.Attribute):
continue
if decorator.func.attr != "route":
continue
if not decorator.args:
continue
first_arg = decorator.args[0]
if isinstance(first_arg, ast.Constant) and isinstance(first_arg.value, str):
route_paths.add(first_arg.value)
assert "/datasets/<uuid:dataset_id>/pipeline/datasource-plugins" in route_paths
assert "/datasets/<uuid:dataset_id>/pipeline/datasource/nodes/<string:node_id>/run" in route_paths
assert "/datasets/<uuid:dataset_id>/pipeline/run" in route_paths
assert "/datasets/pipeline/file-upload" in route_paths

View File

@@ -1,61 +0,0 @@
from datetime import UTC, datetime
from unittest.mock import Mock
from uuid import UUID, uuid4
import pytest
from controllers.service_api.end_user.end_user import EndUserApi
from controllers.service_api.end_user.error import EndUserNotFoundError
from models.model import App, EndUser
class TestEndUserApi:
@pytest.fixture
def resource(self) -> EndUserApi:
return EndUserApi()
@pytest.fixture
def app_model(self) -> App:
app = Mock(spec=App)
app.id = str(uuid4())
app.tenant_id = str(uuid4())
return app
def test_get_end_user_returns_all_attributes(self, mocker, resource: EndUserApi, app_model: App) -> None:
end_user = Mock(spec=EndUser)
end_user.id = str(uuid4())
end_user.tenant_id = app_model.tenant_id
end_user.app_id = app_model.id
end_user.type = "service_api"
end_user.external_user_id = "external-123"
end_user.name = "Alice"
end_user._is_anonymous = True
end_user.session_id = "session-xyz"
end_user.created_at = datetime(2024, 1, 1, tzinfo=UTC)
end_user.updated_at = datetime(2024, 1, 2, tzinfo=UTC)
get_end_user_by_id = mocker.patch(
"controllers.service_api.end_user.end_user.EndUserService.get_end_user_by_id", return_value=end_user
)
result = EndUserApi.get.__wrapped__(resource, app_model=app_model, end_user_id=UUID(end_user.id))
get_end_user_by_id.assert_called_once_with(
tenant_id=app_model.tenant_id, app_id=app_model.id, end_user_id=end_user.id
)
assert result["id"] == end_user.id
assert result["tenant_id"] == end_user.tenant_id
assert result["app_id"] == end_user.app_id
assert result["type"] == end_user.type
assert result["external_user_id"] == end_user.external_user_id
assert result["name"] == end_user.name
assert result["is_anonymous"] is True
assert result["session_id"] == end_user.session_id
assert result["created_at"].startswith("2024-01-01T00:00:00")
assert result["updated_at"].startswith("2024-01-02T00:00:00")
def test_get_end_user_not_found(self, mocker, resource: EndUserApi, app_model: App) -> None:
mocker.patch("controllers.service_api.end_user.end_user.EndUserService.get_end_user_by_id", return_value=None)
with pytest.raises(EndUserNotFoundError):
EndUserApi.get.__wrapped__(resource, app_model=app_model, end_user_id=uuid4())

View File

@@ -1,715 +0,0 @@
from unittest.mock import MagicMock, Mock, patch
from uuid import uuid4
import pytest
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.nodes.knowledge_retrieval import exc
from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest
from models.dataset import Dataset
# ==================== Helper Functions ====================
def create_mock_dataset(
dataset_id: str | None = None,
tenant_id: str | None = None,
provider: str = "dify",
indexing_technique: str = "high_quality",
available_document_count: int = 10,
) -> Mock:
"""
Create a mock Dataset object for testing.
Args:
dataset_id: Unique identifier for the dataset
tenant_id: Tenant ID for the dataset
provider: Provider type ("dify" or "external")
indexing_technique: Indexing technique ("high_quality" or "economy")
available_document_count: Number of available documents
Returns:
Mock: A properly configured Dataset mock
"""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id or str(uuid4())
dataset.tenant_id = tenant_id or str(uuid4())
dataset.name = "test_dataset"
dataset.provider = provider
dataset.indexing_technique = indexing_technique
dataset.available_document_count = available_document_count
dataset.embedding_model = "text-embedding-ada-002"
dataset.embedding_model_provider = "openai"
dataset.retrieval_model = {
"search_method": "semantic_search",
"reranking_enable": False,
"top_k": 4,
"score_threshold_enabled": False,
}
return dataset
def create_mock_document(
content: str,
doc_id: str,
score: float = 0.8,
provider: str = "dify",
additional_metadata: dict | None = None,
) -> Document:
"""
Create a mock Document object for testing.
Args:
content: The text content of the document
doc_id: Unique identifier for the document chunk
score: Relevance score (0.0 to 1.0)
provider: Document provider ("dify" or "external")
additional_metadata: Optional extra metadata fields
Returns:
Document: A properly structured Document object
"""
metadata = {
"doc_id": doc_id,
"document_id": str(uuid4()),
"dataset_id": str(uuid4()),
"score": score,
}
if additional_metadata:
metadata.update(additional_metadata)
return Document(
page_content=content,
metadata=metadata,
provider=provider,
)
# ==================== Test _check_knowledge_rate_limit ====================
class TestCheckKnowledgeRateLimit:
"""
Test suite for _check_knowledge_rate_limit method.
The _check_knowledge_rate_limit method validates whether a tenant has
exceeded their knowledge retrieval rate limit. This is important for:
- Preventing abuse of the knowledge retrieval system
- Enforcing subscription plan limits
- Tracking usage for billing purposes
Test Cases:
============
1. Rate limit disabled - no exception raised
2. Rate limit enabled but not exceeded - no exception raised
3. Rate limit enabled and exceeded - RateLimitExceededError raised
4. Redis operations are performed correctly
5. RateLimitLog is created when limit is exceeded
"""
@patch("core.rag.retrieval.dataset_retrieval.FeatureService")
@patch("core.rag.retrieval.dataset_retrieval.redis_client")
def test_rate_limit_disabled_no_exception(self, mock_redis, mock_feature_service):
"""
Test that when rate limit is disabled, no exception is raised.
This test verifies the behavior when the tenant's subscription
does not have rate limiting enabled.
Verifies:
- FeatureService.get_knowledge_rate_limit is called
- No Redis operations are performed
- No exception is raised
- Retrieval proceeds normally
"""
# Arrange
tenant_id = str(uuid4())
dataset_retrieval = DatasetRetrieval()
# Mock rate limit disabled
mock_limit = Mock()
mock_limit.enabled = False
mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit
# Act & Assert - should not raise any exception
dataset_retrieval._check_knowledge_rate_limit(tenant_id)
# Verify FeatureService was called
mock_feature_service.get_knowledge_rate_limit.assert_called_once_with(tenant_id)
# Verify no Redis operations were performed
assert not mock_redis.zadd.called
assert not mock_redis.zremrangebyscore.called
assert not mock_redis.zcard.called
@patch("core.rag.retrieval.dataset_retrieval.session_factory")
@patch("core.rag.retrieval.dataset_retrieval.FeatureService")
@patch("core.rag.retrieval.dataset_retrieval.redis_client")
@patch("core.rag.retrieval.dataset_retrieval.time")
def test_rate_limit_enabled_not_exceeded(self, mock_time, mock_redis, mock_feature_service, mock_session_factory):
"""
Test that when rate limit is enabled but not exceeded, no exception is raised.
This test simulates a tenant making requests within their rate limit.
The Redis sorted set stores timestamps of recent requests, and old
requests (older than 60 seconds) are removed.
Verifies:
- Redis zadd is called to track the request
- Redis zremrangebyscore removes old entries
- Redis zcard returns count within limit
- No exception is raised
"""
# Arrange
tenant_id = str(uuid4())
dataset_retrieval = DatasetRetrieval()
# Mock rate limit enabled with limit of 100 requests per minute
mock_limit = Mock()
mock_limit.enabled = True
mock_limit.limit = 100
mock_limit.subscription_plan = "professional"
mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit
# Mock time
current_time = 1234567890000 # Current time in milliseconds
mock_time.time.return_value = current_time / 1000 # Return seconds
mock_time.time.__mul__ = lambda self, x: int(self * x) # Multiply to get milliseconds
# Mock Redis operations
# zcard returns 50 (within limit of 100)
mock_redis.zcard.return_value = 50
# Mock session_factory.create_session
mock_session = MagicMock()
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
mock_session_factory.create_session.return_value.__exit__.return_value = None
# Act & Assert - should not raise any exception
dataset_retrieval._check_knowledge_rate_limit(tenant_id)
# Verify Redis operations
expected_key = f"rate_limit_{tenant_id}"
mock_redis.zadd.assert_called_once_with(expected_key, {current_time: current_time})
mock_redis.zremrangebyscore.assert_called_once_with(expected_key, 0, current_time - 60000)
mock_redis.zcard.assert_called_once_with(expected_key)
@patch("core.rag.retrieval.dataset_retrieval.session_factory")
@patch("core.rag.retrieval.dataset_retrieval.FeatureService")
@patch("core.rag.retrieval.dataset_retrieval.redis_client")
@patch("core.rag.retrieval.dataset_retrieval.time")
def test_rate_limit_enabled_exceeded_raises_exception(
self, mock_time, mock_redis, mock_feature_service, mock_session_factory
):
"""
Test that when rate limit is enabled and exceeded, RateLimitExceededError is raised.
This test simulates a tenant exceeding their rate limit. When the count
of recent requests exceeds the limit, an exception should be raised and
a RateLimitLog should be created.
Verifies:
- Redis zcard returns count exceeding limit
- RateLimitExceededError is raised with correct message
- RateLimitLog is created in database
- Session operations are performed correctly
"""
# Arrange
tenant_id = str(uuid4())
dataset_retrieval = DatasetRetrieval()
# Mock rate limit enabled with limit of 100 requests per minute
mock_limit = Mock()
mock_limit.enabled = True
mock_limit.limit = 100
mock_limit.subscription_plan = "professional"
mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit
# Mock time
current_time = 1234567890000
mock_time.time.return_value = current_time / 1000
# Mock Redis operations - return count exceeding limit
mock_redis.zcard.return_value = 150 # Exceeds limit of 100
# Mock session_factory.create_session
mock_session = MagicMock()
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
mock_session_factory.create_session.return_value.__exit__.return_value = None
# Act & Assert
with pytest.raises(exc.RateLimitExceededError) as exc_info:
dataset_retrieval._check_knowledge_rate_limit(tenant_id)
# Verify exception message
assert "knowledge base request rate limit" in str(exc_info.value)
# Verify RateLimitLog was created
mock_session.add.assert_called_once()
added_log = mock_session.add.call_args[0][0]
assert added_log.tenant_id == tenant_id
assert added_log.subscription_plan == "professional"
assert added_log.operation == "knowledge"
# ==================== Test _get_available_datasets ====================
class TestGetAvailableDatasets:
"""
Test suite for _get_available_datasets method.
The _get_available_datasets method retrieves datasets that are available
for retrieval. A dataset is considered available if:
- It belongs to the specified tenant
- It's in the list of requested dataset_ids
- It has at least one completed, enabled, non-archived document OR
- It's an external provider dataset
Note: Due to SQLAlchemy subquery complexity, full testing is done in
integration tests. Unit tests here verify basic behavior.
"""
def test_method_exists_and_has_correct_signature(self):
"""
Test that the method exists and has the correct signature.
Verifies:
- Method exists on DatasetRetrieval class
- Accepts tenant_id and dataset_ids parameters
"""
# Arrange
dataset_retrieval = DatasetRetrieval()
# Assert - method exists
assert hasattr(dataset_retrieval, "_get_available_datasets")
# Assert - method is callable
assert callable(dataset_retrieval._get_available_datasets)
# ==================== Test knowledge_retrieval ====================
class TestDatasetRetrievalKnowledgeRetrieval:
"""
Test suite for knowledge_retrieval method.
The knowledge_retrieval method is the main entry point for retrieving
knowledge from datasets. It orchestrates the entire retrieval process:
1. Checks rate limits
2. Gets available datasets
3. Applies metadata filtering if enabled
4. Performs retrieval (single or multiple mode)
5. Formats and returns results
Test Cases:
============
1. Single mode retrieval
2. Multiple mode retrieval
3. Metadata filtering disabled
4. Metadata filtering automatic
5. Metadata filtering manual
6. External documents handling
7. Dify documents handling
8. Empty results handling
9. Rate limit exceeded
10. No available datasets
"""
def test_knowledge_retrieval_single_mode_basic(self):
"""
Test knowledge_retrieval in single retrieval mode - basic check.
Note: Full single mode testing requires complex model mocking and
is better suited for integration tests. This test verifies the
method accepts single mode requests.
Verifies:
- Method can accept single mode request
- Request parameters are correctly structured
"""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
app_id = str(uuid4())
dataset_id = str(uuid4())
request = KnowledgeRetrievalRequest(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
user_from="web",
dataset_ids=[dataset_id],
query="What is Python?",
retrieval_mode="single",
model_provider="openai",
model_name="gpt-4",
model_mode="chat",
completion_params={"temperature": 0.7},
)
# Assert - request is properly structured
assert request.retrieval_mode == "single"
assert request.model_provider == "openai"
assert request.model_name == "gpt-4"
assert request.model_mode == "chat"
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
@patch("core.rag.retrieval.dataset_retrieval.session_factory")
def test_knowledge_retrieval_multiple_mode(self, mock_session_factory, mock_data_processor):
"""
Test knowledge_retrieval in multiple retrieval mode.
In multiple mode, retrieval is performed across all datasets and
results are combined and reranked.
Verifies:
- Rate limit is checked
- Available datasets are retrieved
- Multiple retrieval is performed
- Results are combined and reranked
- Results are formatted correctly
"""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
app_id = str(uuid4())
dataset_id1 = str(uuid4())
dataset_id2 = str(uuid4())
request = KnowledgeRetrievalRequest(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
user_from="web",
dataset_ids=[dataset_id1, dataset_id2],
query="What is Python?",
retrieval_mode="multiple",
top_k=5,
score_threshold=0.7,
reranking_enable=True,
reranking_mode="reranking_model",
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
)
dataset_retrieval = DatasetRetrieval()
# Mock _check_knowledge_rate_limit
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
# Mock _get_available_datasets
mock_dataset1 = create_mock_dataset(dataset_id=dataset_id1, tenant_id=tenant_id)
mock_dataset2 = create_mock_dataset(dataset_id=dataset_id2, tenant_id=tenant_id)
with patch.object(
dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset1, mock_dataset2]
):
# Mock get_metadata_filter_condition
with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)):
# Mock multiple_retrieve to return documents
doc1 = create_mock_document("Python is great", "doc1", score=0.9)
doc2 = create_mock_document("Python is awesome", "doc2", score=0.8)
with patch.object(
dataset_retrieval, "multiple_retrieve", return_value=[doc1, doc2]
) as mock_multiple_retrieve:
# Mock format_retrieval_documents
mock_record = Mock()
mock_record.segment = Mock()
mock_record.segment.dataset_id = dataset_id1
mock_record.segment.document_id = str(uuid4())
mock_record.segment.index_node_hash = "hash123"
mock_record.segment.hit_count = 5
mock_record.segment.word_count = 100
mock_record.segment.position = 1
mock_record.segment.get_sign_content.return_value = "Python is great"
mock_record.segment.answer = None
mock_record.score = 0.9
mock_record.child_chunks = []
mock_record.summary = None
mock_record.files = None
mock_retrieval_service = Mock()
mock_retrieval_service.format_retrieval_documents.return_value = [mock_record]
with patch(
"core.rag.retrieval.dataset_retrieval.RetrievalService",
return_value=mock_retrieval_service,
):
# Mock database queries
mock_session = MagicMock()
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
mock_session_factory.create_session.return_value.__exit__.return_value = None
mock_dataset_from_db = Mock()
mock_dataset_from_db.id = dataset_id1
mock_dataset_from_db.name = "test_dataset"
mock_document = Mock()
mock_document.id = str(uuid4())
mock_document.name = "test_doc"
mock_document.data_source_type = "upload_file"
mock_document.doc_metadata = {}
mock_session.query.return_value.filter.return_value.all.return_value = [
mock_dataset_from_db
]
mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter(
[mock_dataset_from_db, mock_document]
)
# Act
result = dataset_retrieval.knowledge_retrieval(request)
# Assert
assert isinstance(result, list)
mock_multiple_retrieve.assert_called_once()
def test_knowledge_retrieval_metadata_filtering_disabled(self):
"""
Test knowledge_retrieval with metadata filtering disabled.
When metadata filtering is disabled, get_metadata_filter_condition is
NOT called (the method checks metadata_filtering_mode != "disabled").
Verifies:
- get_metadata_filter_condition is NOT called when mode is "disabled"
- Retrieval proceeds without metadata filters
"""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
app_id = str(uuid4())
dataset_id = str(uuid4())
request = KnowledgeRetrievalRequest(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
user_from="web",
dataset_ids=[dataset_id],
query="What is Python?",
retrieval_mode="multiple",
metadata_filtering_mode="disabled",
top_k=5,
)
dataset_retrieval = DatasetRetrieval()
# Mock dependencies
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
mock_dataset = create_mock_dataset(dataset_id=dataset_id, tenant_id=tenant_id)
with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]):
# Mock get_metadata_filter_condition - should NOT be called when disabled
with patch.object(
dataset_retrieval,
"get_metadata_filter_condition",
return_value=(None, None),
) as mock_get_metadata:
with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]):
# Act
result = dataset_retrieval.knowledge_retrieval(request)
# Assert
assert isinstance(result, list)
# get_metadata_filter_condition should NOT be called when mode is "disabled"
mock_get_metadata.assert_not_called()
def test_knowledge_retrieval_with_external_documents(self):
"""
Test knowledge_retrieval with external documents.
External documents come from external knowledge bases and should
be formatted differently than Dify documents.
Verifies:
- External documents are handled correctly
- Provider is set to "external"
- Metadata includes external-specific fields
"""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
app_id = str(uuid4())
dataset_id = str(uuid4())
request = KnowledgeRetrievalRequest(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
user_from="web",
dataset_ids=[dataset_id],
query="What is Python?",
retrieval_mode="multiple",
top_k=5,
)
dataset_retrieval = DatasetRetrieval()
# Mock dependencies
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
mock_dataset = create_mock_dataset(dataset_id=dataset_id, tenant_id=tenant_id, provider="external")
with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]):
with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)):
# Create external document
external_doc = create_mock_document(
"External knowledge",
"doc1",
score=0.9,
provider="external",
additional_metadata={
"dataset_id": dataset_id,
"dataset_name": "external_kb",
"document_id": "ext_doc1",
"title": "External Document",
},
)
with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[external_doc]):
# Act
result = dataset_retrieval.knowledge_retrieval(request)
# Assert
assert isinstance(result, list)
if result:
assert result[0].metadata.data_source_type == "external"
def test_knowledge_retrieval_empty_results(self):
"""
Test knowledge_retrieval when no documents are found.
Verifies:
- Empty list is returned
- No errors are raised
- All dependencies are still called
"""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
app_id = str(uuid4())
dataset_id = str(uuid4())
request = KnowledgeRetrievalRequest(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
user_from="web",
dataset_ids=[dataset_id],
query="What is Python?",
retrieval_mode="multiple",
top_k=5,
)
dataset_retrieval = DatasetRetrieval()
# Mock dependencies
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
mock_dataset = create_mock_dataset(dataset_id=dataset_id, tenant_id=tenant_id)
with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]):
with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)):
# Mock multiple_retrieve to return empty list
with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]):
# Act
result = dataset_retrieval.knowledge_retrieval(request)
# Assert
assert result == []
def test_knowledge_retrieval_rate_limit_exceeded(self):
"""
Test knowledge_retrieval when rate limit is exceeded.
Verifies:
- RateLimitExceededError is raised
- No further processing occurs
"""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
app_id = str(uuid4())
dataset_id = str(uuid4())
request = KnowledgeRetrievalRequest(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
user_from="web",
dataset_ids=[dataset_id],
query="What is Python?",
retrieval_mode="multiple",
top_k=5,
)
dataset_retrieval = DatasetRetrieval()
# Mock _check_knowledge_rate_limit to raise exception
with patch.object(
dataset_retrieval,
"_check_knowledge_rate_limit",
side_effect=exc.RateLimitExceededError("Rate limit exceeded"),
):
# Act & Assert
with pytest.raises(exc.RateLimitExceededError):
dataset_retrieval.knowledge_retrieval(request)
def test_knowledge_retrieval_no_available_datasets(self):
"""
Test knowledge_retrieval when no datasets are available.
Verifies:
- Empty list is returned
- No retrieval is attempted
"""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
app_id = str(uuid4())
dataset_id = str(uuid4())
request = KnowledgeRetrievalRequest(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
user_from="web",
dataset_ids=[dataset_id],
query="What is Python?",
retrieval_mode="multiple",
top_k=5,
)
dataset_retrieval = DatasetRetrieval()
# Mock dependencies
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
# Mock _get_available_datasets to return empty list
with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[]):
# Act
result = dataset_retrieval.knowledge_retrieval(request)
# Assert
assert result == []
def test_knowledge_retrieval_handles_multiple_documents_with_different_scores(self):
"""
Test that knowledge_retrieval processes multiple documents with different scores.
Note: Full sorting and position testing requires complex SQLAlchemy mocking
which is better suited for integration tests. This test verifies documents
with different scores can be created and have their metadata.
Verifies:
- Documents can be created with different scores
- Score metadata is properly set
"""
# Create documents with different scores
doc1 = create_mock_document("Low score", "doc1", score=0.6)
doc2 = create_mock_document("High score", "doc2", score=0.95)
doc3 = create_mock_document("Medium score", "doc3", score=0.8)
# Assert - each document has the correct score
assert doc1.metadata["score"] == 0.6
assert doc2.metadata["score"] == 0.95
assert doc3.metadata["score"] == 0.8
# Assert - documents are correctly sorted (not the retrieval result, just the list)
unsorted = [doc1, doc2, doc3]
sorted_docs = sorted(unsorted, key=lambda d: d.metadata["score"], reverse=True)
assert [d.metadata["score"] for d in sorted_docs] == [0.95, 0.8, 0.6]

View File

@@ -0,0 +1,197 @@
from unittest.mock import MagicMock, patch
import pytest
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.nodes.agent.agent_node import AgentNode
class TestInferToolProviderType:
"""Test cases for AgentNode._infer_tool_provider_type method."""
def test_infer_type_from_config_workflow(self):
"""Test inferring workflow provider type from config."""
tool_config = {
"type": "workflow",
"provider_name": "workflow-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.WORKFLOW
def test_infer_type_from_config_builtin(self):
"""Test inferring builtin provider type from config."""
tool_config = {
"type": "builtin",
"provider_name": "builtin-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.BUILT_IN
def test_infer_type_from_config_api(self):
"""Test inferring API provider type from config."""
tool_config = {
"type": "api",
"provider_name": "api-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.API
def test_infer_type_from_config_mcp(self):
"""Test inferring MCP provider type from config."""
tool_config = {
"type": "mcp",
"provider_name": "mcp-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.MCP
def test_infer_type_invalid_config_value_raises_error(self):
"""Test that invalid type value in config raises ValueError."""
tool_config = {
"type": "invalid-type",
"provider_name": "workflow-provider-id",
}
tenant_id = "test-tenant"
with pytest.raises(ValueError):
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
def test_infer_workflow_type_from_database(self):
"""Test inferring workflow provider type from database."""
tool_config = {
"provider_name": "workflow-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First query (WorkflowToolProvider) returns a result
mock_session.scalar.return_value = True
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.WORKFLOW
# Should only query once (after finding WorkflowToolProvider)
assert mock_session.scalar.call_count == 1
def test_infer_mcp_type_from_database(self):
"""Test inferring MCP provider type from database."""
tool_config = {
"provider_name": "mcp-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First query (WorkflowToolProvider) returns None
# Second query (MCPToolProvider) returns a result
mock_session.scalar.side_effect = [None, True]
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.MCP
assert mock_session.scalar.call_count == 2
def test_infer_api_type_from_database(self):
"""Test inferring API provider type from database."""
tool_config = {
"provider_name": "api-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First query (WorkflowToolProvider) returns None
# Second query (MCPToolProvider) returns None
# Third query (ApiToolProvider) returns a result
mock_session.scalar.side_effect = [None, None, True]
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.API
assert mock_session.scalar.call_count == 3
def test_infer_builtin_type_from_database(self):
"""Test inferring builtin provider type from database."""
tool_config = {
"provider_name": "builtin-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First three queries return None
# Fourth query (BuiltinToolProvider) returns a result
mock_session.scalar.side_effect = [None, None, None, True]
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.BUILT_IN
assert mock_session.scalar.call_count == 4
def test_infer_type_default_when_not_found(self):
"""Test raising AgentNodeError when provider is not found in database."""
tool_config = {
"provider_name": "unknown-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# All queries return None
mock_session.scalar.return_value = None
# Current implementation raises AgentNodeError when provider not found
from core.workflow.nodes.agent.exc import AgentNodeError
with pytest.raises(AgentNodeError, match="Tool provider with ID 'unknown-provider-id' not found"):
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
def test_infer_type_default_when_no_provider_name(self):
"""Test defaulting to BUILT_IN when provider_name is missing."""
tool_config = {}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.BUILT_IN
def test_infer_type_database_exception_propagates(self):
"""Test that database exception propagates (current implementation doesn't catch it)."""
tool_config = {
"provider_name": "provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# Database query raises exception
mock_session.scalar.side_effect = Exception("Database error")
# Current implementation doesn't catch exceptions, so it propagates
with pytest.raises(Exception, match="Database error"):
AgentNode._infer_tool_provider_type(tool_config, tenant_id)

View File

@@ -1,595 +0,0 @@
import time
import uuid
from unittest.mock import Mock
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import StringSegment
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.nodes.knowledge_retrieval.entities import (
KnowledgeRetrievalNodeData,
MultipleRetrievalConfig,
RerankingModelConfig,
SingleRetrievalConfig,
)
from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.repositories.rag_retrieval_protocol import RAGRetrievalProtocol, Source
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
@pytest.fixture
def mock_graph_init_params():
"""Create mock GraphInitParams."""
return GraphInitParams(
tenant_id=str(uuid.uuid4()),
app_id=str(uuid.uuid4()),
workflow_id=str(uuid.uuid4()),
graph_config={},
user_id=str(uuid.uuid4()),
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
@pytest.fixture
def mock_graph_runtime_state():
"""Create mock GraphRuntimeState."""
variable_pool = VariablePool(
system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
@pytest.fixture
def mock_rag_retrieval():
"""Create mock RAGRetrievalProtocol."""
mock_retrieval = Mock(spec=RAGRetrievalProtocol)
mock_retrieval.knowledge_retrieval.return_value = []
mock_retrieval.llm_usage = LLMUsage.empty_usage()
return mock_retrieval
@pytest.fixture
def sample_node_data():
"""Create sample KnowledgeRetrievalNodeData."""
return KnowledgeRetrievalNodeData(
title="Knowledge Retrieval",
type="knowledge-retrieval",
dataset_ids=[str(uuid.uuid4())],
retrieval_mode="multiple",
multiple_retrieval_config=MultipleRetrievalConfig(
top_k=5,
score_threshold=0.7,
reranking_mode="reranking_model",
reranking_enable=True,
reranking_model=RerankingModelConfig(
provider="cohere",
model="rerank-v2",
),
),
)
class TestKnowledgeRetrievalNode:
"""
Test suite for KnowledgeRetrievalNode.
"""
def test_node_initialization(self, mock_graph_init_params, mock_graph_runtime_state, mock_rag_retrieval):
"""Test KnowledgeRetrievalNode initialization."""
# Arrange
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": {
"title": "Knowledge Retrieval",
"type": "knowledge-retrieval",
"dataset_ids": [str(uuid.uuid4())],
"retrieval_mode": "multiple",
},
}
# Act
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Assert
assert node.id == node_id
assert node._rag_retrieval == mock_rag_retrieval
assert node._llm_file_saver is not None
def test_run_with_no_query_or_attachment(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
sample_node_data,
):
"""Test _run returns success when no query or attachment is provided."""
# Arrange
sample_node_data.query_variable_selector = None
sample_node_data.query_attachment_selector = None
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs == {}
assert mock_rag_retrieval.knowledge_retrieval.call_count == 0
def test_run_with_query_variable_single_mode(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
):
"""Test _run with query variable in single mode."""
# Arrange
from core.workflow.nodes.llm.entities import ModelConfig
query = "What is Python?"
query_selector = ["start", "query"]
# Add query to variable pool
mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query))
node_data = KnowledgeRetrievalNodeData(
title="Knowledge Retrieval",
type="knowledge-retrieval",
dataset_ids=[str(uuid.uuid4())],
retrieval_mode="single",
query_variable_selector=query_selector,
single_retrieval_config=SingleRetrievalConfig(
model=ModelConfig(
provider="openai",
name="gpt-4",
mode="chat",
completion_params={"temperature": 0.7},
)
),
)
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": node_data.model_dump(),
}
# Mock retrieval response
mock_source = Mock(spec=Source)
mock_source.model_dump.return_value = {"content": "Python is a programming language"}
mock_rag_retrieval.knowledge_retrieval.return_value = [mock_source]
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert "result" in result.outputs
assert mock_rag_retrieval.knowledge_retrieval.called
def test_run_with_query_variable_multiple_mode(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
sample_node_data,
):
"""Test _run with query variable in multiple mode."""
# Arrange
query = "What is Python?"
query_selector = ["start", "query"]
# Add query to variable pool
mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query))
sample_node_data.query_variable_selector = query_selector
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
# Mock retrieval response
mock_source = Mock(spec=Source)
mock_source.model_dump.return_value = {"content": "Python is a programming language"}
mock_rag_retrieval.knowledge_retrieval.return_value = [mock_source]
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert "result" in result.outputs
assert mock_rag_retrieval.knowledge_retrieval.called
def test_run_with_invalid_query_variable_type(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
sample_node_data,
):
"""Test _run fails when query variable is not StringSegment."""
# Arrange
query_selector = ["start", "query"]
# Add non-string variable to variable pool
mock_graph_runtime_state.variable_pool.add(query_selector, [1, 2, 3])
sample_node_data.query_variable_selector = query_selector
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Query variable is not string type" in result.error
def test_run_with_invalid_attachment_variable_type(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
sample_node_data,
):
"""Test _run fails when attachment variable is not FileSegment or ArrayFileSegment."""
# Arrange
attachment_selector = ["start", "attachments"]
# Add non-file variable to variable pool
mock_graph_runtime_state.variable_pool.add(attachment_selector, "not a file")
sample_node_data.query_attachment_selector = attachment_selector
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Attachments variable is not array file or file type" in result.error
def test_run_with_rate_limit_exceeded(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
sample_node_data,
):
"""Test _run handles RateLimitExceededError properly."""
# Arrange
query = "What is Python?"
query_selector = ["start", "query"]
mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query))
sample_node_data.query_variable_selector = query_selector
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
# Mock retrieval to raise RateLimitExceededError
mock_rag_retrieval.knowledge_retrieval.side_effect = RateLimitExceededError(
"knowledge base request rate limit exceeded"
)
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "rate limit" in result.error.lower()
def test_run_with_generic_exception(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
sample_node_data,
):
"""Test _run handles generic exceptions properly."""
# Arrange
query = "What is Python?"
query_selector = ["start", "query"]
mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query))
sample_node_data.query_variable_selector = query_selector
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
# Mock retrieval to raise generic exception
mock_rag_retrieval.knowledge_retrieval.side_effect = Exception("Unexpected error")
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Unexpected error" in result.error
def test_extract_variable_selector_to_variable_mapping(self):
"""Test _extract_variable_selector_to_variable_mapping class method."""
# Arrange
node_id = "knowledge_node_1"
node_data = {
"type": "knowledge-retrieval",
"title": "Knowledge Retrieval",
"dataset_ids": [str(uuid.uuid4())],
"retrieval_mode": "multiple",
"query_variable_selector": ["start", "query"],
"query_attachment_selector": ["start", "attachments"],
}
graph_config = {}
# Act
mapping = KnowledgeRetrievalNode._extract_variable_selector_to_variable_mapping(
graph_config=graph_config,
node_id=node_id,
node_data=node_data,
)
# Assert
assert mapping[f"{node_id}.query"] == ["start", "query"]
assert mapping[f"{node_id}.queryAttachment"] == ["start", "attachments"]
class TestFetchDatasetRetriever:
"""
Test suite for _fetch_dataset_retriever method.
"""
def test_fetch_dataset_retriever_single_mode(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
):
"""Test _fetch_dataset_retriever in single mode."""
# Arrange
from core.workflow.nodes.llm.entities import ModelConfig
query = "What is Python?"
variables = {"query": query}
node_data = KnowledgeRetrievalNodeData(
title="Knowledge Retrieval",
type="knowledge-retrieval",
dataset_ids=[str(uuid.uuid4())],
retrieval_mode="single",
single_retrieval_config=SingleRetrievalConfig(
model=ModelConfig(
provider="openai",
name="gpt-4",
mode="chat",
completion_params={"temperature": 0.7},
)
),
)
# Mock retrieval response
mock_source = Mock(spec=Source)
mock_rag_retrieval.knowledge_retrieval.return_value = [mock_source]
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node_id = str(uuid.uuid4())
config = {"id": node_id, "data": node_data.model_dump()}
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
results, usage = node._fetch_dataset_retriever(node_data=node_data, variables=variables)
# Assert
assert len(results) == 1
assert isinstance(usage, LLMUsage)
assert mock_rag_retrieval.knowledge_retrieval.called
def test_fetch_dataset_retriever_multiple_mode_with_reranking(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
sample_node_data,
):
"""Test _fetch_dataset_retriever in multiple mode with reranking."""
# Arrange
query = "What is Python?"
variables = {"query": query}
# Mock retrieval response
mock_rag_retrieval.knowledge_retrieval.return_value = []
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
results, usage = node._fetch_dataset_retriever(node_data=sample_node_data, variables=variables)
# Assert
assert isinstance(results, list)
assert isinstance(usage, LLMUsage)
assert mock_rag_retrieval.knowledge_retrieval.called
# Verify reranking parameters via request object
call_args = mock_rag_retrieval.knowledge_retrieval.call_args
request = call_args[1]["request"]
assert request.reranking_enable is True
assert request.reranking_mode == "reranking_model"
def test_fetch_dataset_retriever_multiple_mode_without_reranking(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
):
"""Test _fetch_dataset_retriever in multiple mode without reranking."""
# Arrange
query = "What is Python?"
variables = {"query": query}
node_data = KnowledgeRetrievalNodeData(
title="Knowledge Retrieval",
type="knowledge-retrieval",
dataset_ids=[str(uuid.uuid4())],
retrieval_mode="multiple",
multiple_retrieval_config=MultipleRetrievalConfig(
top_k=5,
score_threshold=0.7,
reranking_enable=False,
reranking_mode="reranking_model",
),
)
# Mock retrieval response
mock_rag_retrieval.knowledge_retrieval.return_value = []
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": node_data.model_dump(),
}
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
results, usage = node._fetch_dataset_retriever(node_data=node_data, variables=variables)
# Assert
assert isinstance(results, list)
assert mock_rag_retrieval.knowledge_retrieval.called
# Verify reranking is disabled
call_args = mock_rag_retrieval.knowledge_retrieval.call_args
request = call_args[1]["request"]
assert request.reranking_enable is False
def test_version_method(self):
"""Test version class method."""
# Act
version = KnowledgeRetrievalNode.version()
# Assert
assert version == "1"

View File

@@ -133,8 +133,6 @@ class TestCelerySSLConfiguration:
mock_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK = 0
mock_config.ENABLE_TRIGGER_PROVIDER_REFRESH_TASK = False
mock_config.TRIGGER_PROVIDER_REFRESH_INTERVAL = 15
mock_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK = False
mock_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL = 30
with patch("extensions.ext_celery.dify_config", mock_config):
from dify_app import DifyApp

View File

@@ -4,7 +4,7 @@ from typing import Any
from uuid import uuid4
import pytest
from hypothesis import HealthCheck, given, settings
from hypothesis import given, settings
from hypothesis import strategies as st
from core.file import File, FileTransferMethod, FileType
@@ -493,7 +493,7 @@ def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]:
)
@settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None)
@settings(max_examples=50)
@given(_scalar_value())
def test_build_segment_and_extract_values_for_scalar_types(value):
seg = variable_factory.build_segment(value)
@@ -504,7 +504,7 @@ def test_build_segment_and_extract_values_for_scalar_types(value):
assert seg.value == value
@settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None)
@settings(max_examples=50)
@given(values=st.lists(_scalar_value(), max_size=20))
def test_build_segment_and_extract_values_for_array_types(values):
seg = variable_factory.build_segment(values)

View File

@@ -859,7 +859,7 @@ class TestRedisShardedSubscription:
client.get_node_from_key.assert_called_once_with("test-sharded-topic")
mock_pubsub.get_sharded_message.assert_called_once_with(
ignore_subscribe_messages=False,
timeout=1,
timeout=0.1,
target_node="node-1",
)
assert result == mock_pubsub.get_sharded_message.return_value

View File

@@ -1,250 +0,0 @@
"""
Unit tests for API Token Cache module.
"""
import json
from datetime import datetime
from unittest.mock import MagicMock, patch
from services.api_token_service import (
CACHE_KEY_PREFIX,
CACHE_NULL_TTL_SECONDS,
CACHE_TTL_SECONDS,
ApiTokenCache,
CachedApiToken,
)
class TestApiTokenCache:
"""Test cases for ApiTokenCache class."""
def setup_method(self):
"""Setup test fixtures."""
self.mock_token = MagicMock()
self.mock_token.id = "test-token-id-123"
self.mock_token.app_id = "test-app-id-456"
self.mock_token.tenant_id = "test-tenant-id-789"
self.mock_token.type = "app"
self.mock_token.token = "test-token-value-abc"
self.mock_token.last_used_at = datetime(2026, 2, 3, 10, 0, 0)
self.mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
def test_make_cache_key(self):
"""Test cache key generation."""
# Test with scope
key = ApiTokenCache._make_cache_key("my-token", "app")
assert key == f"{CACHE_KEY_PREFIX}:app:my-token"
# Test without scope
key = ApiTokenCache._make_cache_key("my-token", None)
assert key == f"{CACHE_KEY_PREFIX}:any:my-token"
def test_serialize_token(self):
"""Test token serialization."""
serialized = ApiTokenCache._serialize_token(self.mock_token)
data = json.loads(serialized)
assert data["id"] == "test-token-id-123"
assert data["app_id"] == "test-app-id-456"
assert data["tenant_id"] == "test-tenant-id-789"
assert data["type"] == "app"
assert data["token"] == "test-token-value-abc"
assert data["last_used_at"] == "2026-02-03T10:00:00"
assert data["created_at"] == "2026-01-01T00:00:00"
def test_serialize_token_with_nulls(self):
"""Test token serialization with None values."""
mock_token = MagicMock()
mock_token.id = "test-id"
mock_token.app_id = None
mock_token.tenant_id = None
mock_token.type = "dataset"
mock_token.token = "test-token"
mock_token.last_used_at = None
mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
serialized = ApiTokenCache._serialize_token(mock_token)
data = json.loads(serialized)
assert data["app_id"] is None
assert data["tenant_id"] is None
assert data["last_used_at"] is None
def test_deserialize_token(self):
"""Test token deserialization."""
cached_data = json.dumps(
{
"id": "test-id",
"app_id": "test-app",
"tenant_id": "test-tenant",
"type": "app",
"token": "test-token",
"last_used_at": "2026-02-03T10:00:00",
"created_at": "2026-01-01T00:00:00",
}
)
result = ApiTokenCache._deserialize_token(cached_data)
assert isinstance(result, CachedApiToken)
assert result.id == "test-id"
assert result.app_id == "test-app"
assert result.tenant_id == "test-tenant"
assert result.type == "app"
assert result.token == "test-token"
assert result.last_used_at == datetime(2026, 2, 3, 10, 0, 0)
assert result.created_at == datetime(2026, 1, 1, 0, 0, 0)
def test_deserialize_null_token(self):
"""Test deserialization of null token (cached miss)."""
result = ApiTokenCache._deserialize_token("null")
assert result is None
def test_deserialize_invalid_json(self):
"""Test deserialization with invalid JSON."""
result = ApiTokenCache._deserialize_token("invalid-json{")
assert result is None
@patch("services.api_token_service.redis_client")
def test_get_cache_hit(self, mock_redis):
"""Test cache hit scenario."""
cached_data = json.dumps(
{
"id": "test-id",
"app_id": "test-app",
"tenant_id": "test-tenant",
"type": "app",
"token": "test-token",
"last_used_at": "2026-02-03T10:00:00",
"created_at": "2026-01-01T00:00:00",
}
).encode("utf-8")
mock_redis.get.return_value = cached_data
result = ApiTokenCache.get("test-token", "app")
assert result is not None
assert isinstance(result, CachedApiToken)
assert result.app_id == "test-app"
mock_redis.get.assert_called_once_with(f"{CACHE_KEY_PREFIX}:app:test-token")
@patch("services.api_token_service.redis_client")
def test_get_cache_miss(self, mock_redis):
"""Test cache miss scenario."""
mock_redis.get.return_value = None
result = ApiTokenCache.get("test-token", "app")
assert result is None
mock_redis.get.assert_called_once()
@patch("services.api_token_service.redis_client")
def test_set_valid_token(self, mock_redis):
"""Test setting a valid token in cache."""
result = ApiTokenCache.set("test-token", "app", self.mock_token)
assert result is True
mock_redis.setex.assert_called_once()
args = mock_redis.setex.call_args[0]
assert args[0] == f"{CACHE_KEY_PREFIX}:app:test-token"
assert args[1] == CACHE_TTL_SECONDS
@patch("services.api_token_service.redis_client")
def test_set_null_token(self, mock_redis):
"""Test setting a null token (cache penetration prevention)."""
result = ApiTokenCache.set("invalid-token", "app", None)
assert result is True
mock_redis.setex.assert_called_once()
args = mock_redis.setex.call_args[0]
assert args[0] == f"{CACHE_KEY_PREFIX}:app:invalid-token"
assert args[1] == CACHE_NULL_TTL_SECONDS
assert args[2] == b"null"
@patch("services.api_token_service.redis_client")
def test_delete_with_scope(self, mock_redis):
"""Test deleting token cache with specific scope."""
result = ApiTokenCache.delete("test-token", "app")
assert result is True
mock_redis.delete.assert_called_once_with(f"{CACHE_KEY_PREFIX}:app:test-token")
@patch("services.api_token_service.redis_client")
def test_delete_without_scope(self, mock_redis):
"""Test deleting token cache without scope (delete all)."""
# Mock scan_iter to return an iterator of keys
mock_redis.scan_iter.return_value = iter(
[
b"api_token:app:test-token",
b"api_token:dataset:test-token",
]
)
result = ApiTokenCache.delete("test-token", None)
assert result is True
# Verify scan_iter was called with the correct pattern
mock_redis.scan_iter.assert_called_once()
call_args = mock_redis.scan_iter.call_args
assert call_args[1]["match"] == f"{CACHE_KEY_PREFIX}:*:test-token"
# Verify delete was called with all matched keys
mock_redis.delete.assert_called_once_with(
b"api_token:app:test-token",
b"api_token:dataset:test-token",
)
@patch("services.api_token_service.redis_client")
def test_redis_fallback_on_exception(self, mock_redis):
"""Test Redis fallback when Redis is unavailable."""
from redis import RedisError
mock_redis.get.side_effect = RedisError("Connection failed")
result = ApiTokenCache.get("test-token", "app")
# Should return None (fallback) instead of raising exception
assert result is None
class TestApiTokenCacheIntegration:
"""Integration test scenarios."""
@patch("services.api_token_service.redis_client")
def test_full_cache_lifecycle(self, mock_redis):
"""Test complete cache lifecycle: set -> get -> delete."""
# Setup mock token
mock_token = MagicMock()
mock_token.id = "id-123"
mock_token.app_id = "app-456"
mock_token.tenant_id = "tenant-789"
mock_token.type = "app"
mock_token.token = "token-abc"
mock_token.last_used_at = datetime(2026, 2, 3, 10, 0, 0)
mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
# 1. Set token in cache
ApiTokenCache.set("token-abc", "app", mock_token)
assert mock_redis.setex.called
# 2. Simulate cache hit
cached_data = ApiTokenCache._serialize_token(mock_token)
mock_redis.get.return_value = cached_data # bytes from model_dump_json().encode()
retrieved = ApiTokenCache.get("token-abc", "app")
assert retrieved is not None
assert isinstance(retrieved, CachedApiToken)
# 3. Delete from cache
ApiTokenCache.delete("token-abc", "app")
assert mock_redis.delete.called
@patch("services.api_token_service.redis_client")
def test_cache_penetration_prevention(self, mock_redis):
"""Test that non-existent tokens are cached as null."""
# Set null token (cache miss)
ApiTokenCache.set("non-existent-token", "app", None)
args = mock_redis.setex.call_args[0]
assert args[2] == b"null"
assert args[1] == CACHE_NULL_TTL_SECONDS # Shorter TTL for null values

View File

@@ -492,45 +492,3 @@ class TestEndUserServiceGetOrCreateEndUserByType:
# Assert
added_user = mock_session.add.call_args[0][0]
assert added_user.type == invoke_type
class TestEndUserServiceGetEndUserById:
"""Unit tests for EndUserService.get_end_user_by_id."""
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_end_user_by_id_returns_end_user(self, mock_db, mock_session_class):
tenant_id = "tenant-123"
app_id = "app-456"
end_user_id = "end-user-789"
existing_user = MagicMock(spec=EndUser)
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = existing_user
result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id)
assert result == existing_user
mock_session.query.assert_called_once_with(EndUser)
mock_query.where.assert_called_once()
assert len(mock_query.where.call_args[0]) == 3
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_end_user_by_id_returns_none(self, mock_db, mock_session_class):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None
result = EndUserService.get_end_user_by_id(tenant_id="tenant", app_id="app", end_user_id="end-user")
assert result is None

View File

@@ -4,7 +4,6 @@ from unittest.mock import MagicMock
import pytest
from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from models.model import App
from models.tools import WorkflowToolProvider
@@ -90,12 +89,6 @@ def _build_fake_session(app) -> SimpleNamespace:
return SimpleNamespace(query=query)
def _build_parameters() -> list[WorkflowToolParameterConfiguration]:
return [
WorkflowToolParameterConfiguration(name="input", description="input", form=ToolParameter.ToolParameterForm.LLM),
]
def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch):
workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "human-input"}}]})
app = SimpleNamespace(workflow=workflow)
@@ -107,6 +100,8 @@ def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch):
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db)
mock_invalidate = MagicMock()
parameters = [{"name": "input", "description": "input", "form": "form"}]
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool(
user_id="user-id",
@@ -116,7 +111,7 @@ def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch):
label="Tool",
icon={"type": "emoji", "emoji": "tool"},
description="desc",
parameters=_build_parameters(),
parameters=parameters,
)
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
@@ -139,6 +134,7 @@ def test_create_workflow_tool_success(monkeypatch):
mock_from_db = MagicMock()
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db)
parameters = [{"name": "input", "description": "input", "form": "form"}]
icon = {"type": "emoji", "emoji": "tool"}
result = workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool(
@@ -149,7 +145,7 @@ def test_create_workflow_tool_success(monkeypatch):
label="Tool",
icon=icon,
description="desc",
parameters=_build_parameters(),
parameters=parameters,
)
assert result == {"result": "success"}

View File

@@ -83,127 +83,23 @@ def mock_documents(document_ids, dataset_id):
def mock_db_session():
"""Mock database session via session_factory.create_session()."""
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
sessions = [] # Track all created sessions
# Shared mock data that all sessions will access
shared_mock_data = {"dataset": None, "documents": None, "doc_iter": None}
session = MagicMock()
# Ensure tests that expect session.close() to be called can observe it via the context manager
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
def create_session_side_effect():
session = MagicMock()
session.close = MagicMock()
def _exit_side_effect(*args, **kwargs):
session.close()
# Track commit calls
commit_mock = MagicMock()
session.commit = commit_mock
cm = MagicMock()
cm.__enter__.return_value = session
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
def _exit_side_effect(*args, **kwargs):
session.close()
cm.__exit__.side_effect = _exit_side_effect
# Support session.begin() for transactions
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
def begin_exit_side_effect(*args, **kwargs):
# Auto-commit on transaction exit (like SQLAlchemy)
session.commit()
# Also mark wrapper's commit as called
if sessions:
sessions[0].commit()
begin_cm.__exit__ = MagicMock(side_effect=begin_exit_side_effect)
session.begin = MagicMock(return_value=begin_cm)
sessions.append(session)
# Setup query with side_effect to handle both Dataset and Document queries
def query_side_effect(*args):
query = MagicMock()
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
where_result = MagicMock()
where_result.first.return_value = shared_mock_data["dataset"]
query.where = MagicMock(return_value=where_result)
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
# Support both .first() and .all() calls with chaining
where_result = MagicMock()
where_result.where = MagicMock(return_value=where_result)
# Create an iterator for .first() calls if not exists
if shared_mock_data["doc_iter"] is None:
docs = shared_mock_data["documents"] or [None]
shared_mock_data["doc_iter"] = iter(docs)
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
docs_or_empty = shared_mock_data["documents"] or []
where_result.all = MagicMock(return_value=docs_or_empty)
query.where = MagicMock(return_value=where_result)
else:
query.where = MagicMock(return_value=query)
return query
session.query = MagicMock(side_effect=query_side_effect)
return cm
mock_sf.create_session.side_effect = create_session_side_effect
# Create a wrapper that behaves like the first session but has access to all sessions
class SessionWrapper:
def __init__(self):
self._sessions = sessions
self._shared_data = shared_mock_data
# Create a default session for setup phase
self._default_session = MagicMock()
self._default_session.close = MagicMock()
self._default_session.commit = MagicMock()
# Support session.begin() for default session too
begin_cm = MagicMock()
begin_cm.__enter__.return_value = self._default_session
def default_begin_exit_side_effect(*args, **kwargs):
self._default_session.commit()
begin_cm.__exit__ = MagicMock(side_effect=default_begin_exit_side_effect)
self._default_session.begin = MagicMock(return_value=begin_cm)
def default_query_side_effect(*args):
query = MagicMock()
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
where_result = MagicMock()
where_result.first.return_value = shared_mock_data["dataset"]
query.where = MagicMock(return_value=where_result)
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
where_result = MagicMock()
where_result.where = MagicMock(return_value=where_result)
if shared_mock_data["doc_iter"] is None:
docs = shared_mock_data["documents"] or [None]
shared_mock_data["doc_iter"] = iter(docs)
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
docs_or_empty = shared_mock_data["documents"] or []
where_result.all = MagicMock(return_value=docs_or_empty)
query.where = MagicMock(return_value=where_result)
else:
query.where = MagicMock(return_value=query)
return query
self._default_session.query = MagicMock(side_effect=default_query_side_effect)
def __getattr__(self, name):
# Forward all attribute access to the first session, or default if none created yet
target_session = self._sessions[0] if self._sessions else self._default_session
return getattr(target_session, name)
@property
def all_sessions(self):
"""Access all created sessions for testing."""
return self._sessions
wrapper = SessionWrapper()
yield wrapper
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
yield session
@pytest.fixture
@@ -356,9 +252,18 @@ class TestTaskEnqueuing:
use the deprecated function.
"""
# Arrange
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
# Return documents one by one for each call
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -399,9 +304,21 @@ class TestBatchProcessing:
doc.processing_started_at = None
mock_documents.append(doc)
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# Create an iterator for documents
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
# Return documents one by one for each call
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -440,9 +357,19 @@ class TestBatchProcessing:
doc.stopped_at = None
mock_documents.append(doc)
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
mock_feature_service.get_features.return_value.billing.enabled = True
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL
@@ -480,9 +407,19 @@ class TestBatchProcessing:
doc.stopped_at = None
mock_documents.append(doc)
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
mock_feature_service.get_features.return_value.billing.enabled = True
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX
@@ -507,10 +444,7 @@ class TestBatchProcessing:
"""
# Arrange
document_ids = []
# Set shared mock data with empty documents list
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = []
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -548,9 +482,19 @@ class TestProgressTracking:
doc.processing_started_at = None
mock_documents.append(doc)
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -584,9 +528,19 @@ class TestProgressTracking:
doc.processing_started_at = None
mock_documents.append(doc)
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -681,9 +635,19 @@ class TestErrorHandling:
doc.stopped_at = None
mock_documents.append(doc)
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set up to trigger vector space limit error
mock_feature_service.get_features.return_value.billing.enabled = True
@@ -710,9 +674,17 @@ class TestErrorHandling:
Errors during indexing should be caught and logged, but not crash the task.
"""
# Arrange
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = Exception("Indexing failed")
@@ -736,9 +708,17 @@ class TestErrorHandling:
but not treated as a failure.
"""
# Arrange
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Make IndexingRunner raise DocumentIsPausedError
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused")
@@ -873,9 +853,17 @@ class TestTaskCancellation:
Session cleanup should happen in finally block.
"""
# Arrange
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -895,9 +883,17 @@ class TestTaskCancellation:
Session cleanup should happen even when errors occur.
"""
# Arrange
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = Exception("Test error")
@@ -966,7 +962,6 @@ class TestAdvancedScenarios:
document_ids = [str(uuid.uuid4()) for _ in range(3)]
# Create only 2 documents (simulate one missing)
# The new code uses .all() which will only return existing documents
mock_documents = []
for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one
doc = MagicMock(spec=Document)
@@ -976,9 +971,21 @@ class TestAdvancedScenarios:
doc.processing_started_at = None
mock_documents.append(doc)
# Set shared mock data - .all() will only return existing documents
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# Create iterator that returns None for missing document
doc_responses = [mock_documents[0], None, mock_documents[1]]
doc_iter = iter(doc_responses)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -1068,9 +1075,19 @@ class TestAdvancedScenarios:
doc.stopped_at = None
mock_documents.append(doc)
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set vector space exactly at limit
mock_feature_service.get_features.return_value.billing.enabled = True
@@ -1202,9 +1219,19 @@ class TestAdvancedScenarios:
doc.processing_started_at = None
mock_documents.append(doc)
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Billing disabled - limits should not be checked
mock_feature_service.get_features.return_value.billing.enabled = False
@@ -1246,9 +1273,19 @@ class TestIntegration:
# Set up rpop to return None for concurrency check (no more tasks)
mock_redis.rpop.side_effect = [None]
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -1284,9 +1321,19 @@ class TestIntegration:
# Set up rpop to return None for concurrency check (no more tasks)
mock_redis.rpop.side_effect = [None]
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -1368,9 +1415,17 @@ class TestEdgeCases:
mock_document.indexing_status = "waiting"
mock_document.processing_started_at = None
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = [mock_document]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: mock_document
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -1410,9 +1465,17 @@ class TestEdgeCases:
mock_document.indexing_status = "waiting"
mock_document.processing_started_at = None
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = [mock_document]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: mock_document
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -1492,9 +1555,19 @@ class TestEdgeCases:
doc.processing_started_at = None
mock_documents.append(doc)
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set vector space limit to 0 (unlimited)
mock_feature_service.get_features.return_value.billing.enabled = True
@@ -1539,9 +1612,19 @@ class TestEdgeCases:
doc.processing_started_at = None
mock_documents.append(doc)
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set negative vector space limit
mock_feature_service.get_features.return_value.billing.enabled = True
@@ -1592,9 +1675,19 @@ class TestPerformanceScenarios:
doc.processing_started_at = None
mock_documents.append(doc)
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Configure billing with sufficient limits
mock_feature_service.get_features.return_value.billing.enabled = True
@@ -1733,9 +1826,19 @@ class TestRobustness:
doc.processing_started_at = None
mock_documents.append(doc)
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error")
@@ -1763,7 +1866,7 @@ class TestRobustness:
- No exceptions occur
Expected behavior:
- All database sessions are closed
- Database session is closed
- No connection leaks
"""
# Arrange
@@ -1776,9 +1879,19 @@ class TestRobustness:
doc.processing_started_at = None
mock_documents.append(doc)
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -1786,11 +1899,10 @@ class TestRobustness:
# Act
_document_indexing(dataset_id, document_ids)
# Assert - All created sessions should be closed
# The code creates multiple sessions: validation, Phase 1 (parsing), Phase 3 (summary)
assert len(mock_db_session.all_sessions) >= 1
for session in mock_db_session.all_sessions:
assert session.close.called, "All sessions should be closed"
# Assert
assert mock_db_session.close.called
# Verify close is called exactly once
assert mock_db_session.close.call_count == 1
def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis):
"""

View File

@@ -109,87 +109,25 @@ def mock_document_segments(document_id):
@pytest.fixture
def mock_db_session():
"""Mock database session via session_factory.create_session().
After session split refactor, the code calls create_session() multiple times.
This fixture creates shared query mocks so all sessions use the same
query configuration, simulating database persistence across sessions.
The fixture automatically converts side_effect to cycle to prevent StopIteration.
Tests configure mocks the same way as before, but behind the scenes the values
are cycled infinitely for all sessions.
"""
from itertools import cycle
"""Mock database session via session_factory.create_session()."""
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
sessions = []
session = MagicMock()
# Ensure tests can observe session.close() via context manager teardown
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
# Shared query mocks - all sessions use these
shared_query = MagicMock()
shared_filter_by = MagicMock()
shared_scalars_result = MagicMock()
def _exit_side_effect(*args, **kwargs):
session.close()
# Create custom first mock that auto-cycles side_effect
class CyclicMock(MagicMock):
def __setattr__(self, name, value):
if name == "side_effect" and value is not None:
# Convert list/tuple to infinite cycle
if isinstance(value, (list, tuple)):
value = cycle(value)
super().__setattr__(name, value)
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
shared_query.where.return_value.first = CyclicMock()
shared_filter_by.first = CyclicMock()
def _create_session():
"""Create a new mock session for each create_session() call."""
session = MagicMock()
session.close = MagicMock()
session.commit = MagicMock()
# Mock session.begin() context manager
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
def _begin_exit_side_effect(exc_type, exc, tb):
# commit on success
if exc_type is None:
session.commit()
# return False to propagate exceptions
return False
begin_cm.__exit__.side_effect = _begin_exit_side_effect
session.begin.return_value = begin_cm
# Mock create_session() context manager
cm = MagicMock()
cm.__enter__.return_value = session
def _exit_side_effect(exc_type, exc, tb):
session.close()
return False
cm.__exit__.side_effect = _exit_side_effect
# All sessions use the same shared query mocks
session.query.return_value = shared_query
shared_query.where.return_value = shared_query
shared_query.filter_by.return_value = shared_filter_by
session.scalars.return_value = shared_scalars_result
sessions.append(session)
# Attach helpers on the first created session for assertions across all sessions
if len(sessions) == 1:
session.get_all_sessions = lambda: sessions
session.any_close_called = lambda: any(s.close.called for s in sessions)
session.any_commit_called = lambda: any(s.commit.called for s in sessions)
return cm
mock_sf.create_session.side_effect = _create_session
# Create first session and return it
_create_session()
yield sessions[0]
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
session.scalars.return_value = MagicMock()
yield session
@pytest.fixture
@@ -248,8 +186,8 @@ class TestDocumentIndexingSyncTask:
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert - at least one session should have been closed
assert mock_db_session.any_close_called()
# Assert
mock_db_session.close.assert_called_once()
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when notion_workspace_id is missing."""
@@ -292,7 +230,6 @@ class TestDocumentIndexingSyncTask:
"""Test that task handles missing credentials by updating document status."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_datasource_provider_service.get_datasource_credentials.return_value = None
# Act
@@ -302,8 +239,8 @@ class TestDocumentIndexingSyncTask:
assert mock_document.indexing_status == "error"
assert "Datasource credential not found" in mock_document.error
assert mock_document.stopped_at is not None
assert mock_db_session.any_commit_called()
assert mock_db_session.any_close_called()
mock_db_session.commit.assert_called()
mock_db_session.close.assert_called()
def test_page_not_updated(
self,
@@ -317,7 +254,6 @@ class TestDocumentIndexingSyncTask:
"""Test that task does nothing when page has not been updated."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
# Return same time as stored in document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
@@ -327,8 +263,8 @@ class TestDocumentIndexingSyncTask:
# Assert
# Document status should remain unchanged
assert mock_document.indexing_status == "completed"
# At least one session should have been closed via context manager teardown
assert mock_db_session.any_close_called()
# Session should still be closed via context manager teardown
assert mock_db_session.close.called
def test_successful_sync_when_page_updated(
self,
@@ -345,20 +281,7 @@ class TestDocumentIndexingSyncTask:
):
"""Test successful sync flow when Notion page has been updated."""
# Arrange
# Set exact sequence of returns across calls to `.first()`:
# 1) document (initial fetch)
# 2) dataset (pre-check)
# 3) dataset (cleaning phase)
# 4) document (pre-indexing update)
# 5) document (indexing runner fetch)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
mock_document,
mock_document,
]
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
# NotionExtractor returns updated time
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
@@ -376,40 +299,28 @@ class TestDocumentIndexingSyncTask:
mock_processor.clean.assert_called_once()
# Verify segments were deleted from database in batch (DELETE FROM document_segments)
# Aggregate execute calls across all created sessions
execute_sqls = []
for s in mock_db_session.get_all_sessions():
execute_sqls.extend([" ".join(str(c[0][0]).split()) for c in s.execute.call_args_list])
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
# Verify indexing runner was called
mock_indexing_runner.run.assert_called_once_with([mock_document])
# Verify session operations (across any created session)
assert mock_db_session.any_commit_called()
assert mock_db_session.any_close_called()
# Verify session operations
assert mock_db_session.commit.called
mock_db_session.close.assert_called_once()
def test_dataset_not_found_during_cleaning(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_indexing_runner,
mock_document,
dataset_id,
document_id,
):
"""Test that task handles dataset not found during cleaning phase."""
# Arrange
# Sequence: document (initial), dataset (pre-check), None (cleaning), document (update), document (indexing)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
None,
mock_document,
mock_document,
]
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None]
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
@@ -418,8 +329,8 @@ class TestDocumentIndexingSyncTask:
# Assert
# Document should still be set to parsing
assert mock_document.indexing_status == "parsing"
# At least one session should be closed after error
assert mock_db_session.any_close_called()
# Session should be closed after error
mock_db_session.close.assert_called_once()
def test_cleaning_error_continues_to_indexing(
self,
@@ -435,14 +346,8 @@ class TestDocumentIndexingSyncTask:
):
"""Test that indexing continues even if cleaning fails."""
# Arrange
from itertools import cycle
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
# Make the cleaning step fail but not the segment fetch
processor = mock_index_processor_factory.return_value.init_index_processor.return_value
processor.clean.side_effect = Exception("Cleaning error")
mock_db_session.scalars.return_value.all.return_value = []
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error")
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
@@ -451,7 +356,7 @@ class TestDocumentIndexingSyncTask:
# Assert
# Indexing should still be attempted despite cleaning error
mock_indexing_runner.run.assert_called_once_with([mock_document])
assert mock_db_session.any_close_called()
mock_db_session.close.assert_called_once()
def test_indexing_runner_document_paused_error(
self,
@@ -468,10 +373,7 @@ class TestDocumentIndexingSyncTask:
):
"""Test that DocumentIsPausedError is handled gracefully."""
# Arrange
from itertools import cycle
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
@@ -481,7 +383,7 @@ class TestDocumentIndexingSyncTask:
# Assert
# Session should be closed after handling error
assert mock_db_session.any_close_called()
mock_db_session.close.assert_called_once()
def test_indexing_runner_general_error(
self,
@@ -498,10 +400,7 @@ class TestDocumentIndexingSyncTask:
):
"""Test that general exceptions during indexing are handled."""
# Arrange
from itertools import cycle
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = Exception("Indexing error")
@@ -511,7 +410,7 @@ class TestDocumentIndexingSyncTask:
# Assert
# Session should be closed after error
assert mock_db_session.any_close_called()
mock_db_session.close.assert_called_once()
def test_notion_extractor_initialized_with_correct_params(
self,
@@ -618,14 +517,7 @@ class TestDocumentIndexingSyncTask:
):
"""Test that index processor clean is called with correct parameters."""
# Arrange
# Sequence: document (initial), dataset (pre-check), dataset (cleaning), document (update), document (indexing)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
mock_document,
mock_document,
]
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"

View File

@@ -1,5 +1,4 @@
import base64
from decimal import Decimal
from unittest.mock import Mock, patch
import pytest
@@ -10,10 +9,8 @@ from core.mcp.types import (
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
TextResourceContents,
)
from core.model_runtime.entities.llm_entities import LLMUsage
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
@@ -123,231 +120,3 @@ class TestMCPToolInvoke:
# Validate values
values = {m.message.variable_name: m.message.variable_value for m in var_msgs}
assert values == {"a": 1, "b": "x"}
class TestMCPToolUsageExtraction:
"""Test usage metadata extraction from MCP tool results."""
def test_extract_usage_dict_from_direct_usage_field(self) -> None:
"""Test extraction when usage is directly in meta.usage field."""
meta = {
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"total_price": "0.001",
"currency": "USD",
}
}
usage_dict = MCPTool._extract_usage_dict(meta)
assert usage_dict is not None
assert usage_dict["prompt_tokens"] == 100
assert usage_dict["completion_tokens"] == 50
assert usage_dict["total_tokens"] == 150
assert usage_dict["total_price"] == "0.001"
assert usage_dict["currency"] == "USD"
def test_extract_usage_dict_from_nested_metadata(self) -> None:
"""Test extraction when usage is nested in meta.metadata.usage."""
meta = {
"metadata": {
"usage": {
"prompt_tokens": 200,
"completion_tokens": 100,
"total_tokens": 300,
}
}
}
usage_dict = MCPTool._extract_usage_dict(meta)
assert usage_dict is not None
assert usage_dict["prompt_tokens"] == 200
assert usage_dict["total_tokens"] == 300
def test_extract_usage_dict_from_flat_token_fields(self) -> None:
"""Test extraction when token counts are directly in meta."""
meta = {
"prompt_tokens": 150,
"completion_tokens": 75,
"total_tokens": 225,
"currency": "EUR",
}
usage_dict = MCPTool._extract_usage_dict(meta)
assert usage_dict is not None
assert usage_dict["prompt_tokens"] == 150
assert usage_dict["completion_tokens"] == 75
assert usage_dict["total_tokens"] == 225
assert usage_dict["currency"] == "EUR"
def test_extract_usage_dict_recursive(self) -> None:
"""Test recursive search through nested structures."""
meta = {
"custom": {
"nested": {
"usage": {
"total_tokens": 500,
"prompt_tokens": 300,
"completion_tokens": 200,
}
}
}
}
usage_dict = MCPTool._extract_usage_dict(meta)
assert usage_dict is not None
assert usage_dict["total_tokens"] == 500
def test_extract_usage_dict_from_list(self) -> None:
"""Test extraction from nested list structures."""
meta = {
"items": [
{"usage": {"total_tokens": 100}},
{"other": "data"},
]
}
usage_dict = MCPTool._extract_usage_dict(meta)
assert usage_dict is not None
assert usage_dict["total_tokens"] == 100
def test_extract_usage_dict_returns_none_when_missing(self) -> None:
"""Test that None is returned when no usage data is present."""
meta = {"other": "data", "custom": {"nested": {"value": 123}}}
usage_dict = MCPTool._extract_usage_dict(meta)
assert usage_dict is None
def test_extract_usage_dict_empty_meta(self) -> None:
"""Test with empty meta dict."""
usage_dict = MCPTool._extract_usage_dict({})
assert usage_dict is None
def test_derive_usage_from_result_with_meta(self) -> None:
"""Test _derive_usage_from_result with populated meta."""
meta = {
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"total_price": "0.0015",
"currency": "USD",
}
}
result = CallToolResult(content=[], _meta=meta)
usage = MCPTool._derive_usage_from_result(result)
assert isinstance(usage, LLMUsage)
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
assert usage.total_tokens == 150
assert usage.total_price == Decimal("0.0015")
assert usage.currency == "USD"
def test_derive_usage_from_result_without_meta(self) -> None:
"""Test _derive_usage_from_result with no meta returns empty usage."""
result = CallToolResult(content=[], meta=None)
usage = MCPTool._derive_usage_from_result(result)
assert isinstance(usage, LLMUsage)
assert usage.total_tokens == 0
assert usage.prompt_tokens == 0
assert usage.completion_tokens == 0
def test_derive_usage_from_result_calculates_total_tokens(self) -> None:
"""Test that total_tokens is calculated when missing."""
meta = {
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
# total_tokens is missing
}
}
result = CallToolResult(content=[], _meta=meta)
usage = MCPTool._derive_usage_from_result(result)
assert usage.total_tokens == 150 # 100 + 50
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
def test_invoke_sets_latest_usage_from_meta(self) -> None:
"""Test that _invoke sets _latest_usage from result meta."""
tool = _make_mcp_tool()
meta = {
"usage": {
"prompt_tokens": 200,
"completion_tokens": 100,
"total_tokens": 300,
"total_price": "0.003",
"currency": "USD",
}
}
result = CallToolResult(content=[TextContent(type="text", text="test")], _meta=meta)
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
list(tool._invoke(user_id="test_user", tool_parameters={}))
# Verify latest_usage was set correctly
assert tool.latest_usage.prompt_tokens == 200
assert tool.latest_usage.completion_tokens == 100
assert tool.latest_usage.total_tokens == 300
assert tool.latest_usage.total_price == Decimal("0.003")
def test_invoke_with_no_meta_returns_empty_usage(self) -> None:
"""Test that _invoke returns empty usage when no meta is present."""
tool = _make_mcp_tool()
result = CallToolResult(content=[TextContent(type="text", text="test")], _meta=None)
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
list(tool._invoke(user_id="test_user", tool_parameters={}))
# Verify latest_usage is empty
assert tool.latest_usage.total_tokens == 0
assert tool.latest_usage.prompt_tokens == 0
assert tool.latest_usage.completion_tokens == 0
def test_latest_usage_property_returns_llm_usage(self) -> None:
"""Test that latest_usage property returns LLMUsage instance."""
tool = _make_mcp_tool()
assert isinstance(tool.latest_usage, LLMUsage)
def test_initial_usage_is_empty(self) -> None:
"""Test that MCPTool is initialized with empty usage."""
tool = _make_mcp_tool()
assert tool.latest_usage.total_tokens == 0
assert tool.latest_usage.prompt_tokens == 0
assert tool.latest_usage.completion_tokens == 0
assert tool.latest_usage.total_price == Decimal(0)
@pytest.mark.parametrize(
"meta_data",
[
# Direct usage field
{"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}},
# Nested metadata
{"metadata": {"usage": {"total_tokens": 100}}},
# Flat token fields
{"total_tokens": 50, "prompt_tokens": 30, "completion_tokens": 20},
# With price info
{
"usage": {
"total_tokens": 150,
"total_price": "0.002",
"currency": "EUR",
}
},
# Deep nested
{"level1": {"level2": {"usage": {"total_tokens": 200}}}},
],
)
def test_various_meta_formats(self, meta_data) -> None:
"""Test that various meta formats are correctly parsed."""
result = CallToolResult(content=[], _meta=meta_data)
usage = MCPTool._derive_usage_from_result(result)
assert isinstance(usage, LLMUsage)
# Should have at least some usage data
if meta_data.get("usage", {}).get("total_tokens") or meta_data.get("total_tokens"):
expected_total = (
meta_data.get("usage", {}).get("total_tokens")
or meta_data.get("total_tokens")
or meta_data.get("metadata", {}).get("usage", {}).get("total_tokens")
or meta_data.get("level1", {}).get("level2", {}).get("usage", {}).get("total_tokens")
)
if expected_total:
assert usage.total_tokens == expected_total

11
api/uv.lock generated
View File

@@ -1653,7 +1653,7 @@ requires-dist = [
{ name = "starlette", specifier = "==0.49.1" },
{ name = "tiktoken", specifier = "~=0.9.0" },
{ name = "transformers", specifier = "~=4.56.1" },
{ name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.18.18" },
{ name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" },
{ name = "weave", specifier = ">=0.52.16" },
{ name = "weaviate-client", specifier = "==4.17.0" },
{ name = "webvtt-py", specifier = "~=0.5.1" },
@@ -6814,12 +6814,12 @@ wheels = [
[[package]]
name = "unstructured"
version = "0.18.31"
version = "0.16.25"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "backoff" },
{ name = "beautifulsoup4" },
{ name = "charset-normalizer" },
{ name = "chardet" },
{ name = "dataclasses-json" },
{ name = "emoji" },
{ name = "filetype" },
@@ -6827,7 +6827,6 @@ dependencies = [
{ name = "langdetect" },
{ name = "lxml" },
{ name = "nltk" },
{ name = "numba" },
{ name = "numpy" },
{ name = "psutil" },
{ name = "python-iso639" },
@@ -6840,9 +6839,9 @@ dependencies = [
{ name = "unstructured-client" },
{ name = "wrapt" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a9/5f/64285bd69a538bc28753f1423fcaa9d64cd79a9e7c097171b1f0d27e9cdb/unstructured-0.18.31.tar.gz", hash = "sha256:af4bbe32d1894ae6e755f0da6fc0dd307a1d0adeebe0e7cc6278f6cf744339ca", size = 1707700, upload-time = "2026-01-27T15:33:05.378Z" }
sdist = { url = "https://files.pythonhosted.org/packages/64/31/98c4c78e305d1294888adf87fd5ee30577a4c393951341ca32b43f167f1e/unstructured-0.16.25.tar.gz", hash = "sha256:73b9b0f51dbb687af572ecdb849a6811710b9cac797ddeab8ee80fa07d8aa5e6", size = 1683097, upload-time = "2025-03-07T11:19:39.507Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c8/4a/9c43f39d9e443c9bc3f2e379b305bca27110adc653b071221b3132c18de5/unstructured-0.18.31-py3-none-any.whl", hash = "sha256:fab4641176cb9b192ed38048758aa0d9843121d03626d18f42275afb31e5b2d3", size = 1794889, upload-time = "2026-01-27T15:33:03.136Z" },
{ url = "https://files.pythonhosted.org/packages/12/4f/ad08585b5c8a33c82ea119494c4d3023f4796958c56e668b15cc282ec0a0/unstructured-0.16.25-py3-none-any.whl", hash = "sha256:14719ccef2830216cf1c5bf654f75e2bf07b17ca5dcee9da5ac74618130fd337", size = 1769286, upload-time = "2025-03-07T11:19:37.299Z" },
]
[package.optional-dependencies]

View File

@@ -109,7 +109,6 @@ const AgentTools: FC = () => {
tool_parameters: paramsWithDefaultValue,
notAuthor: !tool.is_team_authorization,
enabled: true,
type: tool.provider_type as CollectionType,
}
}
const handleSelectTool = (tool: ToolDefaultValue) => {

View File

@@ -194,11 +194,11 @@ const ConfigContent: FC<Props> = ({
</div>
{type === RETRIEVE_TYPE.multiWay && (
<>
<div className="my-2 flex flex-col items-center py-1">
<div className="system-xs-semibold-uppercase mb-2 mr-2 shrink-0 text-text-secondary">
<div className="my-2 flex h-6 items-center py-1">
<div className="system-xs-semibold-uppercase mr-2 shrink-0 text-text-secondary">
{t('rerankSettings', { ns: 'dataset' })}
</div>
<Divider bgStyle="gradient" className="m-0 !h-px" />
<Divider bgStyle="gradient" className="mx-0 !h-px" />
</div>
{
selectedDatasetsMode.inconsistentEmbeddingModel

View File

@@ -7,7 +7,6 @@ import dayjs from 'dayjs'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import DatePicker from '@/app/components/base/date-and-time-picker/date-picker'
import { useAppContext } from '@/context/app-context'
import useTimestamp from '@/hooks/use-timestamp'
import { cn } from '@/utils/classnames'
@@ -22,7 +21,7 @@ const WrappedDatePicker = ({
onChange,
}: Props) => {
const { t } = useTranslation()
const { userProfile: { timezone } } = useAppContext()
// const { userProfile: { timezone } } = useAppContext()
const { formatTime: formatTimestamp } = useTimestamp()
const handleDateChange = useCallback((date?: dayjs.Dayjs) => {
@@ -65,7 +64,6 @@ const WrappedDatePicker = ({
return (
<DatePicker
value={dayjs(value ? value * 1000 : Date.now())}
timezone={timezone}
onChange={handleDateChange}
onClear={handleDateChange}
renderTrigger={renderTrigger}

View File

@@ -273,71 +273,6 @@ The text generation application offers non-session support and is ideal for tran
</Row>
---
<Heading
url='/end-users/:end_user_id'
method='GET'
title='Get End User'
name='#end-user'
/>
<Row>
<Col>
Retrieve an end user by ID.
This is useful when other APIs return an end-user ID (e.g. `created_by` from File Upload).
### Path Parameters
- `end_user_id` (uuid) Required
End user ID.
### Response
Returns an EndUser object.
- `id` (uuid) ID
- `tenant_id` (uuid) Tenant ID
- `app_id` (uuid) App ID
- `type` (string) End user type
- `external_user_id` (string) External user ID
- `name` (string) Name
- `is_anonymous` (boolean) Whether anonymous
- `session_id` (string) Session ID
- `created_at` (string) ISO 8601 datetime
- `updated_at` (string) ISO 8601 datetime
### Errors
- 404, `end_user_not_found`, end user not found
- 500, internal server error
</Col>
<Col sticky>
### Request Example
<CodeGroup
title="Request"
tag="GET"
label="/end-users/:end_user_id"
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
--header 'Authorization: Bearer {api_key}'`}
/>
### Response Example
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
"type": "service_api",
"external_user_id": "abc-123",
"name": "Alice",
"is_anonymous": false,
"session_id": "abc-123",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
```
</CodeGroup>
</Col>
</Row>
---
<Heading
url='/files/:file_id/preview'
method='GET'

View File

@@ -272,71 +272,6 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
</Row>
---
<Heading
url='/end-users/:end_user_id'
method='GET'
title='エンドユーザーを取得'
name='#end-user'
/>
<Row>
<Col>
エンドユーザー ID からエンドユーザー情報を取得します。
他の API がエンドユーザー IDファイルアップロードの `created_by`)を返す場合に利用できます。
### パスパラメータ
- `end_user_id` (uuid) 必須
エンドユーザー ID。
### レスポンス
EndUser オブジェクトを返します。
- `id` (uuid) ID
- `tenant_id` (uuid) テナント ID
- `app_id` (uuid) アプリ ID
- `type` (string) エンドユーザー種別
- `external_user_id` (string) 外部ユーザー ID
- `name` (string) 名前
- `is_anonymous` (boolean) 匿名ユーザーかどうか
- `session_id` (string) セッション ID
- `created_at` (string) ISO 8601 日時
- `updated_at` (string) ISO 8601 日時
### エラー
- 404, `end_user_not_found`, エンドユーザーが見つかりません
- 500, 内部サーバーエラー
</Col>
<Col sticky>
### リクエスト例
<CodeGroup
title="Request"
tag="GET"
label="/end-users/:end_user_id"
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
--header 'Authorization: Bearer {api_key}'`}
/>
### レスポンス例
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
"type": "service_api",
"external_user_id": "abc-123",
"name": "Alice",
"is_anonymous": false,
"session_id": "abc-123",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
```
</CodeGroup>
</Col>
</Row>
---
<Heading
url='/files/:file_id/preview'
method='GET'

View File

@@ -249,69 +249,6 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
</Row>
---
<Heading
url='/end-users/:end_user_id'
method='GET'
title='获取终端用户'
name='#end-user'
/>
<Row>
<Col>
通过终端用户 ID 获取终端用户信息。
当其他 API 返回终端用户 ID例如上传文件接口返回的 `created_by`)时,可使用该接口查询对应的终端用户信息。
### 路径参数
- `end_user_id` (uuid) 必需
终端用户 ID。
### Response
返回 EndUser 对象。
- `id` (uuid) ID
- `tenant_id` (uuid) 工作空间TenantID
- `app_id` (uuid) 应用 ID
- `type` (string) 终端用户类型
- `external_user_id` (string) 外部用户 ID
- `name` (string) 名称
- `is_anonymous` (boolean) 是否匿名
- `session_id` (string) 会话 ID
- `created_at` (string) ISO 8601 时间
- `updated_at` (string) ISO 8601 时间
### Errors
- 404`end_user_not_found`,终端用户不存在
- 500内部服务器错误
</Col>
<Col sticky>
<CodeGroup
title="Request"
tag="GET"
label="/end-users/:end_user_id"
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
--header 'Authorization: Bearer {api_key}'`}
/>
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
"type": "service_api",
"external_user_id": "abc-123",
"name": "Alice",
"is_anonymous": false,
"session_id": "abc-123",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
```
</CodeGroup>
</Col>
</Row>
---
<Heading
url='/files/:file_id/preview'
method='GET'

View File

@@ -392,71 +392,6 @@ Chat applications support session persistence, allowing previous chat history to
</Row>
---
<Heading
url='/end-users/:end_user_id'
method='GET'
title='Get End User'
name='#end-user'
/>
<Row>
<Col>
Retrieve an end user by ID.
This is useful when other APIs return an end-user ID (e.g. `created_by` from File Upload).
### Path Parameters
- `end_user_id` (uuid) Required
End user ID.
### Response
Returns an EndUser object.
- `id` (uuid) ID
- `tenant_id` (uuid) Tenant ID
- `app_id` (uuid) App ID
- `type` (string) End user type
- `external_user_id` (string) External user ID
- `name` (string) Name
- `is_anonymous` (boolean) Whether anonymous
- `session_id` (string) Session ID
- `created_at` (string) ISO 8601 datetime
- `updated_at` (string) ISO 8601 datetime
### Errors
- 404, `end_user_not_found`, end user not found
- 500, internal server error
</Col>
<Col sticky>
### Request Example
<CodeGroup
title="Request"
tag="GET"
label="/end-users/:end_user_id"
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
--header 'Authorization: Bearer {api_key}'`}
/>
### Response Example
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
"type": "service_api",
"external_user_id": "abc-123",
"name": "Alice",
"is_anonymous": false,
"session_id": "abc-123",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
```
</CodeGroup>
</Col>
</Row>
---
<Heading
url='/files/:file_id/preview'
method='GET'

View File

@@ -393,71 +393,6 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
</Row>
---
<Heading
url='/end-users/:end_user_id'
method='GET'
title='エンドユーザーを取得'
name='#end-user'
/>
<Row>
<Col>
エンドユーザー ID からエンドユーザー情報を取得します。
他の API がエンドユーザー IDファイルアップロードの `created_by`)を返す場合に利用できます。
### パスパラメータ
- `end_user_id` (uuid) 必須
エンドユーザー ID。
### レスポンス
EndUser オブジェクトを返します。
- `id` (uuid) ID
- `tenant_id` (uuid) テナント ID
- `app_id` (uuid) アプリ ID
- `type` (string) エンドユーザー種別
- `external_user_id` (string) 外部ユーザー ID
- `name` (string) 名前
- `is_anonymous` (boolean) 匿名ユーザーかどうか
- `session_id` (string) セッション ID
- `created_at` (string) ISO 8601 日時
- `updated_at` (string) ISO 8601 日時
### エラー
- 404, `end_user_not_found`, エンドユーザーが見つかりません
- 500, 内部サーバーエラー
</Col>
<Col sticky>
### リクエスト例
<CodeGroup
title="Request"
tag="GET"
label="/end-users/:end_user_id"
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
--header 'Authorization: Bearer {api_key}'`}
/>
### レスポンス例
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
"type": "service_api",
"external_user_id": "abc-123",
"name": "Alice",
"is_anonymous": false,
"session_id": "abc-123",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
```
</CodeGroup>
</Col>
</Row>
---
<Heading
url='/files/:file_id/preview'
method='GET'

View File

@@ -390,69 +390,6 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
</Row>
---
<Heading
url='/end-users/:end_user_id'
method='GET'
title='获取终端用户'
name='#end-user'
/>
<Row>
<Col>
通过终端用户 ID 获取终端用户信息。
当其他 API 返回终端用户 ID例如上传文件接口返回的 `created_by`)时,可使用该接口查询对应的终端用户信息。
### 路径参数
- `end_user_id` (uuid) 必需
终端用户 ID。
### Response
返回 EndUser 对象。
- `id` (uuid) ID
- `tenant_id` (uuid) 工作空间TenantID
- `app_id` (uuid) 应用 ID
- `type` (string) 终端用户类型
- `external_user_id` (string) 外部用户 ID
- `name` (string) 名称
- `is_anonymous` (boolean) 是否匿名
- `session_id` (string) 会话 ID
- `created_at` (string) ISO 8601 时间
- `updated_at` (string) ISO 8601 时间
### Errors
- 404`end_user_not_found`,终端用户不存在
- 500内部服务器错误
</Col>
<Col sticky>
<CodeGroup
title="Request"
tag="GET"
label="/end-users/:end_user_id"
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
--header 'Authorization: Bearer {api_key}'`}
/>
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
"type": "service_api",
"external_user_id": "abc-123",
"name": "Alice",
"is_anonymous": false,
"session_id": "abc-123",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
```
</CodeGroup>
</Col>
</Row>
---
<Heading
url='/files/:file_id/preview'
method='GET'

View File

@@ -362,71 +362,6 @@ Chat applications support session persistence, allowing previous chat history to
</Row>
---
<Heading
url='/end-users/:end_user_id'
method='GET'
title='Get End User'
name='#end-user'
/>
<Row>
<Col>
Retrieve an end user by ID.
This is useful when other APIs return an end-user ID (e.g. `created_by` from File Upload).
### Path Parameters
- `end_user_id` (uuid) Required
End user ID.
### Response
Returns an EndUser object.
- `id` (uuid) ID
- `tenant_id` (uuid) Tenant ID
- `app_id` (uuid) App ID
- `type` (string) End user type
- `external_user_id` (string) External user ID
- `name` (string) Name
- `is_anonymous` (boolean) Whether anonymous
- `session_id` (string) Session ID
- `created_at` (string) ISO 8601 datetime
- `updated_at` (string) ISO 8601 datetime
### Errors
- 404, `end_user_not_found`, end user not found
- 500, internal server error
</Col>
<Col sticky>
### Request Example
<CodeGroup
title="Request"
tag="GET"
label="/end-users/:end_user_id"
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
--header 'Authorization: Bearer {api_key}'`}
/>
### Response Example
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
"type": "service_api",
"external_user_id": "abc-123",
"name": "Alice",
"is_anonymous": false,
"session_id": "abc-123",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
```
</CodeGroup>
</Col>
</Row>
---
<Heading
url='/files/:file_id/preview'
method='GET'

View File

@@ -362,71 +362,6 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
</Row>
---
<Heading
url='/end-users/:end_user_id'
method='GET'
title='エンドユーザーを取得'
name='#end-user'
/>
<Row>
<Col>
エンドユーザー ID からエンドユーザー情報を取得します。
他の API がエンドユーザー IDファイルアップロードの `created_by`)を返す場合に利用できます。
### パスパラメータ
- `end_user_id` (uuid) 必須
エンドユーザー ID。
### レスポンス
EndUser オブジェクトを返します。
- `id` (uuid) ID
- `tenant_id` (uuid) テナント ID
- `app_id` (uuid) アプリ ID
- `type` (string) エンドユーザー種別
- `external_user_id` (string) 外部ユーザー ID
- `name` (string) 名前
- `is_anonymous` (boolean) 匿名ユーザーかどうか
- `session_id` (string) セッション ID
- `created_at` (string) ISO 8601 日時
- `updated_at` (string) ISO 8601 日時
### エラー
- 404, `end_user_not_found`, エンドユーザーが見つかりません
- 500, 内部サーバーエラー
</Col>
<Col sticky>
### リクエスト例
<CodeGroup
title="Request"
tag="GET"
label="/end-users/:end_user_id"
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
--header 'Authorization: Bearer {api_key}'`}
/>
### レスポンス例
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
"type": "service_api",
"external_user_id": "abc-123",
"name": "Alice",
"is_anonymous": false,
"session_id": "abc-123",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
```
</CodeGroup>
</Col>
</Row>
---
<Heading
url='/files/:file_id/preview'
method='GET'

View File

@@ -368,69 +368,6 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
</Row>
---
<Heading
url='/end-users/:end_user_id'
method='GET'
title='获取终端用户'
name='#end-user'
/>
<Row>
<Col>
通过终端用户 ID 获取终端用户信息。
当其他 API 返回终端用户 ID例如上传文件接口返回的 `created_by`)时,可使用该接口查询对应的终端用户信息。
### 路径参数
- `end_user_id` (uuid) 必需
终端用户 ID。
### Response
返回 EndUser 对象。
- `id` (uuid) ID
- `tenant_id` (uuid) 工作空间TenantID
- `app_id` (uuid) 应用 ID
- `type` (string) 终端用户类型
- `external_user_id` (string) 外部用户 ID
- `name` (string) 名称
- `is_anonymous` (boolean) 是否匿名
- `session_id` (string) 会话 ID
- `created_at` (string) ISO 8601 时间
- `updated_at` (string) ISO 8601 时间
### Errors
- 404`end_user_not_found`,终端用户不存在
- 500内部服务器错误
</Col>
<Col sticky>
<CodeGroup
title="Request"
tag="GET"
label="/end-users/:end_user_id"
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
--header 'Authorization: Bearer {api_key}'`}
/>
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
"type": "service_api",
"external_user_id": "abc-123",
"name": "Alice",
"is_anonymous": false,
"session_id": "abc-123",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
```
</CodeGroup>
</Col>
</Row>
---
<Heading
url='/files/:file_id/preview'
method='GET'

View File

@@ -740,71 +740,6 @@ Workflow applications offers non-session support and is ideal for translation, a
---
<Heading
url='/end-users/:end_user_id'
method='GET'
title='Get End User'
name='#end-user'
/>
<Row>
<Col>
Retrieve an end user by ID.
This is useful when other APIs return an end-user ID (e.g. `created_by` from File Upload).
### Path Parameters
- `end_user_id` (uuid) Required
End user ID.
### Response
Returns an EndUser object.
- `id` (uuid) ID
- `tenant_id` (uuid) Tenant ID
- `app_id` (uuid) App ID
- `type` (string) End user type
- `external_user_id` (string) External user ID
- `name` (string) Name
- `is_anonymous` (boolean) Whether anonymous
- `session_id` (string) Session ID
- `created_at` (string) ISO 8601 datetime
- `updated_at` (string) ISO 8601 datetime
### Errors
- 404, `end_user_not_found`, end user not found
- 500, internal server error
</Col>
<Col sticky>
### Request Example
<CodeGroup
title="Request"
tag="GET"
label="/end-users/:end_user_id"
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
--header 'Authorization: Bearer {api_key}'`}
/>
### Response Example
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
"type": "service_api",
"external_user_id": "abc-123",
"name": "Alice",
"is_anonymous": false,
"session_id": "abc-123",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
```
</CodeGroup>
</Col>
</Row>
---
<Heading
url='/workflows/logs'
method='GET'

Some files were not shown because too many files have changed in this diff Show More