mirror of
https://github.com/langgenius/dify.git
synced 2026-04-09 01:29:22 +08:00
Compare commits
12 Commits
refactor/t
...
feat/hitl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d0ff9463f | ||
|
|
b893d2df82 | ||
|
|
79b6117d80 | ||
|
|
d2ef434dec | ||
|
|
aaf83c2b4c | ||
|
|
d898bcff90 | ||
|
|
b4cf146c85 | ||
|
|
f21782a9a3 | ||
|
|
e4455987e7 | ||
|
|
b2ceb41dd6 | ||
|
|
f614153f30 | ||
|
|
8ca020e179 |
8
.github/workflows/deploy-hitl.yml
vendored
8
.github/workflows/deploy-hitl.yml
vendored
@@ -4,7 +4,8 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "build/feat/hitl"
|
||||
- "feat/hitl-frontend"
|
||||
- "feat/hitl-backend"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@@ -13,7 +14,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'build/feat/hitl'
|
||||
(
|
||||
github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
|
||||
github.event.workflow_run.head_branch == 'feat/hitl-backend'
|
||||
)
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v1
|
||||
|
||||
2
.vscode/launch.json.template
vendored
2
.vscode/launch.json.template
vendored
@@ -37,7 +37,7 @@
|
||||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
|
||||
@@ -52,12 +52,14 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
# TODO(QuantumGhost): use DI to avoid depending on global DB.
|
||||
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
|
||||
|
||||
@@ -104,6 +106,8 @@ forbidden_modules =
|
||||
core.trigger
|
||||
core.variables
|
||||
ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> core.db.session_factory
|
||||
core.workflow.nodes.agent.agent_node -> models.tools
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.workflow_entry -> core.app.workflow.layers.observability
|
||||
@@ -124,6 +128,11 @@ ignore_imports =
|
||||
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.llm.llm_utils -> configs
|
||||
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.file.models
|
||||
@@ -144,6 +153,7 @@ ignore_imports =
|
||||
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
@@ -159,6 +169,9 @@ ignore_imports =
|
||||
core.workflow.workflow_entry -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
@@ -207,6 +220,7 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
core.workflow.nodes.llm.node -> core.model_manager
|
||||
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
|
||||
@@ -222,6 +236,7 @@ ignore_imports =
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.llm.node -> models.dataset
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
|
||||
core.workflow.nodes.llm.file_saver -> core.tools.signature
|
||||
@@ -280,6 +295,8 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
|
||||
@@ -122,7 +122,7 @@ These commands assume you start from the repository root.
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
```
|
||||
|
||||
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
|
||||
|
||||
@@ -1180,16 +1180,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
default=0,
|
||||
)
|
||||
|
||||
# API token last_used_at batch update
|
||||
ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK: bool = Field(
|
||||
description="Enable periodic batch update of API token last_used_at timestamps",
|
||||
default=True,
|
||||
)
|
||||
API_TOKEN_LAST_USED_UPDATE_INTERVAL: int = Field(
|
||||
description="Interval in minutes for batch updating API token last_used_at (default 30)",
|
||||
default=30,
|
||||
)
|
||||
|
||||
# Trigger provider refresh (simple version)
|
||||
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
|
||||
description="Enable trigger provider refresh poller",
|
||||
|
||||
@@ -5,6 +5,8 @@ from enum import StrEnum
|
||||
from flask_restx import Namespace
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
@@ -22,9 +24,6 @@ def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> No
|
||||
|
||||
|
||||
def get_or_create_model(model_name: str, field_def):
|
||||
# Import lazily to avoid circular imports between console controllers and schema helpers.
|
||||
from controllers.console import console_ns
|
||||
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
|
||||
@@ -10,7 +10,6 @@ from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
@@ -132,11 +131,6 @@ class BaseApiKeyResource(Resource):
|
||||
if key is None:
|
||||
flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found")
|
||||
|
||||
# Invalidate cache before deleting from database
|
||||
# Type assertion: key is guaranteed to be non-None here because abort() raises
|
||||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@@ -463,9 +463,8 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
"""Console API for getting workflow pause details."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def get(self, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow pause details.
|
||||
@@ -478,14 +477,10 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
# Query WorkflowRun to determine if workflow is suspended
|
||||
session_maker = sessionmaker(bind=db.engine)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_maker)
|
||||
|
||||
workflow_run = db.session.get(WorkflowRun, workflow_run_id)
|
||||
if not workflow_run:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
if workflow_run.tenant_id != current_user.current_tenant_id:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
# Check if workflow is suspended
|
||||
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
if not is_paused:
|
||||
|
||||
@@ -55,7 +55,6 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
@@ -821,11 +820,6 @@ class DatasetApiDeleteApi(Resource):
|
||||
if key is None:
|
||||
console_ns.abort(404, message="API key not found")
|
||||
|
||||
# Invalidate cache before deleting from database
|
||||
# Type assertion: key is guaranteed to be non-None here because abort() raises
|
||||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ class TagUpdateDeleteApi(Resource):
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
|
||||
@@ -34,8 +34,6 @@ from .dataset import (
|
||||
metadata,
|
||||
segment,
|
||||
)
|
||||
from .dataset.rag_pipeline import rag_pipeline_workflow
|
||||
from .end_user import end_user
|
||||
from .workspace import models
|
||||
|
||||
__all__ = [
|
||||
@@ -46,7 +44,6 @@ __all__ = [
|
||||
"conversation",
|
||||
"dataset",
|
||||
"document",
|
||||
"end_user",
|
||||
"file",
|
||||
"file_preview",
|
||||
"hit_testing",
|
||||
@@ -54,7 +51,6 @@ __all__ = [
|
||||
"message",
|
||||
"metadata",
|
||||
"models",
|
||||
"rag_pipeline_workflow",
|
||||
"segment",
|
||||
"site",
|
||||
"workflow",
|
||||
|
||||
@@ -396,7 +396,7 @@ class DatasetApi(DatasetApiResource):
|
||||
try:
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
return "", 204
|
||||
return 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
except services.errors.dataset.DatasetInUseError:
|
||||
@@ -557,7 +557,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag(payload.tag_id)
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags/binding")
|
||||
@@ -581,7 +581,7 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags/unbinding")
|
||||
@@ -605,7 +605,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/tags")
|
||||
|
||||
@@ -746,4 +746,4 @@ class DocumentApi(DatasetApiResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
@@ -128,7 +128,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import string
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
@@ -10,7 +12,6 @@ from controllers.common.errors import FilenameNotExistsError, NoFileUploadedErro
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import PipelineRunError
|
||||
from controllers.service_api.dataset.rag_pipeline.serializers import serialize_upload_file
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@@ -40,7 +41,7 @@ register_schema_model(service_api_ns, DatasourceNodeRunPayload)
|
||||
register_schema_model(service_api_ns, PipelineRunApiEntity)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/datasource-plugins")
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
|
||||
class DatasourcePluginsApi(DatasetApiResource):
|
||||
"""Resource for datasource plugins."""
|
||||
|
||||
@@ -75,7 +76,7 @@ class DatasourcePluginsApi(DatasetApiResource):
|
||||
return datasource_plugins, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/datasource/nodes/<string:node_id>/run")
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
|
||||
class DatasourceNodeRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@@ -130,7 +131,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/run")
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run")
|
||||
class PipelineRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@@ -231,4 +232,12 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return serialize_upload_file(upload_file), 201
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at,
|
||||
}, 201
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
"""
|
||||
Serialization helpers for Service API knowledge pipeline endpoints.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
def serialize_upload_file(upload_file: UploadFile) -> dict[str, Any]:
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at.isoformat() if upload_file.created_at else None,
|
||||
}
|
||||
@@ -233,7 +233,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_segment")
|
||||
@@ -499,7 +499,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
except ChildChunkDeleteIndexServiceError as e:
|
||||
raise ChildChunkDeleteIndexError(str(e))
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_child_chunk")
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from . import end_user
|
||||
|
||||
__all__ = ["end_user"]
|
||||
@@ -1,41 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.end_user.error import EndUserNotFoundError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from fields.end_user_fields import EndUserDetail
|
||||
from models.model import App
|
||||
from services.end_user_service import EndUserService
|
||||
|
||||
|
||||
@service_api_ns.route("/end-users/<uuid:end_user_id>")
|
||||
class EndUserApi(Resource):
|
||||
"""Resource for retrieving end user details by ID."""
|
||||
|
||||
@service_api_ns.doc("get_end_user")
|
||||
@service_api_ns.doc(description="Get an end user by ID")
|
||||
@service_api_ns.doc(
|
||||
params={"end_user_id": "End user ID"},
|
||||
responses={
|
||||
200: "End user retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "End user not found",
|
||||
},
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App, end_user_id: UUID):
|
||||
"""Get end user detail.
|
||||
|
||||
This endpoint is scoped to the current app token's tenant/app to prevent
|
||||
cross-tenant/app access when an end-user ID is known.
|
||||
"""
|
||||
|
||||
end_user = EndUserService.get_end_user_by_id(
|
||||
tenant_id=app_model.tenant_id, app_id=app_model.id, end_user_id=str(end_user_id)
|
||||
)
|
||||
if end_user is None:
|
||||
raise EndUserNotFoundError()
|
||||
|
||||
return EndUserDetail.model_validate(end_user).model_dump(mode="json")
|
||||
@@ -1,7 +0,0 @@
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class EndUserNotFoundError(BaseHTTPException):
|
||||
error_code = "end_user_not_found"
|
||||
description = "End user not found."
|
||||
code = 404
|
||||
@@ -1,24 +1,27 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache, fetch_token_with_single_flight, record_token_usage
|
||||
from services.end_user_service import EndUserService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@@ -217,8 +220,6 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[T, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
|
||||
# get url path dataset_id from positional args or kwargs
|
||||
# Flask passes URL path parameters as positional arguments
|
||||
dataset_id = None
|
||||
@@ -255,18 +256,12 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
# Validate dataset if dataset_id is provided
|
||||
if dataset_id:
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == api_token.tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if not dataset.enable_api:
|
||||
raise Forbidden("Dataset api access is not enabled.")
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == api_token.tenant_id)
|
||||
@@ -301,14 +296,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
|
||||
def validate_and_get_api_token(scope: str | None = None):
|
||||
"""
|
||||
Validate and get API token with Redis caching.
|
||||
|
||||
This function uses a two-tier approach:
|
||||
1. First checks Redis cache for the token
|
||||
2. If not cached, queries database and caches the result
|
||||
|
||||
The last_used_at field is updated asynchronously via Celery task
|
||||
to avoid blocking the request.
|
||||
Validate and get API token.
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None or " " not in auth_header:
|
||||
@@ -320,18 +308,29 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||
|
||||
# Try to get token from cache first
|
||||
# Returns a CachedApiToken (plain Python object), not a SQLAlchemy model
|
||||
cached_token = ApiTokenCache.get(auth_token, scope)
|
||||
if cached_token is not None:
|
||||
logger.debug("Token validation served from cache for scope: %s", scope)
|
||||
# Record usage in Redis for later batch update (no Celery task per request)
|
||||
record_token_usage(auth_token, scope)
|
||||
return cast(ApiToken, cached_token)
|
||||
current_time = naive_utc_now()
|
||||
cutoff_time = current_time - timedelta(minutes=1)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
update_stmt = (
|
||||
update(ApiToken)
|
||||
.where(
|
||||
ApiToken.token == auth_token,
|
||||
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
|
||||
ApiToken.type == scope,
|
||||
)
|
||||
.values(last_used_at=current_time)
|
||||
)
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
result = session.execute(update_stmt)
|
||||
api_token = session.scalar(stmt)
|
||||
|
||||
# Cache miss - use Redis lock for single-flight mode
|
||||
# This ensures only one request queries DB for the same token concurrently
|
||||
return fetch_token_with_single_flight(auth_token, scope)
|
||||
if hasattr(result, "rowcount") and result.rowcount > 0:
|
||||
session.commit()
|
||||
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
return api_token
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
|
||||
@@ -65,12 +65,15 @@ def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Re
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
# TODO(QuantumGhost): disable authorization for web app
|
||||
# form api temporarily
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/<string:form_token>")
|
||||
# class HumanInputFormApi(WebApiResource):
|
||||
class HumanInputFormApi(Resource):
|
||||
"""API for getting and submitting human input forms via the web app."""
|
||||
|
||||
# NOTE(QuantumGhost): this endpoint is unauthenticated on purpose for now.
|
||||
|
||||
# def get(self, _app_model: App, _end_user: EndUser, form_token: str):
|
||||
def get(self, form_token: str):
|
||||
"""
|
||||
|
||||
@@ -346,7 +346,7 @@ class WorkflowResponseConverter:
|
||||
paused_nodes=list(event.paused_nodes),
|
||||
outputs=encoded_outputs,
|
||||
reasons=pause_reasons,
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
status=WorkflowExecutionStatus.PAUSED.value,
|
||||
created_at=int(started_at.timestamp()),
|
||||
elapsed_time=elapsed_time,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
@@ -422,7 +422,7 @@ class WorkflowResponseConverter:
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=run_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
status=workflow_run.status,
|
||||
status=workflow_run.status.value,
|
||||
outputs=encoded_outputs,
|
||||
error=workflow_run.error,
|
||||
elapsed_time=elapsed_time,
|
||||
|
||||
@@ -262,7 +262,7 @@ class WorkflowPauseStreamResponse(StreamResponse):
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
outputs: Mapping[str, Any] = Field(default_factory=dict)
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||
status: WorkflowExecutionStatus
|
||||
status: str
|
||||
created_at: int
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
|
||||
@@ -8,7 +8,6 @@ from core.file.file_manager import file_manager
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType
|
||||
@@ -17,7 +16,6 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
@@ -77,7 +75,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
self._http_request_http_client = http_request_http_client or ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
|
||||
self._http_request_file_manager = http_request_file_manager or file_manager
|
||||
self._rag_retrieval = DatasetRetrieval()
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
@@ -143,15 +140,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
file_manager=self._http_request_file_manager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
return KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
rag_retrieval=self._rag_retrieval,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy import and_, literal, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
@@ -20,7 +18,6 @@ from core.app.app_config.entities import (
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
@@ -61,30 +58,12 @@ from core.rag.retrieval.template_prompts import (
|
||||
)
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.workflow.nodes.knowledge_retrieval import exc
|
||||
from core.workflow.repositories.rag_retrieval_protocol import (
|
||||
KnowledgeRetrievalRequest,
|
||||
Source,
|
||||
SourceChildChunk,
|
||||
SourceMetadata,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models import UploadFile
|
||||
from models.dataset import (
|
||||
ChildChunk,
|
||||
Dataset,
|
||||
DatasetMetadata,
|
||||
DatasetQuery,
|
||||
DocumentSegment,
|
||||
RateLimitLog,
|
||||
SegmentAttachmentBinding,
|
||||
)
|
||||
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.dataset import Document as DocumentModel
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
@@ -94,8 +73,6 @@ default_retrieval_model: dict[str, Any] = {
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetRetrieval:
|
||||
def __init__(self, application_generate_entity=None):
|
||||
@@ -114,233 +91,6 @@ class DatasetRetrieval:
|
||||
else:
|
||||
self._llm_usage = self._llm_usage.plus(usage)
|
||||
|
||||
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]:
|
||||
self._check_knowledge_rate_limit(request.tenant_id)
|
||||
available_datasets = self._get_available_datasets(request.tenant_id, request.dataset_ids)
|
||||
available_datasets_ids = [i.id for i in available_datasets]
|
||||
if not available_datasets_ids:
|
||||
return []
|
||||
|
||||
if not request.query:
|
||||
return []
|
||||
|
||||
metadata_filter_document_ids, metadata_condition = None, None
|
||||
|
||||
if request.metadata_filtering_mode != "disabled":
|
||||
# Convert workflow layer types to app_config layer types
|
||||
if not request.metadata_model_config:
|
||||
raise ValueError("metadata_model_config is required for this method")
|
||||
|
||||
app_metadata_model_config = ModelConfig.model_validate(request.metadata_model_config.model_dump())
|
||||
|
||||
app_metadata_filtering_conditions = None
|
||||
if request.metadata_filtering_conditions is not None:
|
||||
app_metadata_filtering_conditions = MetadataFilteringCondition.model_validate(
|
||||
request.metadata_filtering_conditions.model_dump()
|
||||
)
|
||||
|
||||
query = request.query if request.query is not None else ""
|
||||
|
||||
metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
|
||||
dataset_ids=available_datasets_ids,
|
||||
query=query,
|
||||
tenant_id=request.tenant_id,
|
||||
user_id=request.user_id,
|
||||
metadata_filtering_mode=request.metadata_filtering_mode,
|
||||
metadata_model_config=app_metadata_model_config,
|
||||
metadata_filtering_conditions=app_metadata_filtering_conditions,
|
||||
inputs={},
|
||||
)
|
||||
|
||||
if request.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
# Ensure required fields are not None for single retrieval mode
|
||||
if request.model_provider is None or request.model_name is None or request.query is None:
|
||||
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=request.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=request.model_provider,
|
||||
model=request.model_name,
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=request.model_name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise exc.ModelNotExistError(f"Model {request.model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise exc.ModelCredentialsNotInitializedError(
|
||||
f"Model {request.model_name} credentials is not initialized."
|
||||
)
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise exc.ModelNotSupportedError(f"Dify Hosted OpenAI {request.model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise exc.ModelQuotaExceededError(f"Model provider {request.model_provider} quota exceeded.")
|
||||
|
||||
stop = []
|
||||
completion_params = (request.completion_params or {}).copy()
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(request.model_name, model_credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise exc.ModelNotExistError(f"Model {request.model_name} not exist.")
|
||||
|
||||
model_config = ModelConfigWithCredentialsEntity(
|
||||
provider=request.model_provider,
|
||||
model=request.model_name,
|
||||
model_schema=model_schema,
|
||||
mode=request.model_mode or "chat",
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
all_documents = self.single_retrieve(
|
||||
request.app_id,
|
||||
request.tenant_id,
|
||||
request.user_id,
|
||||
request.user_from,
|
||||
request.query,
|
||||
available_datasets,
|
||||
model_instance,
|
||||
model_config,
|
||||
planning_strategy,
|
||||
None, # message_id
|
||||
metadata_filter_document_ids,
|
||||
metadata_condition,
|
||||
)
|
||||
else:
|
||||
all_documents = self.multiple_retrieve(
|
||||
app_id=request.app_id,
|
||||
tenant_id=request.tenant_id,
|
||||
user_id=request.user_id,
|
||||
user_from=request.user_from,
|
||||
available_datasets=available_datasets,
|
||||
query=request.query,
|
||||
top_k=request.top_k,
|
||||
score_threshold=request.score_threshold,
|
||||
reranking_mode=request.reranking_mode,
|
||||
reranking_model=request.reranking_model,
|
||||
weights=request.weights,
|
||||
reranking_enable=request.reranking_enable,
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
attachment_ids=request.attachment_ids,
|
||||
)
|
||||
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
retrieval_resource_list = []
|
||||
# deal with external documents
|
||||
for item in external_documents:
|
||||
source = Source(
|
||||
metadata=SourceMetadata(
|
||||
source="knowledge",
|
||||
dataset_id=item.metadata.get("dataset_id"),
|
||||
dataset_name=item.metadata.get("dataset_name"),
|
||||
document_id=item.metadata.get("document_id"),
|
||||
document_name=item.metadata.get("title"),
|
||||
data_source_type="external",
|
||||
retriever_from="workflow",
|
||||
score=item.metadata.get("score"),
|
||||
doc_metadata=item.metadata,
|
||||
),
|
||||
title=item.metadata.get("title"),
|
||||
content=item.page_content,
|
||||
)
|
||||
retrieval_resource_list.append(source)
|
||||
# deal with dify documents
|
||||
if dify_documents:
|
||||
records = RetrievalService.format_retrieval_documents(dify_documents)
|
||||
dataset_ids = [i.segment.dataset_id for i in records]
|
||||
document_ids = [i.segment.document_id for i in records]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
|
||||
documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all()
|
||||
|
||||
dataset_map = {i.id: i for i in datasets}
|
||||
document_map = {i.id: i for i in documents}
|
||||
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = dataset_map.get(segment.dataset_id)
|
||||
document = document_map.get(segment.document_id)
|
||||
|
||||
if dataset and document:
|
||||
source = Source(
|
||||
metadata=SourceMetadata(
|
||||
source="knowledge",
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from="workflow",
|
||||
score=record.score or 0.0,
|
||||
segment_hit_count=segment.hit_count,
|
||||
segment_word_count=segment.word_count,
|
||||
segment_position=segment.position,
|
||||
segment_index_node_hash=segment.index_node_hash,
|
||||
doc_metadata=document.doc_metadata,
|
||||
child_chunks=[
|
||||
SourceChildChunk(
|
||||
id=str(getattr(chunk, "id", "")),
|
||||
content=str(getattr(chunk, "content", "")),
|
||||
position=int(getattr(chunk, "position", 0)),
|
||||
score=float(getattr(chunk, "score", 0.0)),
|
||||
)
|
||||
for chunk in (record.child_chunks or [])
|
||||
],
|
||||
position=None,
|
||||
),
|
||||
title=document.name,
|
||||
files=list(record.files) if record.files else None,
|
||||
content=segment.get_sign_content(),
|
||||
)
|
||||
if segment.answer:
|
||||
source.content = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
|
||||
|
||||
if record.summary:
|
||||
source.summary = record.summary
|
||||
|
||||
retrieval_resource_list.append(source)
|
||||
|
||||
if retrieval_resource_list:
|
||||
|
||||
def _score(item: Source) -> float:
|
||||
meta = item.metadata
|
||||
score = meta.score
|
||||
if isinstance(score, (int, float)):
|
||||
return float(score)
|
||||
return 0.0
|
||||
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=_score, # type: ignore[arg-type, return-value]
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||
item.metadata.position = position # type: ignore[index]
|
||||
return retrieval_resource_list
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
app_id: str,
|
||||
@@ -400,7 +150,14 @@ class DatasetRetrieval:
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
available_datasets = self._get_available_datasets(tenant_id, dataset_ids)
|
||||
available_datasets = []
|
||||
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
|
||||
datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
|
||||
for dataset in datasets:
|
||||
if dataset.available_document_count == 0 and dataset.provider != "external":
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
|
||||
if inputs:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
@@ -1404,6 +1161,7 @@ class DatasetRetrieval:
|
||||
query=query or "",
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
try:
|
||||
# handle invoke result
|
||||
invoke_result = cast(
|
||||
@@ -1434,8 +1192,7 @@ class DatasetRetrieval:
|
||||
"condition": item.get("comparison_operator"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(e, exc_info=True)
|
||||
except Exception:
|
||||
return None
|
||||
return automatic_metadata_filters
|
||||
|
||||
@@ -1649,12 +1406,7 @@ class DatasetRetrieval:
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
text = result.delta.message.content
|
||||
if isinstance(text, str):
|
||||
full_text += text
|
||||
elif isinstance(text, list):
|
||||
for i in text:
|
||||
if i.data:
|
||||
full_text += i.data
|
||||
full_text += text
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
@@ -1772,53 +1524,3 @@ class DatasetRetrieval:
|
||||
cancel_event.set()
|
||||
if thread_exceptions is not None:
|
||||
thread_exceptions.append(e)
|
||||
|
||||
def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]:
|
||||
with session_factory.create_session() as session:
|
||||
subquery = (
|
||||
session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
|
||||
.where(
|
||||
DocumentModel.indexing_status == "completed",
|
||||
DocumentModel.enabled == True,
|
||||
DocumentModel.archived == False,
|
||||
DocumentModel.dataset_id.in_(dataset_ids),
|
||||
)
|
||||
.group_by(DocumentModel.dataset_id)
|
||||
.having(func.count(DocumentModel.id) > 0)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
results = (
|
||||
session.query(Dataset)
|
||||
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
|
||||
.where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
|
||||
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
|
||||
.all()
|
||||
)
|
||||
|
||||
available_datasets = []
|
||||
for dataset in results:
|
||||
if not dataset:
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
return available_datasets
|
||||
|
||||
def _check_knowledge_rate_limit(self, tenant_id: str):
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{tenant_id}"
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
request_count = redis_client.zcard(key)
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
with session_factory.create_session() as session:
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
session.add(rate_limit_log)
|
||||
raise exc.RateLimitExceededError(
|
||||
"you have reached the knowledge base request rate limit of your subscription."
|
||||
)
|
||||
|
||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
@@ -17,7 +17,6 @@ from core.mcp.types import (
|
||||
TextContent,
|
||||
TextResourceContents,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
@@ -47,7 +46,6 @@ class MCPTool(Tool):
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.MCP
|
||||
@@ -61,10 +59,6 @@ class MCPTool(Tool):
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
result = self.invoke_remote_mcp_tool(tool_parameters)
|
||||
|
||||
# Extract usage metadata from MCP protocol's _meta field
|
||||
self._latest_usage = self._derive_usage_from_result(result)
|
||||
|
||||
# handle dify tool output
|
||||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
@@ -126,99 +120,6 @@ class MCPTool(Tool):
|
||||
for item in json_list:
|
||||
yield self.create_json_message(item)
|
||||
|
||||
@property
|
||||
def latest_usage(self) -> LLMUsage:
|
||||
return self._latest_usage
|
||||
|
||||
@classmethod
|
||||
def _derive_usage_from_result(cls, result: CallToolResult) -> LLMUsage:
|
||||
"""
|
||||
Extract usage metadata from MCP tool result's _meta field.
|
||||
|
||||
The MCP protocol's _meta field (aliased as 'meta' in Python) can contain
|
||||
usage information such as token counts, costs, and other metadata.
|
||||
|
||||
Args:
|
||||
result: The CallToolResult from MCP tool invocation
|
||||
|
||||
Returns:
|
||||
LLMUsage instance with values from meta or empty_usage if not found
|
||||
"""
|
||||
# Extract usage from the meta field if present
|
||||
if result.meta:
|
||||
usage_dict = cls._extract_usage_dict(result.meta)
|
||||
if usage_dict is not None:
|
||||
return LLMUsage.from_metadata(cast(LLMUsageMetadata, cast(object, dict(usage_dict))))
|
||||
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
@classmethod
|
||||
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Recursively search for usage dictionary in the payload.
|
||||
|
||||
The MCP protocol's _meta field can contain usage data in various formats:
|
||||
- Direct usage field: {"usage": {...}}
|
||||
- Nested in metadata: {"metadata": {"usage": {...}}}
|
||||
- Or nested within other fields
|
||||
|
||||
Args:
|
||||
payload: The payload to search for usage data
|
||||
|
||||
Returns:
|
||||
The usage dictionary if found, None otherwise
|
||||
"""
|
||||
# Check for direct usage field
|
||||
usage_candidate = payload.get("usage")
|
||||
if isinstance(usage_candidate, Mapping):
|
||||
return usage_candidate
|
||||
|
||||
# Check for metadata nested usage
|
||||
metadata_candidate = payload.get("metadata")
|
||||
if isinstance(metadata_candidate, Mapping):
|
||||
usage_candidate = metadata_candidate.get("usage")
|
||||
if isinstance(usage_candidate, Mapping):
|
||||
return usage_candidate
|
||||
|
||||
# Check for common token counting fields directly in payload
|
||||
# Some MCP servers may include token counts directly
|
||||
if "total_tokens" in payload or "prompt_tokens" in payload or "completion_tokens" in payload:
|
||||
usage_dict: dict[str, Any] = {}
|
||||
for key in (
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
"prompt_unit_price",
|
||||
"completion_unit_price",
|
||||
"total_price",
|
||||
"currency",
|
||||
"prompt_price_unit",
|
||||
"completion_price_unit",
|
||||
"prompt_price",
|
||||
"completion_price",
|
||||
"latency",
|
||||
"time_to_first_token",
|
||||
"time_to_generate",
|
||||
):
|
||||
if key in payload:
|
||||
usage_dict[key] = payload[key]
|
||||
if usage_dict:
|
||||
return usage_dict
|
||||
|
||||
# Recursively search through nested structures
|
||||
for value in payload.values():
|
||||
if isinstance(value, Mapping):
|
||||
found = cls._extract_usage_dict(value)
|
||||
if found is not None:
|
||||
return found
|
||||
elif isinstance(value, list) and not isinstance(value, (str, bytes, bytearray)):
|
||||
for item in value:
|
||||
if isinstance(item, Mapping):
|
||||
found = cls._extract_usage_dict(item)
|
||||
if found is not None:
|
||||
return found
|
||||
return None
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.runtime.graph_runtime_state import GraphExecutionProtocol
|
||||
|
||||
from .node_execution import NodeExecution
|
||||
|
||||
@@ -237,6 +236,3 @@ class GraphExecution:
|
||||
def record_node_failure(self) -> None:
|
||||
"""Increment the count of node failures encountered during execution."""
|
||||
self.exceptions_count += 1
|
||||
|
||||
|
||||
_: GraphExecutionProtocol = GraphExecution(workflow_id="")
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.db.session_factory import session_factory
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
@@ -49,6 +50,12 @@ from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
from models import ToolFile
|
||||
from models.model import Conversation
|
||||
from models.tools import (
|
||||
ApiToolProvider,
|
||||
BuiltinToolProvider,
|
||||
MCPToolProvider,
|
||||
WorkflowToolProvider,
|
||||
)
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .exc import (
|
||||
@@ -259,7 +266,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
provider_type = self._infer_tool_provider_type(tool, self.tenant_id)
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
@@ -748,3 +755,34 @@ class AgentNode(Node[AgentNodeData]):
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _infer_tool_provider_type(tool_config: dict[str, Any], tenant_id: str) -> ToolProviderType:
|
||||
provider_type_str = tool_config.get("type")
|
||||
if provider_type_str:
|
||||
return ToolProviderType(provider_type_str)
|
||||
|
||||
provider_id = tool_config.get("provider_name")
|
||||
if not provider_id:
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
provider_map: dict[
|
||||
type[Union[WorkflowToolProvider, MCPToolProvider, ApiToolProvider, BuiltinToolProvider]],
|
||||
ToolProviderType,
|
||||
] = {
|
||||
WorkflowToolProvider: ToolProviderType.WORKFLOW,
|
||||
MCPToolProvider: ToolProviderType.MCP,
|
||||
ApiToolProvider: ToolProviderType.API,
|
||||
BuiltinToolProvider: ToolProviderType.BUILT_IN,
|
||||
}
|
||||
|
||||
for provider_model, provider_type in provider_map.items():
|
||||
stmt = select(provider_model).where(
|
||||
provider_model.id == provider_id,
|
||||
provider_model.tenant_id == tenant_id,
|
||||
)
|
||||
if session.scalar(stmt):
|
||||
return provider_type
|
||||
|
||||
raise AgentNodeError(f"Tool provider with ID '{provider_id}' not found.")
|
||||
|
||||
@@ -20,7 +20,3 @@ class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
|
||||
|
||||
class InvalidModelTypeError(KnowledgeRetrievalNodeError):
|
||||
"""Raised when the model is not a Large Language Model."""
|
||||
|
||||
|
||||
class RateLimitExceededError(KnowledgeRetrievalNodeError):
|
||||
"""Raised when the rate limit is exceeded."""
|
||||
|
||||
@@ -1,10 +1,29 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.variables import (
|
||||
ArrayFileSegment,
|
||||
FileSegment,
|
||||
@@ -17,16 +36,35 @@ from core.workflow.enums import (
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_2,
|
||||
METADATA_FILTER_COMPLETION_PROMPT,
|
||||
METADATA_FILTER_SYSTEM_PROMPT,
|
||||
METADATA_FILTER_USER_PROMPT_1,
|
||||
METADATA_FILTER_USER_PROMPT_2,
|
||||
METADATA_FILTER_USER_PROMPT_3,
|
||||
)
|
||||
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig
|
||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .entities import KnowledgeRetrievalNodeData
|
||||
from .exc import (
|
||||
InvalidModelTypeError,
|
||||
KnowledgeRetrievalNodeError,
|
||||
RateLimitExceededError,
|
||||
ModelCredentialsNotInitializedError,
|
||||
ModelNotExistError,
|
||||
ModelNotSupportedError,
|
||||
ModelQuotaExceededError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -35,6 +73,14 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
|
||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
@@ -51,7 +97,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
rag_retrieval: RAGRetrievalProtocol,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@@ -63,7 +108,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs = []
|
||||
self._rag_retrieval = rag_retrieval
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
@@ -77,7 +121,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
usage = LLMUsage.empty_usage()
|
||||
if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@@ -85,7 +128,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
process_data={},
|
||||
outputs={},
|
||||
metadata={},
|
||||
llm_usage=usage,
|
||||
llm_usage=LLMUsage.empty_usage(),
|
||||
)
|
||||
variables: dict[str, Any] = {}
|
||||
# extract variables
|
||||
@@ -113,9 +156,36 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
else:
|
||||
variables["attachments"] = [variable.value]
|
||||
|
||||
# TODO(-LAN-): Move this check outside.
|
||||
# check rate limit
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{self.tenant_id}"
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
request_count = redis_client.zcard(key)
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=self.tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
session.add(rate_limit_log)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
|
||||
error_type="RateLimitExceeded",
|
||||
)
|
||||
|
||||
# retrieve knowledge
|
||||
usage = LLMUsage.empty_usage()
|
||||
try:
|
||||
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
|
||||
outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])}
|
||||
outputs = {"result": ArrayObjectSegment(value=results)}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
@@ -128,17 +198,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
except RateLimitExceededError as e:
|
||||
logger.warning(e, exc_info=True)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
except KnowledgeRetrievalNodeError as e:
|
||||
logger.warning("Error when running knowledge retrieval node", exc_info=True)
|
||||
logger.warning("Error when running knowledge retrieval node")
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
@@ -148,7 +210,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
)
|
||||
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||
except Exception as e:
|
||||
logger.warning(e, exc_info=True)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
@@ -156,47 +217,92 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _fetch_dataset_retriever(
|
||||
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
|
||||
) -> tuple[list[Source], LLMUsage]:
|
||||
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
available_datasets = []
|
||||
dataset_ids = node_data.dataset_ids
|
||||
query = variables.get("query")
|
||||
attachments = variables.get("attachments")
|
||||
retrieval_resource_list = []
|
||||
metadata_filter_document_ids = None
|
||||
metadata_condition = None
|
||||
metadata_usage = LLMUsage.empty_usage()
|
||||
# Subquery: Count the number of available documents for each dataset
|
||||
subquery = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
)
|
||||
.group_by(Document.dataset_id)
|
||||
.having(func.count(Document.id) > 0)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = "disabled"
|
||||
if node_data.metadata_filtering_mode is not None:
|
||||
metadata_filtering_mode = node_data.metadata_filtering_mode
|
||||
results = (
|
||||
db.session.query(Dataset)
|
||||
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
|
||||
.where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
|
||||
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
|
||||
.all()
|
||||
)
|
||||
|
||||
# avoid blocking at retrieval
|
||||
db.session.close()
|
||||
|
||||
for dataset in results:
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
if query:
|
||||
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
|
||||
[dataset.id for dataset in available_datasets], query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, metadata_usage)
|
||||
all_documents = []
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
|
||||
# fetch model config
|
||||
if node_data.single_retrieval_config is None:
|
||||
raise ValueError("single_retrieval_config is required for single retrieval mode")
|
||||
model = node_data.single_retrieval_config.model
|
||||
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
|
||||
request=KnowledgeRetrievalRequest(
|
||||
raise ValueError("single_retrieval_config is required")
|
||||
model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model)
|
||||
# check model is support tool calling
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model, credentials=model_config.credentials
|
||||
)
|
||||
|
||||
if model_schema:
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
all_documents = dataset_retrieval.single_retrieve(
|
||||
available_datasets=available_datasets,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
user_from=self.user_from.value,
|
||||
dataset_ids=dataset_ids,
|
||||
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value,
|
||||
completion_params=model.completion_params,
|
||||
model_provider=model.provider,
|
||||
model_mode=model.mode,
|
||||
model_name=model.name,
|
||||
metadata_model_config=node_data.metadata_model_config,
|
||||
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
|
||||
metadata_filtering_mode=metadata_filtering_mode,
|
||||
query=query,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
planning_strategy=planning_strategy,
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
)
|
||||
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
reranking_model = None
|
||||
weights = None
|
||||
match node_data.multiple_retrieval_config.reranking_mode:
|
||||
case "reranking_model":
|
||||
if node_data.multiple_retrieval_config.reranking_model:
|
||||
@@ -223,36 +329,284 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
},
|
||||
}
|
||||
case _:
|
||||
# Handle any other reranking_mode values
|
||||
reranking_model = None
|
||||
weights = None
|
||||
all_documents = dataset_retrieval.multiple_retrieve(
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
available_datasets=available_datasets,
|
||||
query=query,
|
||||
top_k=node_data.multiple_retrieval_config.top_k,
|
||||
score_threshold=node_data.multiple_retrieval_config.score_threshold
|
||||
if node_data.multiple_retrieval_config.score_threshold is not None
|
||||
else 0.0,
|
||||
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
|
||||
reranking_model=reranking_model,
|
||||
weights=weights,
|
||||
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
|
||||
)
|
||||
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
|
||||
|
||||
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
|
||||
request=KnowledgeRetrievalRequest(
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
dataset_ids=dataset_ids,
|
||||
query=query,
|
||||
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value,
|
||||
top_k=node_data.multiple_retrieval_config.top_k,
|
||||
score_threshold=node_data.multiple_retrieval_config.score_threshold
|
||||
if node_data.multiple_retrieval_config.score_threshold is not None
|
||||
else 0.0,
|
||||
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
|
||||
reranking_model=reranking_model,
|
||||
weights=weights,
|
||||
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
||||
metadata_model_config=node_data.metadata_model_config,
|
||||
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
|
||||
metadata_filtering_mode=metadata_filtering_mode,
|
||||
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
retrieval_resource_list = []
|
||||
# deal with external documents
|
||||
for item in external_documents:
|
||||
source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = {
|
||||
"metadata": {
|
||||
"_source": "knowledge",
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": "workflow",
|
||||
"score": item.metadata.get("score"),
|
||||
"doc_metadata": item.metadata,
|
||||
},
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
retrieval_resource_list.append(source)
|
||||
# deal with dify documents
|
||||
if dify_documents:
|
||||
records = RetrievalService.format_retrieval_documents(dify_documents)
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
|
||||
stmt = select(Document).where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
document = db.session.scalar(stmt)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"metadata": {
|
||||
"_source": "knowledge",
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": "workflow",
|
||||
"score": record.score or 0.0,
|
||||
"child_chunks": [
|
||||
{
|
||||
"id": str(getattr(chunk, "id", "")),
|
||||
"content": str(getattr(chunk, "content", "")),
|
||||
"position": int(getattr(chunk, "position", 0)),
|
||||
"score": float(getattr(chunk, "score", 0.0)),
|
||||
}
|
||||
for chunk in (record.child_chunks or [])
|
||||
],
|
||||
"segment_hit_count": segment.hit_count,
|
||||
"segment_word_count": segment.word_count,
|
||||
"segment_position": segment.position,
|
||||
"segment_index_node_hash": segment.index_node_hash,
|
||||
"doc_metadata": document.doc_metadata,
|
||||
},
|
||||
"title": document.name,
|
||||
"files": list(record.files) if record.files else None,
|
||||
}
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.get_sign_content()
|
||||
# Add summary if available
|
||||
if record.summary:
|
||||
source["summary"] = record.summary
|
||||
retrieval_resource_list.append(source)
|
||||
if retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=self._score, # type: ignore[arg-type, return-value]
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||
item["metadata"]["position"] = position # type: ignore[index]
|
||||
return retrieval_resource_list, usage
|
||||
|
||||
def _score(self, item: dict[str, Any]) -> float:
|
||||
meta = item.get("metadata")
|
||||
if isinstance(meta, dict):
|
||||
s = meta.get("score")
|
||||
if isinstance(s, (int, float)):
|
||||
return float(s)
|
||||
return 0.0
|
||||
|
||||
def _get_metadata_filter_condition(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
document_query = db.session.query(Document).where(
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
filters: list[Any] = []
|
||||
metadata_condition = None
|
||||
match node_data.metadata_filtering_mode:
|
||||
case "disabled":
|
||||
return None, None, usage
|
||||
case "automatic":
|
||||
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
|
||||
dataset_ids, query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, automatic_usage)
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
filter.get("value"),
|
||||
filters,
|
||||
)
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=filter.get("metadata_name"), # type: ignore
|
||||
comparison_operator=filter.get("condition"), # type: ignore
|
||||
value=filter.get("value"),
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or",
|
||||
conditions=conditions,
|
||||
)
|
||||
case "manual":
|
||||
if node_data.metadata_filtering_conditions:
|
||||
conditions = []
|
||||
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
).value[0]
|
||||
if expected_value.value_type in {"number", "integer", "float"}:
|
||||
expected_value = expected_value.value
|
||||
elif expected_value.value_type == "string":
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
case _:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
if (
|
||||
node_data.metadata_filtering_conditions
|
||||
and node_data.metadata_filtering_conditions.logical_operator == "and"
|
||||
):
|
||||
document_query = document_query.where(and_(*filters))
|
||||
else:
|
||||
document_query = document_query.where(or_(*filters))
|
||||
documents = document_query.all()
|
||||
# group by dataset_id
|
||||
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
||||
for document in documents:
|
||||
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
||||
return metadata_filter_document_ids, metadata_condition, usage
|
||||
|
||||
def _automatic_metadata_filter_func(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
# get all metadata field
|
||||
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||
metadata_fields = db.session.scalars(stmt).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
if node_data.metadata_model_config is None:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
# get metadata model instance and fetch model config
|
||||
model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
|
||||
# fetch prompt messages
|
||||
prompt_template = self._get_prompt_template(
|
||||
node_data=node_data,
|
||||
metadata_fields=all_metadata_fields,
|
||||
query=query or "",
|
||||
)
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query=query,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
sys_files=[],
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
try:
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=node_data.metadata_model_config,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self.node_data.structured_output_enabled,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
)
|
||||
|
||||
usage = self._rag_retrieval.llm_usage
|
||||
return retrieval_resource_list, usage
|
||||
for event in generator:
|
||||
if isinstance(event, ModelInvokeCompletedEvent):
|
||||
result_text = event.text
|
||||
usage = self._merge_usage(usage, event.usage)
|
||||
break
|
||||
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
automatic_metadata_filters = []
|
||||
if "metadata_map" in result_text_json:
|
||||
metadata_map = result_text_json["metadata_map"]
|
||||
for item in metadata_map:
|
||||
if item.get("metadata_field_name") in all_metadata_fields:
|
||||
automatic_metadata_filters.append(
|
||||
{
|
||||
"metadata_name": item.get("metadata_field_name"),
|
||||
"value": item.get("metadata_field_value"),
|
||||
"condition": item.get("comparison_operator"),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
return [], usage
|
||||
return automatic_metadata_filters, usage
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
@@ -272,3 +626,107 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
if typed_node_data.query_attachment_selector:
|
||||
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
|
||||
return variable_mapping
|
||||
|
||||
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
model_name = model.name
|
||||
provider_name = model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = model.completion_params
|
||||
stop = []
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = model.mode
|
||||
if not model_mode:
|
||||
raise ModelNotExistError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
|
||||
model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore
|
||||
input_text = query
|
||||
|
||||
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT
|
||||
)
|
||||
prompt_messages.append(system_prompt_messages)
|
||||
user_prompt_message_1 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_1)
|
||||
assistant_prompt_message_1 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_1)
|
||||
user_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_2)
|
||||
assistant_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_2)
|
||||
user_prompt_message_3 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=METADATA_FILTER_USER_PROMPT_3.format(
|
||||
input_text=input_text,
|
||||
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_3)
|
||||
return prompt_messages
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
return LLMNodeCompletionModelPromptTemplate(
|
||||
text=METADATA_FILTER_COMPLETION_PROMPT.format(
|
||||
input_text=input_text,
|
||||
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities import LLMUsage
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import MetadataFilteringCondition
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
|
||||
|
||||
class SourceChildChunk(BaseModel):
|
||||
id: str = Field(default="", description="Child chunk ID")
|
||||
content: str = Field(default="", description="Child chunk content")
|
||||
position: int = Field(default=0, description="Child chunk position")
|
||||
score: float = Field(default=0.0, description="Child chunk relevance score")
|
||||
|
||||
|
||||
class SourceMetadata(BaseModel):
|
||||
source: str = Field(
|
||||
default="knowledge",
|
||||
serialization_alias="_source",
|
||||
description="Data source identifier",
|
||||
)
|
||||
dataset_id: str = Field(description="Dataset unique identifier")
|
||||
dataset_name: str = Field(description="Dataset display name")
|
||||
document_id: str = Field(description="Document unique identifier")
|
||||
document_name: str = Field(description="Document display name")
|
||||
data_source_type: str = Field(description="Type of data source")
|
||||
segment_id: str | None = Field(default=None, description="Segment unique identifier")
|
||||
retriever_from: str = Field(default="workflow", description="Retriever source context")
|
||||
score: float = Field(default=0.0, description="Retrieval relevance score")
|
||||
child_chunks: list[SourceChildChunk] = Field(default=[], description="List of child chunks")
|
||||
segment_hit_count: int | None = Field(default=0, description="Number of times segment was retrieved")
|
||||
segment_word_count: int | None = Field(default=0, description="Word count of the segment")
|
||||
segment_position: int | None = Field(default=0, description="Position of segment in document")
|
||||
segment_index_node_hash: str | None = Field(default=None, description="Hash of index node for the segment")
|
||||
doc_metadata: dict[str, Any] | None = Field(default=None, description="Additional document metadata")
|
||||
position: int | None = Field(default=0, description="Position of the document in the dataset")
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
|
||||
|
||||
class Source(BaseModel):
|
||||
metadata: SourceMetadata = Field(description="Source metadata information")
|
||||
title: str = Field(description="Document title")
|
||||
files: list[Any] | None = Field(default=None, description="Associated file references")
|
||||
content: str | None = Field(description="Segment content text")
|
||||
summary: str | None = Field(default=None, description="Content summary if available")
|
||||
|
||||
|
||||
class KnowledgeRetrievalRequest(BaseModel):
|
||||
tenant_id: str = Field(description="Tenant unique identifier")
|
||||
user_id: str = Field(description="User unique identifier")
|
||||
app_id: str = Field(description="Application unique identifier")
|
||||
user_from: str = Field(description="Source of the user request (e.g., 'workflow', 'api')")
|
||||
dataset_ids: list[str] = Field(description="List of dataset IDs to retrieve from")
|
||||
query: str | None = Field(default=None, description="Query text for knowledge retrieval")
|
||||
retrieval_mode: str = Field(description="Retrieval strategy: 'single' or 'multiple'")
|
||||
model_provider: str | None = Field(default=None, description="Model provider name (e.g., 'openai', 'anthropic')")
|
||||
completion_params: dict[str, Any] | None = Field(
|
||||
default=None, description="Model completion parameters (e.g., temperature, max_tokens)"
|
||||
)
|
||||
model_mode: str | None = Field(default=None, description="Model mode (e.g., 'chat', 'completion')")
|
||||
model_name: str | None = Field(default=None, description="Model name (e.g., 'gpt-4', 'claude-3-opus')")
|
||||
metadata_model_config: ModelConfig | None = Field(
|
||||
default=None, description="Model config for metadata-based filtering"
|
||||
)
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None = Field(
|
||||
default=None, description="Conditions for filtering by metadata"
|
||||
)
|
||||
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = Field(
|
||||
default="disabled", description="Metadata filtering mode: 'disabled', 'automatic', or 'manual'"
|
||||
)
|
||||
top_k: int = Field(default=0, description="Number of top results to return")
|
||||
score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold")
|
||||
reranking_mode: str = Field(default="reranking_model", description="Reranking strategy")
|
||||
reranking_model: dict | None = Field(default=None, description="Reranking model configuration")
|
||||
weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking")
|
||||
reranking_enable: bool = Field(default=True, description="Whether reranking is enabled")
|
||||
attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval")
|
||||
|
||||
|
||||
class RAGRetrievalProtocol(Protocol):
|
||||
"""Protocol for RAG-based knowledge retrieval implementations.
|
||||
|
||||
Implementations of this protocol handle knowledge retrieval from datasets
|
||||
including rate limiting, dataset filtering, and document retrieval.
|
||||
"""
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Return accumulated LLM usage for retrieval operations."""
|
||||
...
|
||||
|
||||
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]:
|
||||
"""Retrieve knowledge from datasets based on the provided request.
|
||||
|
||||
Args:
|
||||
request: Knowledge retrieval request with search parameters
|
||||
|
||||
Returns:
|
||||
List of sources matching the search criteria
|
||||
|
||||
Raises:
|
||||
RateLimitExceededError: If rate limit is exceeded
|
||||
ModelNotExistError: If specified model doesn't exist
|
||||
"""
|
||||
...
|
||||
@@ -64,7 +64,7 @@ class GraphExecutionProtocol(Protocol):
|
||||
aborted: bool
|
||||
error: Exception | None
|
||||
exceptions_count: int
|
||||
pause_reasons: list[PauseReason]
|
||||
pause_reasons: Sequence[PauseReason]
|
||||
|
||||
def start(self) -> None:
|
||||
"""Transition execution into the running state."""
|
||||
@@ -446,7 +446,7 @@ class GraphRuntimeState:
|
||||
graph_execution_cls = module.GraphExecution
|
||||
workflow_id = self._pending_graph_execution_workflow_id or ""
|
||||
self._pending_graph_execution_workflow_id = None
|
||||
return graph_execution_cls(workflow_id=workflow_id) # type: ignore[invalid-return-type]
|
||||
return graph_execution_cls(workflow_id=workflow_id)
|
||||
|
||||
def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol:
|
||||
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
|
||||
|
||||
@@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
|
||||
if [[ -z "${CELERY_QUEUES}" ]]; then
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
|
||||
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
fi
|
||||
else
|
||||
DEFAULT_QUEUES="${CELERY_QUEUES}"
|
||||
|
||||
@@ -190,14 +190,6 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"task": "schedule.trigger_provider_refresh_task.trigger_provider_refresh",
|
||||
"schedule": timedelta(minutes=dify_config.TRIGGER_PROVIDER_REFRESH_INTERVAL),
|
||||
}
|
||||
|
||||
if dify_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK:
|
||||
imports.append("schedule.update_api_token_last_used_task")
|
||||
beat_schedule["batch_update_api_token_last_used"] = {
|
||||
"task": "schedule.update_api_token_last_used_task.batch_update_api_token_last_used",
|
||||
"schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL),
|
||||
}
|
||||
|
||||
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
|
||||
|
||||
return celery_app
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restx import fields
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
simple_end_user_fields = {
|
||||
"id": fields.String,
|
||||
@@ -12,19 +10,6 @@ simple_end_user_fields = {
|
||||
"session_id": fields.String,
|
||||
}
|
||||
|
||||
end_user_detail_fields = {
|
||||
"id": fields.String,
|
||||
"tenant_id": fields.String,
|
||||
"app_id": fields.String,
|
||||
"type": fields.String,
|
||||
"external_user_id": fields.String,
|
||||
"name": fields.String,
|
||||
"is_anonymous": fields.Boolean,
|
||||
"session_id": fields.String,
|
||||
"created_at": fields.DateTime,
|
||||
"updated_at": fields.DateTime,
|
||||
}
|
||||
|
||||
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
@@ -41,23 +26,3 @@ class SimpleEndUser(ResponseModel):
|
||||
type: str
|
||||
is_anonymous: bool
|
||||
session_id: str | None = None
|
||||
|
||||
|
||||
class EndUserDetail(ResponseModel):
|
||||
"""Full EndUser record for API responses.
|
||||
|
||||
Note: The SQLAlchemy model defines an `is_anonymous` property for Flask-Login semantics
|
||||
(always False). The database column is exposed as `_is_anonymous`, so this DTO maps
|
||||
`is_anonymous` from `_is_anonymous` to return the stored value.
|
||||
"""
|
||||
|
||||
id: str
|
||||
tenant_id: str
|
||||
app_id: str | None = None
|
||||
type: str
|
||||
external_user_id: str | None = None
|
||||
name: str | None = None
|
||||
is_anonymous: bool = Field(validation_alias="_is_anonymous")
|
||||
session_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@@ -33,7 +33,7 @@ class ShardedTopic:
|
||||
return self
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._client.spublish(self._topic, payload) # type: ignore[attr-defined,union-attr]
|
||||
self._client.spublish(self._topic, payload) # type: ignore[attr-defined]
|
||||
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
return self
|
||||
@@ -75,7 +75,7 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
||||
#
|
||||
# Here we specify the `target_node` to mitigate this problem.
|
||||
node = self._client.get_node_from_key(self._topic)
|
||||
return self._pubsub.get_sharded_message( # type: ignore[attr-defined]
|
||||
return self._pubsub.get_sharded_message(
|
||||
ignore_subscribe_messages=False,
|
||||
timeout=1,
|
||||
target_node=node,
|
||||
|
||||
@@ -20,7 +20,7 @@ depends_on = None
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('account_trial_app_records',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('account_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('count', sa.Integer(), nullable=False),
|
||||
@@ -33,17 +33,17 @@ def upgrade():
|
||||
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
|
||||
|
||||
op.create_table('exporle_banners',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('content', sa.JSON(), nullable=False),
|
||||
sa.Column('link', sa.String(length=255), nullable=False),
|
||||
sa.Column('sort', sa.Integer(), nullable=False),
|
||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'"), nullable=False),
|
||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'"), nullable=False),
|
||||
sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
|
||||
)
|
||||
op.create_table('trial_apps',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
"""drop server_default for app trail related tables
|
||||
|
||||
Revision ID: c3df22613c99
|
||||
Revises: e8c3b3c46151
|
||||
Create Date: 2026-02-09 09:50:46.181969
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'c3df22613c99'
|
||||
down_revision = 'e8c3b3c46151'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.alter_column("account_trial_app_records", "id", server_default=None)
|
||||
op.alter_column("exporle_banners", "id", server_default=None)
|
||||
op.alter_column("trial_apps", "id", server_default=None)
|
||||
|
||||
|
||||
def downgrade():
|
||||
# This migration is primarily for schema consistence
|
||||
# between database and model definitions.
|
||||
#
|
||||
# The original
|
||||
# DROP SERVER default is idemponent.
|
||||
# Besides, the original migration has been updated to
|
||||
# reflect the
|
||||
pass
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
|
||||
@@ -51,16 +50,3 @@ class DefaultFieldsMixin:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(id={self.id})>"
|
||||
|
||||
|
||||
def gen_uuidv4_string() -> str:
|
||||
"""gen_uuidv4_string generate a UUIDv4 string.
|
||||
|
||||
NOTE: This function exists only for historical reasons. New models should use uuidv7 for primary key generation.
|
||||
"""
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def gen_uuidv7_string() -> str:
|
||||
"""gen_uuidv4_string generate a UUIDv4 string."""
|
||||
return str(uuidv7())
|
||||
|
||||
@@ -26,7 +26,7 @@ from libs.helper import generate_string # type: ignore[import-not-found]
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .account import Account, Tenant
|
||||
from .base import Base, TypeBase, gen_uuidv4_string
|
||||
from .base import Base, TypeBase
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole
|
||||
from .provider_ids import GenericProviderID
|
||||
@@ -620,7 +620,7 @@ class TrialApp(Base):
|
||||
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=gen_uuidv4_string)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
@@ -640,7 +640,7 @@ class AccountTrialAppRecord(Base):
|
||||
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
|
||||
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
|
||||
)
|
||||
id = mapped_column(StringUUID, default=gen_uuidv4_string)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
count = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
@@ -660,7 +660,7 @@ class AccountTrialAppRecord(Base):
|
||||
class ExporleBanner(TypeBase):
|
||||
__tablename__ = "exporle_banners"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv4_string, init=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
||||
link: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
|
||||
@@ -81,7 +81,7 @@ dependencies = [
|
||||
"starlette==0.49.1",
|
||||
"tiktoken~=0.9.0",
|
||||
"transformers~=4.56.1",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
|
||||
"yarl~=1.18.3",
|
||||
"webvtt-py~=0.5.1",
|
||||
"sseclient-py~=1.8.0",
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
"""
|
||||
Scheduled task to batch-update API token last_used_at timestamps.
|
||||
|
||||
Instead of updating the database on every request, token usage is recorded
|
||||
in Redis as lightweight SET keys (api_token_active:{scope}:{token}).
|
||||
This task runs periodically (default every 30 minutes) to flush those
|
||||
records into the database in a single batch operation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import click
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import ApiToken
|
||||
from services.api_token_service import ACTIVE_TOKEN_KEY_PREFIX
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.celery.task(queue="api_token")
|
||||
def batch_update_api_token_last_used():
|
||||
"""
|
||||
Batch update last_used_at for all recently active API tokens.
|
||||
|
||||
Scans Redis for api_token_active:* keys, parses the token and scope
|
||||
from each key, and performs a batch database update.
|
||||
"""
|
||||
click.echo(click.style("batch_update_api_token_last_used: start.", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
updated_count = 0
|
||||
scanned_count = 0
|
||||
|
||||
try:
|
||||
# Collect all active token keys and their values (the actual usage timestamps)
|
||||
token_entries: list[tuple[str, str | None, datetime]] = [] # (token, scope, usage_time)
|
||||
keys_to_delete: list[str | bytes] = []
|
||||
|
||||
for key in redis_client.scan_iter(match=f"{ACTIVE_TOKEN_KEY_PREFIX}*", count=200):
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode("utf-8")
|
||||
scanned_count += 1
|
||||
|
||||
# Read the value (ISO timestamp recorded at actual request time)
|
||||
value = redis_client.get(key)
|
||||
if not value:
|
||||
keys_to_delete.append(key)
|
||||
continue
|
||||
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("utf-8")
|
||||
|
||||
try:
|
||||
usage_time = datetime.fromisoformat(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning("Invalid timestamp in key %s: %s", key, value)
|
||||
keys_to_delete.append(key)
|
||||
continue
|
||||
|
||||
# Parse token info from key: api_token_active:{scope}:{token}
|
||||
suffix = key[len(ACTIVE_TOKEN_KEY_PREFIX) :]
|
||||
parts = suffix.split(":", 1)
|
||||
if len(parts) == 2:
|
||||
scope_str, token = parts
|
||||
scope = None if scope_str == "None" else scope_str
|
||||
token_entries.append((token, scope, usage_time))
|
||||
keys_to_delete.append(key)
|
||||
|
||||
if not token_entries:
|
||||
click.echo(click.style("batch_update_api_token_last_used: no active tokens found.", fg="yellow"))
|
||||
# Still clean up any invalid keys
|
||||
if keys_to_delete:
|
||||
redis_client.delete(*keys_to_delete)
|
||||
return
|
||||
|
||||
# Update each token in its own short transaction to avoid long transactions
|
||||
for token, scope, usage_time in token_entries:
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
stmt = (
|
||||
update(ApiToken)
|
||||
.where(
|
||||
ApiToken.token == token,
|
||||
ApiToken.type == scope,
|
||||
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < usage_time)),
|
||||
)
|
||||
.values(last_used_at=usage_time)
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
rowcount = getattr(result, "rowcount", 0)
|
||||
if rowcount > 0:
|
||||
updated_count += 1
|
||||
|
||||
# Delete processed keys from Redis
|
||||
if keys_to_delete:
|
||||
redis_client.delete(*keys_to_delete)
|
||||
|
||||
except Exception:
|
||||
logger.exception("batch_update_api_token_last_used failed")
|
||||
|
||||
elapsed = time.perf_counter() - start_at
|
||||
click.echo(
|
||||
click.style(
|
||||
f"batch_update_api_token_last_used: done. "
|
||||
f"scanned={scanned_count}, updated={updated_count}, elapsed={elapsed:.2f}s",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
@@ -1,330 +0,0 @@
|
||||
"""
|
||||
API Token Service
|
||||
|
||||
Handles all API token caching, validation, and usage recording.
|
||||
Includes Redis cache operations, database queries, and single-flight concurrency control.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import ApiToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Pydantic DTO
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
class CachedApiToken(BaseModel):
|
||||
"""
|
||||
Pydantic model for cached API token data.
|
||||
|
||||
This is NOT a SQLAlchemy model instance, but a plain Pydantic model
|
||||
that mimics the ApiToken model interface for read-only access.
|
||||
"""
|
||||
|
||||
id: str
|
||||
app_id: str | None
|
||||
tenant_id: str | None
|
||||
type: str
|
||||
token: str
|
||||
last_used_at: datetime | None
|
||||
created_at: datetime | None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<CachedApiToken id={self.id} type={self.type}>"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Cache configuration
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
CACHE_KEY_PREFIX = "api_token"
|
||||
CACHE_TTL_SECONDS = 600 # 10 minutes
|
||||
CACHE_NULL_TTL_SECONDS = 60 # 1 minute for non-existent tokens
|
||||
ACTIVE_TOKEN_KEY_PREFIX = "api_token_active:"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Cache class
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
class ApiTokenCache:
|
||||
"""
|
||||
Redis cache wrapper for API tokens.
|
||||
Handles serialization, deserialization, and cache invalidation.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def make_active_key(token: str, scope: str | None = None) -> str:
|
||||
"""Generate Redis key for recording token usage."""
|
||||
return f"{ACTIVE_TOKEN_KEY_PREFIX}{scope}:{token}"
|
||||
|
||||
@staticmethod
|
||||
def _make_tenant_index_key(tenant_id: str) -> str:
|
||||
"""Generate Redis key for tenant token index."""
|
||||
return f"tenant_tokens:{tenant_id}"
|
||||
|
||||
@staticmethod
|
||||
def _make_cache_key(token: str, scope: str | None = None) -> str:
|
||||
"""Generate cache key for the given token and scope."""
|
||||
scope_str = scope or "any"
|
||||
return f"{CACHE_KEY_PREFIX}:{scope_str}:{token}"
|
||||
|
||||
@staticmethod
|
||||
def _serialize_token(api_token: Any) -> bytes:
|
||||
"""Serialize ApiToken object to JSON bytes."""
|
||||
if isinstance(api_token, CachedApiToken):
|
||||
return api_token.model_dump_json().encode("utf-8")
|
||||
|
||||
cached = CachedApiToken(
|
||||
id=str(api_token.id),
|
||||
app_id=str(api_token.app_id) if api_token.app_id else None,
|
||||
tenant_id=str(api_token.tenant_id) if api_token.tenant_id else None,
|
||||
type=api_token.type,
|
||||
token=api_token.token,
|
||||
last_used_at=api_token.last_used_at,
|
||||
created_at=api_token.created_at,
|
||||
)
|
||||
return cached.model_dump_json().encode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_token(cached_data: bytes | str) -> Any:
|
||||
"""Deserialize JSON bytes/string back to a CachedApiToken Pydantic model."""
|
||||
if cached_data in {b"null", "null"}:
|
||||
return None
|
||||
|
||||
try:
|
||||
if isinstance(cached_data, bytes):
|
||||
cached_data = cached_data.decode("utf-8")
|
||||
return CachedApiToken.model_validate_json(cached_data)
|
||||
except (ValueError, Exception) as e:
|
||||
logger.warning("Failed to deserialize token from cache: %s", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def get(token: str, scope: str | None) -> Any | None:
|
||||
"""Get API token from cache."""
|
||||
cache_key = ApiTokenCache._make_cache_key(token, scope)
|
||||
cached_data = redis_client.get(cache_key)
|
||||
|
||||
if cached_data is None:
|
||||
logger.debug("Cache miss for token key: %s", cache_key)
|
||||
return None
|
||||
|
||||
logger.debug("Cache hit for token key: %s", cache_key)
|
||||
return ApiTokenCache._deserialize_token(cached_data)
|
||||
|
||||
@staticmethod
|
||||
def _add_to_tenant_index(tenant_id: str | None, cache_key: str) -> None:
|
||||
"""Add cache key to tenant index for efficient invalidation."""
|
||||
if not tenant_id:
|
||||
return
|
||||
|
||||
try:
|
||||
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
|
||||
redis_client.sadd(index_key, cache_key)
|
||||
redis_client.expire(index_key, CACHE_TTL_SECONDS + 60)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to update tenant index: %s", e)
|
||||
|
||||
@staticmethod
|
||||
def _remove_from_tenant_index(tenant_id: str | None, cache_key: str) -> None:
|
||||
"""Remove cache key from tenant index."""
|
||||
if not tenant_id:
|
||||
return
|
||||
|
||||
try:
|
||||
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
|
||||
redis_client.srem(index_key, cache_key)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to remove from tenant index: %s", e)
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def set(token: str, scope: str | None, api_token: Any | None, ttl: int = CACHE_TTL_SECONDS) -> bool:
|
||||
"""Set API token in cache."""
|
||||
cache_key = ApiTokenCache._make_cache_key(token, scope)
|
||||
|
||||
if api_token is None:
|
||||
cached_value = b"null"
|
||||
ttl = CACHE_NULL_TTL_SECONDS
|
||||
else:
|
||||
cached_value = ApiTokenCache._serialize_token(api_token)
|
||||
|
||||
try:
|
||||
redis_client.setex(cache_key, ttl, cached_value)
|
||||
|
||||
if api_token is not None and hasattr(api_token, "tenant_id"):
|
||||
ApiTokenCache._add_to_tenant_index(api_token.tenant_id, cache_key)
|
||||
|
||||
logger.debug("Cached token with key: %s, ttl: %ss", cache_key, ttl)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cache token: %s", e)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def delete(token: str, scope: str | None = None) -> bool:
|
||||
"""Delete API token from cache."""
|
||||
if scope is None:
|
||||
pattern = f"{CACHE_KEY_PREFIX}:*:{token}"
|
||||
try:
|
||||
keys_to_delete = list(redis_client.scan_iter(match=pattern))
|
||||
if keys_to_delete:
|
||||
redis_client.delete(*keys_to_delete)
|
||||
logger.info("Deleted %d cache entries for token", len(keys_to_delete))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete token cache with pattern: %s", e)
|
||||
return False
|
||||
else:
|
||||
cache_key = ApiTokenCache._make_cache_key(token, scope)
|
||||
try:
|
||||
tenant_id = None
|
||||
try:
|
||||
cached_data = redis_client.get(cache_key)
|
||||
if cached_data and cached_data != b"null":
|
||||
cached_token = ApiTokenCache._deserialize_token(cached_data)
|
||||
if cached_token:
|
||||
tenant_id = cached_token.tenant_id
|
||||
except Exception as e:
|
||||
logger.debug("Failed to get tenant_id for cache cleanup: %s", e)
|
||||
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
if tenant_id:
|
||||
ApiTokenCache._remove_from_tenant_index(tenant_id, cache_key)
|
||||
|
||||
logger.info("Deleted cache for key: %s", cache_key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete token cache: %s", e)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def invalidate_by_tenant(tenant_id: str) -> bool:
|
||||
"""Invalidate all API token caches for a specific tenant via tenant index."""
|
||||
try:
|
||||
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
|
||||
cache_keys = redis_client.smembers(index_key)
|
||||
|
||||
if cache_keys:
|
||||
deleted_count = 0
|
||||
for cache_key in cache_keys:
|
||||
if isinstance(cache_key, bytes):
|
||||
cache_key = cache_key.decode("utf-8")
|
||||
redis_client.delete(cache_key)
|
||||
deleted_count += 1
|
||||
|
||||
redis_client.delete(index_key)
|
||||
|
||||
logger.info(
|
||||
"Invalidated %d token cache entries for tenant: %s",
|
||||
deleted_count,
|
||||
tenant_id,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"No tenant index found for %s, relying on TTL expiration",
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to invalidate tenant token cache: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Token usage recording (for batch update)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def record_token_usage(auth_token: str, scope: str | None) -> None:
|
||||
"""
|
||||
Record token usage in Redis for later batch update by a scheduled job.
|
||||
|
||||
Instead of dispatching a Celery task per request, we simply SET a key in Redis.
|
||||
A Celery Beat scheduled task will periodically scan these keys and batch-update
|
||||
last_used_at in the database.
|
||||
"""
|
||||
try:
|
||||
key = ApiTokenCache.make_active_key(auth_token, scope)
|
||||
redis_client.set(key, naive_utc_now().isoformat(), ex=3600)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to record token usage: %s", e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Database query + single-flight
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def query_token_from_db(auth_token: str, scope: str | None) -> ApiToken:
|
||||
"""
|
||||
Query API token from database and cache the result.
|
||||
|
||||
Raises Unauthorized if token is invalid.
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
api_token = session.scalar(stmt)
|
||||
|
||||
if not api_token:
|
||||
ApiTokenCache.set(auth_token, scope, None)
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
ApiTokenCache.set(auth_token, scope, api_token)
|
||||
record_token_usage(auth_token, scope)
|
||||
return api_token
|
||||
|
||||
|
||||
def fetch_token_with_single_flight(auth_token: str, scope: str | None) -> ApiToken | Any:
|
||||
"""
|
||||
Fetch token from DB with single-flight pattern using Redis lock.
|
||||
|
||||
Ensures only one concurrent request queries the database for the same token.
|
||||
Falls back to direct query if lock acquisition fails.
|
||||
"""
|
||||
logger.debug("Token cache miss, attempting to acquire query lock for scope: %s", scope)
|
||||
|
||||
lock_key = f"api_token_query_lock:{scope}:{auth_token}"
|
||||
lock = redis_client.lock(lock_key, timeout=10, blocking_timeout=5)
|
||||
|
||||
try:
|
||||
if lock.acquire(blocking=True):
|
||||
try:
|
||||
cached_token = ApiTokenCache.get(auth_token, scope)
|
||||
if cached_token is not None:
|
||||
logger.debug("Token cached by concurrent request, using cached version")
|
||||
return cached_token
|
||||
|
||||
return query_token_from_db(auth_token, scope)
|
||||
finally:
|
||||
lock.release()
|
||||
else:
|
||||
logger.warning("Lock timeout for token: %s, proceeding with direct query", auth_token[:10])
|
||||
return query_token_from_db(auth_token, scope)
|
||||
except Unauthorized:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning("Redis lock failed for token query: %s, proceeding anyway", e)
|
||||
return query_token_from_db(auth_token, scope)
|
||||
@@ -155,11 +155,11 @@ class AsyncWorkflowService:
|
||||
|
||||
task: AsyncResult[Any] | None = None
|
||||
if queue_name == QueuePriority.PROFESSIONAL:
|
||||
task = execute_workflow_professional.delay(task_data_dict)
|
||||
task = execute_workflow_professional.delay(task_data_dict) # type: ignore
|
||||
elif queue_name == QueuePriority.TEAM:
|
||||
task = execute_workflow_team.delay(task_data_dict)
|
||||
task = execute_workflow_team.delay(task_data_dict) # type: ignore
|
||||
else: # SANDBOX
|
||||
task = execute_workflow_sandbox.delay(task_data_dict)
|
||||
task = execute_workflow_sandbox.delay(task_data_dict) # type: ignore
|
||||
|
||||
# 10. Update trigger log with task info
|
||||
trigger_log.status = WorkflowTriggerStatus.QUEUED
|
||||
@@ -170,7 +170,7 @@ class AsyncWorkflowService:
|
||||
|
||||
return AsyncTriggerResponse(
|
||||
workflow_trigger_log_id=trigger_log.id,
|
||||
task_id=task.id,
|
||||
task_id=task.id, # type: ignore
|
||||
status="queued",
|
||||
queue=queue_name,
|
||||
)
|
||||
|
||||
@@ -1696,18 +1696,13 @@ class DocumentService:
|
||||
for document in documents
|
||||
if document.data_source_type == "upload_file" and document.data_source_info_dict
|
||||
]
|
||||
if dataset.doc_form is not None:
|
||||
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
|
||||
|
||||
# Delete documents first, then dispatch cleanup task after commit
|
||||
# to avoid deadlock between main transaction and async task
|
||||
for document in documents:
|
||||
db.session.delete(document)
|
||||
db.session.commit()
|
||||
|
||||
# Dispatch cleanup task after commit to avoid lock contention
|
||||
# Task cleans up segments, files, and vector indexes
|
||||
if dataset.doc_form is not None:
|
||||
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
|
||||
|
||||
@staticmethod
|
||||
def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
@@ -16,25 +16,6 @@ class EndUserService:
|
||||
Service for managing end users.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_end_user_by_id(cls, *, tenant_id: str, app_id: str, end_user_id: str) -> EndUser | None:
|
||||
"""Get an end user by primary key.
|
||||
|
||||
This is scoped to the provided tenant and app to prevent cross-tenant/app access
|
||||
when an end-user ID is known.
|
||||
"""
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
return (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.id == end_user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
EndUser.app_id == app_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_or_create_end_user(cls, app_model: App, user_id: str | None = None) -> EndUser:
|
||||
"""
|
||||
|
||||
@@ -1329,24 +1329,10 @@ class RagPipelineService:
|
||||
"""
|
||||
Get datasource plugins
|
||||
"""
|
||||
dataset: Dataset | None = (
|
||||
db.session.query(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = (
|
||||
db.session.query(Pipeline)
|
||||
.where(
|
||||
Pipeline.id == dataset.pipeline_id,
|
||||
Pipeline.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
@@ -1427,24 +1413,10 @@ class RagPipelineService:
|
||||
"""
|
||||
Get pipeline
|
||||
"""
|
||||
dataset: Dataset | None = (
|
||||
db.session.query(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = (
|
||||
db.session.query(Pipeline)
|
||||
.where(
|
||||
Pipeline.id == dataset.pipeline_id,
|
||||
Pipeline.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
return pipeline
|
||||
|
||||
@@ -10,7 +10,6 @@ from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -335,7 +335,7 @@ def _build_node_finished_event(
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status=WorkflowNodeExecutionStatus(snapshot.status),
|
||||
status=snapshot.status,
|
||||
error=None,
|
||||
elapsed_time=snapshot.elapsed_time,
|
||||
execution_metadata=None,
|
||||
@@ -373,7 +373,7 @@ def _build_pause_event(
|
||||
paused_nodes=paused_nodes,
|
||||
outputs=outputs,
|
||||
reasons=reasons,
|
||||
status=workflow_run.status,
|
||||
status=workflow_run.status.value,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
elapsed_time=float(workflow_run.elapsed_time or 0.0),
|
||||
total_tokens=int(workflow_run.total_tokens or 0),
|
||||
|
||||
@@ -6,6 +6,7 @@ from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
@@ -57,3 +58,5 @@ def add_annotation_to_index_task(
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Build index for annotation failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
@@ -5,6 +5,7 @@ import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
@@ -39,3 +40,5 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
|
||||
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("Annotation deleted index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
@@ -6,6 +6,7 @@ from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
@@ -58,3 +59,5 @@ def update_annotation_to_index_task(
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Build index for annotation failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
@@ -14,9 +14,6 @@ from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Batch size for database operations to keep transactions short
|
||||
BATCH_SIZE = 1000
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
|
||||
@@ -34,179 +31,63 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
||||
if not doc_form:
|
||||
raise ValueError("doc_form is required")
|
||||
|
||||
storage_keys_to_delete: list[str] = []
|
||||
index_node_ids: list[str] = []
|
||||
segment_ids: list[str] = []
|
||||
total_image_upload_file_ids: list[str] = []
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
|
||||
session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id.in_(document_ids),
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
try:
|
||||
# ============ Step 1: Query segment and file data (short read-only transaction) ============
|
||||
with session_factory.create_session() as session:
|
||||
# Get segments info
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
|
||||
).all()
|
||||
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
|
||||
# Collect image file IDs from segment content
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
total_image_upload_file_ids.extend(image_upload_file_ids)
|
||||
|
||||
# Query storage keys for image files
|
||||
if total_image_upload_file_ids:
|
||||
image_files = session.scalars(
|
||||
select(UploadFile).where(UploadFile.id.in_(total_image_upload_file_ids))
|
||||
).all()
|
||||
storage_keys_to_delete.extend([f.key for f in image_files if f and f.key])
|
||||
|
||||
# Query storage keys for document files
|
||||
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
|
||||
for image_file in image_files:
|
||||
try:
|
||||
if image_file and image_file.key:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
image_file.id,
|
||||
)
|
||||
stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
session.execute(stmt)
|
||||
session.delete(segment)
|
||||
if file_ids:
|
||||
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
|
||||
storage_keys_to_delete.extend([f.key for f in files if f and f.key])
|
||||
for file in files:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
except Exception:
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
|
||||
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
|
||||
session.execute(stmt)
|
||||
|
||||
# ============ Step 2: Clean vector index (external service, fresh session for dataset) ============
|
||||
if index_node_ids:
|
||||
try:
|
||||
# Fetch dataset in a fresh session to avoid DetachedInstanceError
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.warning("Dataset not found for vector index cleanup, dataset_id: %s", dataset_id)
|
||||
else:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to clean vector index for dataset_id: %s, document_ids: %s, index_node_ids count: %d",
|
||||
dataset_id,
|
||||
document_ids,
|
||||
len(index_node_ids),
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# ============ Step 3: Delete metadata binding (separate short transaction) ============
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
deleted_count = (
|
||||
session.query(DatasetMetadataBinding)
|
||||
.where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id.in_(document_ids),
|
||||
)
|
||||
.delete(synchronize_session=False)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
session.commit()
|
||||
logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to delete metadata bindings for dataset_id: %s, document_ids: %s",
|
||||
dataset_id,
|
||||
document_ids,
|
||||
)
|
||||
|
||||
# ============ Step 4: Batch delete UploadFile records (multiple short transactions) ============
|
||||
if total_image_upload_file_ids:
|
||||
failed_batches = 0
|
||||
total_batches = (len(total_image_upload_file_ids) + BATCH_SIZE - 1) // BATCH_SIZE
|
||||
for i in range(0, len(total_image_upload_file_ids), BATCH_SIZE):
|
||||
batch = total_image_upload_file_ids[i : i + BATCH_SIZE]
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = delete(UploadFile).where(UploadFile.id.in_(batch))
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
except Exception:
|
||||
failed_batches += 1
|
||||
logger.exception(
|
||||
"Failed to delete image UploadFile batch %d-%d for dataset_id: %s",
|
||||
i,
|
||||
i + len(batch),
|
||||
dataset_id,
|
||||
)
|
||||
if failed_batches > 0:
|
||||
logger.warning(
|
||||
"Image UploadFile deletion: %d/%d batches failed for dataset_id: %s",
|
||||
failed_batches,
|
||||
total_batches,
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
# ============ Step 5: Batch delete DocumentSegment records (multiple short transactions) ============
|
||||
if segment_ids:
|
||||
failed_batches = 0
|
||||
total_batches = (len(segment_ids) + BATCH_SIZE - 1) // BATCH_SIZE
|
||||
for i in range(0, len(segment_ids), BATCH_SIZE):
|
||||
batch = segment_ids[i : i + BATCH_SIZE]
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(batch))
|
||||
session.execute(segment_delete_stmt)
|
||||
session.commit()
|
||||
except Exception:
|
||||
failed_batches += 1
|
||||
logger.exception(
|
||||
"Failed to delete DocumentSegment batch %d-%d for dataset_id: %s, document_ids: %s",
|
||||
i,
|
||||
i + len(batch),
|
||||
dataset_id,
|
||||
document_ids,
|
||||
)
|
||||
if failed_batches > 0:
|
||||
logger.warning(
|
||||
"DocumentSegment deletion: %d/%d batches failed, document_ids: %s",
|
||||
failed_batches,
|
||||
total_batches,
|
||||
document_ids,
|
||||
)
|
||||
|
||||
# ============ Step 6: Delete document-associated files (separate short transaction) ============
|
||||
if file_ids:
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to delete document UploadFile records for dataset_id: %s, file_ids: %s",
|
||||
dataset_id,
|
||||
file_ids,
|
||||
)
|
||||
|
||||
# ============ Step 7: Delete storage files (I/O operations, no DB transaction) ============
|
||||
storage_delete_failures = 0
|
||||
for storage_key in storage_keys_to_delete:
|
||||
try:
|
||||
storage.delete(storage_key)
|
||||
except Exception:
|
||||
storage_delete_failures += 1
|
||||
logger.exception("Failed to delete file from storage, key: %s", storage_key)
|
||||
if storage_delete_failures > 0:
|
||||
logger.warning(
|
||||
"Storage file deletion completed with %d failures out of %d total files for dataset_id: %s",
|
||||
storage_delete_failures,
|
||||
len(storage_keys_to_delete),
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned documents when documents deleted latency: {end_at - start_at:.2f}s, "
|
||||
f"dataset_id: {dataset_id}, document_ids: {document_ids}, "
|
||||
f"segments: {len(segment_ids)}, image_files: {len(total_image_upload_file_ids)}, "
|
||||
f"storage_files: {len(storage_keys_to_delete)}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Batch clean documents failed for dataset_id: %s, document_ids: %s",
|
||||
dataset_id,
|
||||
document_ids,
|
||||
)
|
||||
logger.exception("Cleaned documents when documents deleted failed")
|
||||
|
||||
@@ -48,11 +48,6 @@ def batch_create_segment_to_index_task(
|
||||
|
||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||
|
||||
# Initialize variables with default values
|
||||
upload_file_key: str | None = None
|
||||
dataset_config: dict | None = None
|
||||
document_config: dict | None = None
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
@@ -74,115 +69,86 @@ def batch_create_segment_to_index_task(
|
||||
if not upload_file:
|
||||
raise ValueError("UploadFile not found.")
|
||||
|
||||
dataset_config = {
|
||||
"id": dataset.id,
|
||||
"indexing_technique": dataset.indexing_technique,
|
||||
"tenant_id": dataset.tenant_id,
|
||||
"embedding_model_provider": dataset.embedding_model_provider,
|
||||
"embedding_model": dataset.embedding_model,
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file.key, file_path)
|
||||
|
||||
document_config = {
|
||||
"id": dataset_document.id,
|
||||
"doc_form": dataset_document.doc_form,
|
||||
"word_count": dataset_document.word_count or 0,
|
||||
}
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
|
||||
upload_file_key = upload_file.key
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Segments batch created index failed")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
return
|
||||
|
||||
# Ensure required variables are set before proceeding
|
||||
if upload_file_key is None or dataset_config is None or document_config is None:
|
||||
logger.error("Required configuration not set due to session error")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
return
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file_key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file_key, file_path)
|
||||
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if document_config["doc_form"] == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
word_count_change = 0
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(
|
||||
texts=[segment["content"] for segment in content]
|
||||
)
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
tokens_list = [0] * len(content)
|
||||
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset_config["indexing_technique"] == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset_config["tenant_id"],
|
||||
provider=dataset_config["embedding_model_provider"],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset_config["embedding_model"],
|
||||
)
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == dataset_document.id)
|
||||
.scalar()
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=naive_utc_now(),
|
||||
status="completed",
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
|
||||
word_count_change = 0
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content])
|
||||
else:
|
||||
tokens_list = [0] * len(content)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == document_config["id"])
|
||||
.scalar()
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=naive_utc_now(),
|
||||
status="completed",
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if document_config["doc_form"] == "qa_model":
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
dataset_document = session.get(Document, document_id)
|
||||
if dataset_document:
|
||||
assert dataset_document.word_count is not None
|
||||
dataset_document.word_count += word_count_change
|
||||
session.add(dataset_document)
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if dataset:
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"])
|
||||
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
|
||||
session.commit()
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Segments batch created index failed")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
|
||||
@@ -28,7 +28,6 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
"""
|
||||
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
total_attachment_files = []
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
@@ -48,91 +47,78 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
SegmentAttachmentBinding.document_id == document_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
|
||||
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
|
||||
total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings])
|
||||
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
segment_contents = [segment.content for segment in segments]
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document deleted failed")
|
||||
return
|
||||
|
||||
# check segment is exist
|
||||
if index_node_ids:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset:
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
|
||||
total_image_files = []
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
for segment_content in segment_contents:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment_content)
|
||||
image_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))).all()
|
||||
total_image_files.extend([image_file.key for image_file in image_files])
|
||||
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
session.execute(image_file_delete_stmt)
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
image_files = session.scalars(
|
||||
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
).all()
|
||||
for image_file in image_files:
|
||||
if image_file is None:
|
||||
continue
|
||||
try:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
image_file.id,
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
session.execute(segment_delete_stmt)
|
||||
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
session.execute(image_file_delete_stmt)
|
||||
session.delete(segment)
|
||||
|
||||
for image_file_key in total_image_files:
|
||||
try:
|
||||
storage.delete(image_file_key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
image_file_key,
|
||||
session.commit()
|
||||
if file_id:
|
||||
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if file:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
except Exception:
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
||||
session.delete(file)
|
||||
# delete segment attachments
|
||||
if attachments_with_bindings:
|
||||
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
|
||||
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
|
||||
for binding, attachment_file in attachments_with_bindings:
|
||||
try:
|
||||
storage.delete(attachment_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete attachment_file failed when storage deleted, \
|
||||
attachment_file_id: %s",
|
||||
binding.attachment_id,
|
||||
)
|
||||
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||
session.execute(attachment_file_delete_stmt)
|
||||
|
||||
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
|
||||
SegmentAttachmentBinding.id.in_(binding_ids)
|
||||
)
|
||||
session.execute(binding_delete_stmt)
|
||||
|
||||
# delete dataset metadata binding
|
||||
session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id == document_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
if file_id:
|
||||
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if file:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
except Exception:
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
||||
session.delete(file)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
# delete segment attachments
|
||||
if attachment_ids:
|
||||
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||
session.execute(attachment_file_delete_stmt)
|
||||
|
||||
if binding_ids:
|
||||
binding_delete_stmt = delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.id.in_(binding_ids))
|
||||
session.execute(binding_delete_stmt)
|
||||
|
||||
for attachment_file_key in total_attachment_files:
|
||||
try:
|
||||
storage.delete(attachment_file_key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete attachment_file failed when storage deleted, \
|
||||
attachment_file_id: %s",
|
||||
attachment_file_key,
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
# delete dataset metadata binding
|
||||
session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id == document_id,
|
||||
).delete()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
logger.exception("Cleaned document when document deleted failed")
|
||||
|
||||
@@ -23,40 +23,40 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
|
||||
"""
|
||||
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
total_index_node_ids = []
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
|
||||
session.execute(document_delete_stmt)
|
||||
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
|
||||
session.execute(document_delete_stmt)
|
||||
|
||||
for document_id in document_ids:
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
total_index_node_ids.extend([segment.index_node_id for segment in segments])
|
||||
for document_id in document_ids:
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset:
|
||||
index_processor.clean(
|
||||
dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
session.commit()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Clean document when import form notion document deleted end :: {} latency: {}".format(
|
||||
dataset_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Clean document when import form notion document deleted end :: {} latency: {}".format(
|
||||
dataset_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when import form notion document deleted failed")
|
||||
|
||||
@@ -3,7 +3,6 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
@@ -68,14 +67,8 @@ def delete_segment_from_index_task(
|
||||
if segment_attachment_bindings:
|
||||
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
|
||||
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
|
||||
segment_attachment_bind_ids = [i.id for i in segment_attachment_bindings]
|
||||
|
||||
for i in range(0, len(segment_attachment_bind_ids), 1000):
|
||||
segment_attachment_bind_delete_stmt = delete(SegmentAttachmentBinding).where(
|
||||
SegmentAttachmentBinding.id.in_(segment_attachment_bind_ids[i : i + 1000])
|
||||
)
|
||||
session.execute(segment_attachment_bind_delete_stmt)
|
||||
|
||||
for binding in segment_attachment_bindings:
|
||||
session.delete(binding)
|
||||
# delete upload file
|
||||
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
|
||||
@@ -27,129 +27,104 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
"""
|
||||
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
tenant_id = None
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
return
|
||||
|
||||
if document.indexing_status == "parsing":
|
||||
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
|
||||
return
|
||||
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
|
||||
data_source_info = document.data_source_info_dict
|
||||
if document.data_source_type != "notion_import":
|
||||
logger.info(click.style(f"Document {document_id} is not a notion_import, skipping", fg="yellow"))
|
||||
return
|
||||
if document.data_source_type == "notion_import":
|
||||
if (
|
||||
not data_source_info
|
||||
or "notion_page_id" not in data_source_info
|
||||
or "notion_workspace_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no notion page found")
|
||||
workspace_id = data_source_info["notion_workspace_id"]
|
||||
page_id = data_source_info["notion_page_id"]
|
||||
page_type = data_source_info["type"]
|
||||
page_edited_time = data_source_info["last_edited_time"]
|
||||
credential_id = data_source_info.get("credential_id")
|
||||
|
||||
if (
|
||||
not data_source_info
|
||||
or "notion_page_id" not in data_source_info
|
||||
or "notion_workspace_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no notion page found")
|
||||
# Get credentials from datasource provider
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=document.tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
workspace_id = data_source_info["notion_workspace_id"]
|
||||
page_id = data_source_info["notion_page_id"]
|
||||
page_type = data_source_info["type"]
|
||||
page_edited_time = data_source_info["last_edited_time"]
|
||||
credential_id = data_source_info.get("credential_id")
|
||||
tenant_id = document.tenant_id
|
||||
index_type = document.doc_form
|
||||
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# Get credentials from datasource provider
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
if not credential:
|
||||
logger.error(
|
||||
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
|
||||
document_id,
|
||||
tenant_id,
|
||||
credential_id,
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if document:
|
||||
if not credential:
|
||||
logger.error(
|
||||
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
|
||||
document_id,
|
||||
document.tenant_id,
|
||||
credential_id,
|
||||
)
|
||||
document.indexing_status = "error"
|
||||
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
|
||||
document.stopped_at = naive_utc_now()
|
||||
return
|
||||
session.commit()
|
||||
return
|
||||
|
||||
loader = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
loader = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=document.tenant_id,
|
||||
)
|
||||
|
||||
last_edited_time = loader.get_notion_last_edited_time()
|
||||
if last_edited_time == page_edited_time:
|
||||
logger.info(click.style(f"Document {document_id} content unchanged, skipping sync", fg="yellow"))
|
||||
return
|
||||
last_edited_time = loader.get_notion_last_edited_time()
|
||||
|
||||
logger.info(click.style(f"Document {document_id} content changed, starting sync", fg="green"))
|
||||
# check the page is updated
|
||||
if last_edited_time != page_edited_time:
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
session.commit()
|
||||
|
||||
try:
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset:
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("Failed to clean vector index for document %s", document_id)
|
||||
# delete all document segment and index
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if not document:
|
||||
logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow"))
|
||||
return
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
data_source_info = document.data_source_info_dict
|
||||
data_source_info["last_edited_time"] = last_edited_time
|
||||
document.data_source_info = data_source_info
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
session.execute(segment_delete_stmt)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Cleaned document when document update data source or process rule: {} latency: {}".format(
|
||||
document_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document update data source or process rule failed")
|
||||
|
||||
logger.info(click.style(f"Deleted segments for document {document_id}", fg="green"))
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if document:
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Sync completed for document {document_id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception as e:
|
||||
logger.exception("document_indexing_sync_task failed for document_id: %s", document_id)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
|
||||
|
||||
@@ -81,35 +81,26 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# Phase 1: Update status to parsing (short transaction)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
documents = (
|
||||
session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all()
|
||||
)
|
||||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
|
||||
for document in documents:
|
||||
if document:
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
session.add(document)
|
||||
# Transaction committed and closed
|
||||
session.commit()
|
||||
|
||||
# Phase 2: Execute indexing (no transaction - IndexingRunner creates its own sessions)
|
||||
has_error = False
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
has_error = True
|
||||
except Exception:
|
||||
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
||||
has_error = True
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
|
||||
if not has_error:
|
||||
with session_factory.create_session() as session:
|
||||
# Trigger summary index generation for completed documents if enabled
|
||||
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
|
||||
# Re-query dataset to get latest summary_index_setting (in case it was updated)
|
||||
@@ -124,18 +115,17 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
# expire all session to get latest document's indexing status
|
||||
session.expire_all()
|
||||
# Check each document's indexing status and trigger summary generation if completed
|
||||
|
||||
documents = (
|
||||
session.query(Document)
|
||||
.where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for document in documents:
|
||||
for document_id in document_ids:
|
||||
# Re-query document to get latest status (IndexingRunner may have updated it)
|
||||
document = (
|
||||
session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if document:
|
||||
logger.info(
|
||||
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
|
||||
document.id,
|
||||
document_id,
|
||||
document.indexing_status,
|
||||
document.doc_form,
|
||||
document.need_summary,
|
||||
@@ -146,36 +136,46 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
and document.need_summary is True
|
||||
):
|
||||
try:
|
||||
generate_summary_index_task.delay(dataset.id, document.id, None)
|
||||
generate_summary_index_task.delay(dataset.id, document_id, None)
|
||||
logger.info(
|
||||
"Queued summary index generation task for document %s in dataset %s "
|
||||
"after indexing completed",
|
||||
document.id,
|
||||
document_id,
|
||||
dataset.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to queue summary index generation task for document %s",
|
||||
document.id,
|
||||
document_id,
|
||||
)
|
||||
# Don't fail the entire indexing process if summary task queuing fails
|
||||
else:
|
||||
logger.info(
|
||||
"Skipping summary generation for document %s: "
|
||||
"status=%s, doc_form=%s, need_summary=%s",
|
||||
document.id,
|
||||
document_id,
|
||||
document.indexing_status,
|
||||
document.doc_form,
|
||||
document.need_summary,
|
||||
)
|
||||
else:
|
||||
logger.warning("Document %s not found after indexing", document.id)
|
||||
logger.warning("Document %s not found after indexing", document_id)
|
||||
else:
|
||||
logger.info(
|
||||
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
|
||||
dataset.id,
|
||||
summary_index_setting.get("enable") if summary_index_setting else None,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
|
||||
dataset.id,
|
||||
dataset.indexing_technique,
|
||||
)
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
||||
|
||||
|
||||
def _document_indexing_with_tenant_queue(
|
||||
|
||||
@@ -36,19 +36,25 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return
|
||||
# delete all document segment and index
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
|
||||
index_type = document.doc_form
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
|
||||
clean_success = False
|
||||
try:
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if index_node_ids:
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
@@ -58,21 +64,15 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
clean_success = True
|
||||
except Exception:
|
||||
logger.exception("Failed to clean document index during update, document_id: %s", document_id)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document update data source or process rule failed")
|
||||
|
||||
if clean_success:
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
session.execute(segment_delete_stmt)
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
|
||||
|
||||
@@ -48,7 +48,6 @@ from models.workflow import (
|
||||
WorkflowArchiveLog,
|
||||
)
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.api_token_service import ApiTokenCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -135,12 +134,6 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_api_tokens(tenant_id: str, app_id: str):
|
||||
def del_api_token(session, api_token_id: str):
|
||||
# Fetch token details for cache invalidation
|
||||
token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first()
|
||||
if token_obj:
|
||||
# Invalidate cache before deletion
|
||||
ApiTokenCache.delete(token_obj.token, token_obj.type)
|
||||
|
||||
session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
|
||||
@@ -6,8 +6,9 @@ improving performance by offloading storage operations to background workers.
|
||||
"""
|
||||
|
||||
from celery import shared_task # type: ignore[import-untyped]
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
|
||||
|
||||
|
||||
@@ -16,6 +17,6 @@ def save_workflow_execution_task(
|
||||
self,
|
||||
deletions: list[DraftVarFileDeletion],
|
||||
):
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
srv = WorkflowDraftVariableService(session=session)
|
||||
srv.delete_workflow_draft_variable_file(deletions=deletions)
|
||||
|
||||
@@ -1,375 +0,0 @@
|
||||
"""
|
||||
Integration tests for API Token Cache with Redis.
|
||||
|
||||
These tests require:
|
||||
- Redis server running
|
||||
- Test database configured
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import ApiToken
|
||||
from services.api_token_service import ApiTokenCache, CachedApiToken
|
||||
|
||||
|
||||
class TestApiTokenCacheRedisIntegration:
|
||||
"""Integration tests with real Redis."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test fixtures and clean Redis."""
|
||||
self.test_token = "test-integration-token-123"
|
||||
self.test_scope = "app"
|
||||
self.cache_key = f"api_token:{self.test_scope}:{self.test_token}"
|
||||
|
||||
# Clean up any existing test data
|
||||
self._cleanup()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Cleanup test data from Redis."""
|
||||
self._cleanup()
|
||||
|
||||
def _cleanup(self):
|
||||
"""Remove test data from Redis."""
|
||||
try:
|
||||
redis_client.delete(self.cache_key)
|
||||
redis_client.delete(ApiTokenCache._make_tenant_index_key("test-tenant-id"))
|
||||
redis_client.delete(ApiTokenCache.make_active_key(self.test_token, self.test_scope))
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
def test_cache_set_and_get_with_real_redis(self):
|
||||
"""Test cache set and get operations with real Redis."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_token = MagicMock()
|
||||
mock_token.id = "test-id-123"
|
||||
mock_token.app_id = "test-app-456"
|
||||
mock_token.tenant_id = "test-tenant-789"
|
||||
mock_token.type = "app"
|
||||
mock_token.token = self.test_token
|
||||
mock_token.last_used_at = datetime.now()
|
||||
mock_token.created_at = datetime.now() - timedelta(days=30)
|
||||
|
||||
# Set in cache
|
||||
result = ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
|
||||
assert result is True
|
||||
|
||||
# Verify in Redis
|
||||
cached_data = redis_client.get(self.cache_key)
|
||||
assert cached_data is not None
|
||||
|
||||
# Get from cache
|
||||
cached_token = ApiTokenCache.get(self.test_token, self.test_scope)
|
||||
assert cached_token is not None
|
||||
assert isinstance(cached_token, CachedApiToken)
|
||||
assert cached_token.id == "test-id-123"
|
||||
assert cached_token.app_id == "test-app-456"
|
||||
assert cached_token.tenant_id == "test-tenant-789"
|
||||
assert cached_token.type == "app"
|
||||
assert cached_token.token == self.test_token
|
||||
|
||||
def test_cache_ttl_with_real_redis(self):
|
||||
"""Test cache TTL is set correctly."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_token = MagicMock()
|
||||
mock_token.id = "test-id"
|
||||
mock_token.app_id = "test-app"
|
||||
mock_token.tenant_id = "test-tenant"
|
||||
mock_token.type = "app"
|
||||
mock_token.token = self.test_token
|
||||
mock_token.last_used_at = None
|
||||
mock_token.created_at = datetime.now()
|
||||
|
||||
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
|
||||
|
||||
ttl = redis_client.ttl(self.cache_key)
|
||||
assert 595 <= ttl <= 600 # Should be around 600 seconds (10 minutes)
|
||||
|
||||
def test_cache_null_value_for_invalid_token(self):
|
||||
"""Test caching null value for invalid tokens."""
|
||||
result = ApiTokenCache.set(self.test_token, self.test_scope, None)
|
||||
assert result is True
|
||||
|
||||
cached_data = redis_client.get(self.cache_key)
|
||||
assert cached_data == b"null"
|
||||
|
||||
cached_token = ApiTokenCache.get(self.test_token, self.test_scope)
|
||||
assert cached_token is None
|
||||
|
||||
ttl = redis_client.ttl(self.cache_key)
|
||||
assert 55 <= ttl <= 60
|
||||
|
||||
def test_cache_delete_with_real_redis(self):
|
||||
"""Test cache deletion with real Redis."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_token = MagicMock()
|
||||
mock_token.id = "test-id"
|
||||
mock_token.app_id = "test-app"
|
||||
mock_token.tenant_id = "test-tenant"
|
||||
mock_token.type = "app"
|
||||
mock_token.token = self.test_token
|
||||
mock_token.last_used_at = None
|
||||
mock_token.created_at = datetime.now()
|
||||
|
||||
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
|
||||
assert redis_client.exists(self.cache_key) == 1
|
||||
|
||||
result = ApiTokenCache.delete(self.test_token, self.test_scope)
|
||||
assert result is True
|
||||
assert redis_client.exists(self.cache_key) == 0
|
||||
|
||||
def test_tenant_index_creation(self):
|
||||
"""Test tenant index is created when caching token."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
tenant_id = "test-tenant-id"
|
||||
mock_token = MagicMock()
|
||||
mock_token.id = "test-id"
|
||||
mock_token.app_id = "test-app"
|
||||
mock_token.tenant_id = tenant_id
|
||||
mock_token.type = "app"
|
||||
mock_token.token = self.test_token
|
||||
mock_token.last_used_at = None
|
||||
mock_token.created_at = datetime.now()
|
||||
|
||||
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
|
||||
|
||||
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
|
||||
assert redis_client.exists(index_key) == 1
|
||||
|
||||
members = redis_client.smembers(index_key)
|
||||
cache_keys = [m.decode("utf-8") if isinstance(m, bytes) else m for m in members]
|
||||
assert self.cache_key in cache_keys
|
||||
|
||||
def test_invalidate_by_tenant_via_index(self):
|
||||
"""Test tenant-wide cache invalidation using index (fast path)."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
tenant_id = "test-tenant-id"
|
||||
|
||||
for i in range(3):
|
||||
token_value = f"test-token-{i}"
|
||||
mock_token = MagicMock()
|
||||
mock_token.id = f"test-id-{i}"
|
||||
mock_token.app_id = "test-app"
|
||||
mock_token.tenant_id = tenant_id
|
||||
mock_token.type = "app"
|
||||
mock_token.token = token_value
|
||||
mock_token.last_used_at = None
|
||||
mock_token.created_at = datetime.now()
|
||||
|
||||
ApiTokenCache.set(token_value, "app", mock_token)
|
||||
|
||||
for i in range(3):
|
||||
key = f"api_token:app:test-token-{i}"
|
||||
assert redis_client.exists(key) == 1
|
||||
|
||||
result = ApiTokenCache.invalidate_by_tenant(tenant_id)
|
||||
assert result is True
|
||||
|
||||
for i in range(3):
|
||||
key = f"api_token:app:test-token-{i}"
|
||||
assert redis_client.exists(key) == 0
|
||||
|
||||
assert redis_client.exists(ApiTokenCache._make_tenant_index_key(tenant_id)) == 0
|
||||
|
||||
def test_concurrent_cache_access(self):
|
||||
"""Test concurrent cache access doesn't cause issues."""
|
||||
import concurrent.futures
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_token = MagicMock()
|
||||
mock_token.id = "test-id"
|
||||
mock_token.app_id = "test-app"
|
||||
mock_token.tenant_id = "test-tenant"
|
||||
mock_token.type = "app"
|
||||
mock_token.token = self.test_token
|
||||
mock_token.last_used_at = None
|
||||
mock_token.created_at = datetime.now()
|
||||
|
||||
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
|
||||
|
||||
def get_from_cache():
|
||||
return ApiTokenCache.get(self.test_token, self.test_scope)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(get_from_cache) for _ in range(50)]
|
||||
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
||||
|
||||
assert len(results) == 50
|
||||
assert all(r is not None for r in results)
|
||||
assert all(isinstance(r, CachedApiToken) for r in results)
|
||||
|
||||
|
||||
class TestTokenUsageRecording:
|
||||
"""Tests for recording token usage in Redis (batch update approach)."""
|
||||
|
||||
def setup_method(self):
|
||||
self.test_token = "test-usage-token"
|
||||
self.test_scope = "app"
|
||||
self.active_key = ApiTokenCache.make_active_key(self.test_token, self.test_scope)
|
||||
|
||||
def teardown_method(self):
|
||||
try:
|
||||
redis_client.delete(self.active_key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_record_token_usage_sets_redis_key(self):
|
||||
"""Test that record_token_usage writes an active key to Redis."""
|
||||
from services.api_token_service import record_token_usage
|
||||
|
||||
record_token_usage(self.test_token, self.test_scope)
|
||||
|
||||
# Key should exist
|
||||
assert redis_client.exists(self.active_key) == 1
|
||||
|
||||
# Value should be an ISO timestamp
|
||||
value = redis_client.get(self.active_key)
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("utf-8")
|
||||
datetime.fromisoformat(value) # Should not raise
|
||||
|
||||
def test_record_token_usage_has_ttl(self):
|
||||
"""Test that active keys have a TTL as safety net."""
|
||||
from services.api_token_service import record_token_usage
|
||||
|
||||
record_token_usage(self.test_token, self.test_scope)
|
||||
|
||||
ttl = redis_client.ttl(self.active_key)
|
||||
assert 3595 <= ttl <= 3600 # ~1 hour
|
||||
|
||||
def test_record_token_usage_overwrites(self):
|
||||
"""Test that repeated calls overwrite the same key (no accumulation)."""
|
||||
from services.api_token_service import record_token_usage
|
||||
|
||||
record_token_usage(self.test_token, self.test_scope)
|
||||
first_value = redis_client.get(self.active_key)
|
||||
|
||||
time.sleep(0.01) # Tiny delay so timestamp differs
|
||||
|
||||
record_token_usage(self.test_token, self.test_scope)
|
||||
second_value = redis_client.get(self.active_key)
|
||||
|
||||
# Key count should still be 1 (overwritten, not accumulated)
|
||||
assert redis_client.exists(self.active_key) == 1
|
||||
|
||||
|
||||
class TestEndToEndCacheFlow:
|
||||
"""End-to-end integration test for complete cache flow."""
|
||||
|
||||
@pytest.mark.usefixtures("db_session")
|
||||
def test_complete_flow_cache_miss_then_hit(self, db_session):
|
||||
"""
|
||||
Test complete flow:
|
||||
1. First request (cache miss) -> query DB -> cache result
|
||||
2. Second request (cache hit) -> return from cache
|
||||
3. Verify Redis state
|
||||
"""
|
||||
test_token_value = "test-e2e-token"
|
||||
test_scope = "app"
|
||||
|
||||
test_token = ApiToken()
|
||||
test_token.id = "test-e2e-id"
|
||||
test_token.token = test_token_value
|
||||
test_token.type = test_scope
|
||||
test_token.app_id = "test-app"
|
||||
test_token.tenant_id = "test-tenant"
|
||||
test_token.last_used_at = None
|
||||
test_token.created_at = datetime.now()
|
||||
|
||||
db_session.add(test_token)
|
||||
db_session.commit()
|
||||
|
||||
try:
|
||||
# Step 1: Cache miss - set token in cache
|
||||
ApiTokenCache.set(test_token_value, test_scope, test_token)
|
||||
|
||||
cache_key = f"api_token:{test_scope}:{test_token_value}"
|
||||
assert redis_client.exists(cache_key) == 1
|
||||
|
||||
# Step 2: Cache hit - get from cache
|
||||
cached_token = ApiTokenCache.get(test_token_value, test_scope)
|
||||
assert cached_token is not None
|
||||
assert cached_token.id == test_token.id
|
||||
assert cached_token.token == test_token_value
|
||||
|
||||
# Step 3: Verify tenant index
|
||||
index_key = ApiTokenCache._make_tenant_index_key(test_token.tenant_id)
|
||||
assert redis_client.exists(index_key) == 1
|
||||
assert cache_key.encode() in redis_client.smembers(index_key)
|
||||
|
||||
# Step 4: Delete and verify cleanup
|
||||
ApiTokenCache.delete(test_token_value, test_scope)
|
||||
assert redis_client.exists(cache_key) == 0
|
||||
assert cache_key.encode() not in redis_client.smembers(index_key)
|
||||
|
||||
finally:
|
||||
db_session.delete(test_token)
|
||||
db_session.commit()
|
||||
redis_client.delete(f"api_token:{test_scope}:{test_token_value}")
|
||||
redis_client.delete(ApiTokenCache._make_tenant_index_key(test_token.tenant_id))
|
||||
|
||||
def test_high_concurrency_simulation(self):
|
||||
"""Simulate high concurrency access to cache."""
|
||||
import concurrent.futures
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
test_token_value = "test-concurrent-token"
|
||||
test_scope = "app"
|
||||
|
||||
mock_token = MagicMock()
|
||||
mock_token.id = "concurrent-id"
|
||||
mock_token.app_id = "test-app"
|
||||
mock_token.tenant_id = "test-tenant"
|
||||
mock_token.type = test_scope
|
||||
mock_token.token = test_token_value
|
||||
mock_token.last_used_at = datetime.now()
|
||||
mock_token.created_at = datetime.now()
|
||||
|
||||
ApiTokenCache.set(test_token_value, test_scope, mock_token)
|
||||
|
||||
try:
|
||||
|
||||
def read_cache():
|
||||
return ApiTokenCache.get(test_token_value, test_scope)
|
||||
|
||||
start_time = time.time()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
|
||||
futures = [executor.submit(read_cache) for _ in range(100)]
|
||||
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
assert len(results) == 100
|
||||
assert all(r is not None for r in results)
|
||||
|
||||
assert elapsed < 1.0, f"Too slow: {elapsed}s for 100 cache reads"
|
||||
|
||||
finally:
|
||||
ApiTokenCache.delete(test_token_value, test_scope)
|
||||
redis_client.delete(ApiTokenCache._make_tenant_index_key(mock_token.tenant_id))
|
||||
|
||||
|
||||
class TestRedisFailover:
|
||||
"""Test behavior when Redis is unavailable."""
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_graceful_degradation_when_redis_fails(self, mock_redis):
|
||||
"""Test system degrades gracefully when Redis is unavailable."""
|
||||
from redis import RedisError
|
||||
|
||||
mock_redis.get.side_effect = RedisError("Connection failed")
|
||||
mock_redis.setex.side_effect = RedisError("Connection failed")
|
||||
|
||||
result_get = ApiTokenCache.get("test-token", "app")
|
||||
assert result_get is None
|
||||
|
||||
result_set = ApiTokenCache.set("test-token", "app", None)
|
||||
assert result_set is False
|
||||
@@ -1,29 +0,0 @@
|
||||
"""
|
||||
Integration tests for KnowledgeRetrievalNode.
|
||||
|
||||
This module provides integration tests for KnowledgeRetrievalNode with real database interactions.
|
||||
|
||||
Note: These tests require database setup and are more complex than unit tests.
|
||||
For now, we focus on unit tests which provide better coverage for the node logic.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestKnowledgeRetrievalNodeIntegration:
|
||||
"""
|
||||
Integration test suite for KnowledgeRetrievalNode.
|
||||
|
||||
Note: Full integration tests require:
|
||||
- Database setup with datasets and documents
|
||||
- Vector store for embeddings
|
||||
- Model providers for retrieval
|
||||
|
||||
For now, unit tests provide comprehensive coverage of the node logic.
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason="Integration tests require full database and vector store setup")
|
||||
def test_end_to_end_knowledge_retrieval(self):
|
||||
"""Test end-to-end knowledge retrieval workflow."""
|
||||
# TODO: Implement with real database
|
||||
pass
|
||||
@@ -1,614 +0,0 @@
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest
|
||||
from models.dataset import Dataset, Document
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
|
||||
class TestGetAvailableDatasetsIntegration:
|
||||
def test_returns_datasets_with_available_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create documents with completed status, enabled, not archived
|
||||
for i in range(3):
|
||||
document = Document(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
name=f"Document {i}",
|
||||
created_from="web",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0].id == dataset.id
|
||||
assert result[0].tenant_id == tenant.id
|
||||
assert result[0].name == dataset.name
|
||||
|
||||
def test_filters_out_datasets_with_only_archived_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
# Create only archived documents
|
||||
for i in range(2):
|
||||
document = Document(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
name=f"Archived Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status="completed",
|
||||
enabled=True,
|
||||
archived=True, # Archived
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
|
||||
|
||||
# Assert
|
||||
assert len(result) == 0
|
||||
|
||||
def test_filters_out_datasets_with_only_disabled_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
# Create only disabled documents
|
||||
for i in range(2):
|
||||
document = Document(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
name=f"Disabled Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status="completed",
|
||||
enabled=False, # Disabled
|
||||
archived=False,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
|
||||
|
||||
# Assert
|
||||
assert len(result) == 0
|
||||
|
||||
def test_filters_out_datasets_with_non_completed_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
# Create documents with non-completed status
|
||||
for i, status in enumerate(["indexing", "parsing", "splitting"]):
|
||||
document = Document(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
name=f"Document {status}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status=status, # Not completed
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
|
||||
|
||||
# Assert
|
||||
assert len(result) == 0
|
||||
|
||||
def test_includes_external_datasets_without_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that external datasets are returned even with no available documents.
|
||||
|
||||
External datasets (e.g., from external knowledge bases) don't have
|
||||
documents stored in Dify's database, so they should always be available.
|
||||
|
||||
Verifies:
|
||||
- External datasets are included in results
|
||||
- No document count check for external datasets
|
||||
"""
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="external", # External provider
|
||||
data_source_type="external",
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0].id == dataset.id
|
||||
assert result[0].provider == "external"
|
||||
|
||||
def test_filters_by_tenant_id(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
# Create two accounts/tenants
|
||||
account1 = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account1, name=fake.company())
|
||||
tenant1 = account1.current_tenant
|
||||
|
||||
account2 = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account2, name=fake.company())
|
||||
tenant2 = account2.current_tenant
|
||||
|
||||
# Create dataset for tenant1
|
||||
dataset1 = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant1.id,
|
||||
name="Tenant 1 Dataset",
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account1.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset1)
|
||||
|
||||
# Create dataset for tenant2
|
||||
dataset2 = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant2.id,
|
||||
name="Tenant 2 Dataset",
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account2.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset2)
|
||||
|
||||
# Add documents to both datasets
|
||||
for dataset, account in [(dataset1, account1), (dataset2, account2)]:
|
||||
document = Document(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="upload_file",
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
name=f"Document for {dataset.name}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status="completed",
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act - request from tenant1, should only get tenant1's dataset
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
result = dataset_retrieval._get_available_datasets(tenant1.id, [dataset1.id, dataset2.id])
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0].id == dataset1.id
|
||||
assert result[0].tenant_id == tenant1.id
|
||||
|
||||
def test_returns_empty_list_when_no_datasets_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Don't create any datasets
|
||||
|
||||
# Act
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
result = dataset_retrieval._get_available_datasets(tenant.id, [str(uuid.uuid4())])
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_returns_only_requested_dataset_ids(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create multiple datasets
|
||||
datasets = []
|
||||
for i in range(3):
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=f"Dataset {i}",
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
datasets.append(dataset)
|
||||
|
||||
# Add document
|
||||
document = Document(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="upload_file",
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
name=f"Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status="completed",
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act - request only dataset 0 and 2, not dataset 1
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
requested_ids = [datasets[0].id, datasets[2].id]
|
||||
result = dataset_retrieval._get_available_datasets(tenant.id, requested_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
returned_ids = {d.id for d in result}
|
||||
assert returned_ids == {datasets[0].id, datasets[2].id}
|
||||
|
||||
|
||||
class TestKnowledgeRetrievalIntegration:
|
||||
def test_knowledge_retrieval_with_available_datasets(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
document = Document(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="upload_file",
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
name=fake.sentence(),
|
||||
created_by=account.id,
|
||||
indexing_status="completed",
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form="text_model",
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create request
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
app_id=str(uuid.uuid4()),
|
||||
user_from="web",
|
||||
dataset_ids=[dataset.id],
|
||||
query="test query",
|
||||
retrieval_mode="multiple",
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock rate limit check and retrieval
|
||||
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
|
||||
with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)):
|
||||
with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]):
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_knowledge_retrieval_no_available_datasets(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset but no documents
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
app_id=str(uuid.uuid4()),
|
||||
user_from="web",
|
||||
dataset_ids=[dataset.id],
|
||||
query="test query",
|
||||
retrieval_mode="multiple",
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock rate limit check
|
||||
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_knowledge_retrieval_rate_limit_exceeded(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
app_id=str(uuid.uuid4()),
|
||||
user_from="web",
|
||||
dataset_ids=[dataset.id],
|
||||
query="test query",
|
||||
retrieval_mode="multiple",
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock rate limit check to raise exception
|
||||
with patch.object(
|
||||
dataset_retrieval,
|
||||
"_check_knowledge_rate_limit",
|
||||
side_effect=Exception("Rate limit exceeded"),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Rate limit exceeded"):
|
||||
dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies():
|
||||
with (
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
):
|
||||
# Setup default mock returns for account service
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
yield {
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
@@ -605,20 +605,26 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
|
||||
mock_storage.download.side_effect = mock_download
|
||||
|
||||
# Execute the task - should raise ValueError for empty CSV
|
||||
# Execute the task
|
||||
job_id = str(uuid.uuid4())
|
||||
with pytest.raises(ValueError, match="The CSV file is empty"):
|
||||
batch_create_segment_to_index_task(
|
||||
job_id=job_id,
|
||||
upload_file_id=upload_file.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
)
|
||||
batch_create_segment_to_index_task(
|
||||
job_id=job_id,
|
||||
upload_file_id=upload_file.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
)
|
||||
|
||||
# Verify error handling
|
||||
# Since exception was raised, no segments should be created
|
||||
# Check Redis cache was set to error status
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
cache_key = f"segment_batch_import_{job_id}"
|
||||
cache_value = redis_client.get(cache_key)
|
||||
assert cache_value == b"error"
|
||||
|
||||
# Verify no segments were created
|
||||
from extensions.ext_database import db
|
||||
|
||||
segments = db.session.query(DocumentSegment).all()
|
||||
|
||||
@@ -153,7 +153,8 @@ class TestCleanNotionDocumentTask:
|
||||
# Execute cleanup task
|
||||
clean_notion_document_task(document_ids, dataset.id)
|
||||
|
||||
# Verify segments are deleted
|
||||
# Verify documents and segments are deleted
|
||||
assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 0
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id.in_(document_ids))
|
||||
@@ -161,9 +162,9 @@ class TestCleanNotionDocumentTask:
|
||||
== 0
|
||||
)
|
||||
|
||||
# Verify index processor was called
|
||||
# Verify index processor was called for each document
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
mock_processor.clean.assert_called_once()
|
||||
assert mock_processor.clean.call_count == len(document_ids)
|
||||
|
||||
# This test successfully verifies:
|
||||
# 1. Document records are properly deleted from the database
|
||||
@@ -185,12 +186,12 @@ class TestCleanNotionDocumentTask:
|
||||
non_existent_dataset_id = str(uuid.uuid4())
|
||||
document_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
|
||||
|
||||
# Execute cleanup task with non-existent dataset - expect exception
|
||||
with pytest.raises(Exception, match="Document has no dataset"):
|
||||
clean_notion_document_task(document_ids, non_existent_dataset_id)
|
||||
# Execute cleanup task with non-existent dataset
|
||||
clean_notion_document_task(document_ids, non_existent_dataset_id)
|
||||
|
||||
# Verify that the index processor factory was not used
|
||||
mock_index_processor_factory.return_value.init_index_processor.assert_not_called()
|
||||
# Verify that the index processor was not called
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
mock_processor.clean.assert_not_called()
|
||||
|
||||
def test_clean_notion_document_task_empty_document_list(
|
||||
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
|
||||
@@ -228,13 +229,9 @@ class TestCleanNotionDocumentTask:
|
||||
# Execute cleanup task with empty document list
|
||||
clean_notion_document_task([], dataset.id)
|
||||
|
||||
# Verify that the index processor was called once with empty node list
|
||||
# Verify that the index processor was not called
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
assert mock_processor.clean.call_count == 1
|
||||
args, kwargs = mock_processor.clean.call_args
|
||||
# args: (dataset, total_index_node_ids)
|
||||
assert isinstance(args[0], Dataset)
|
||||
assert args[1] == []
|
||||
mock_processor.clean.assert_not_called()
|
||||
|
||||
def test_clean_notion_document_task_with_different_index_types(
|
||||
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
|
||||
@@ -318,7 +315,8 @@ class TestCleanNotionDocumentTask:
|
||||
# Note: This test successfully verifies cleanup with different document types.
|
||||
# The task properly handles various index types and document configurations.
|
||||
|
||||
# Verify segments are deleted
|
||||
# Verify documents and segments are deleted
|
||||
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id == document.id)
|
||||
@@ -406,7 +404,8 @@ class TestCleanNotionDocumentTask:
|
||||
# Execute cleanup task
|
||||
clean_notion_document_task([document.id], dataset.id)
|
||||
|
||||
# Verify segments are deleted
|
||||
# Verify documents and segments are deleted
|
||||
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
|
||||
== 0
|
||||
@@ -509,7 +508,8 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
clean_notion_document_task(documents_to_clean, dataset.id)
|
||||
|
||||
# Verify only specified documents' segments are deleted
|
||||
# Verify only specified documents and segments are deleted
|
||||
assert db_session_with_containers.query(Document).filter(Document.id.in_(documents_to_clean)).count() == 0
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id.in_(documents_to_clean))
|
||||
@@ -697,12 +697,11 @@ class TestCleanNotionDocumentTask:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock index processor to raise an exception
|
||||
mock_index_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
mock_index_processor = mock_index_processor_factory.init_index_processor.return_value
|
||||
mock_index_processor.clean.side_effect = Exception("Index processor error")
|
||||
|
||||
# Execute cleanup task - current implementation propagates the exception
|
||||
with pytest.raises(Exception, match="Index processor error"):
|
||||
clean_notion_document_task([document.id], dataset.id)
|
||||
# Execute cleanup task - it should handle the exception gracefully
|
||||
clean_notion_document_task([document.id], dataset.id)
|
||||
|
||||
# Note: This test demonstrates the task's error handling capability.
|
||||
# Even with external service errors, the database operations complete successfully.
|
||||
@@ -804,7 +803,8 @@ class TestCleanNotionDocumentTask:
|
||||
all_document_ids = [doc.id for doc in documents]
|
||||
clean_notion_document_task(all_document_ids, dataset.id)
|
||||
|
||||
# Verify all segments are deleted
|
||||
# Verify all documents and segments are deleted
|
||||
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
|
||||
== 0
|
||||
@@ -914,7 +914,8 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
clean_notion_document_task([target_document.id], target_dataset.id)
|
||||
|
||||
# Verify only documents' segments from target dataset are deleted
|
||||
# Verify only documents from target dataset are deleted
|
||||
assert db_session_with_containers.query(Document).filter(Document.id == target_document.id).count() == 0
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id == target_document.id)
|
||||
@@ -1029,7 +1030,8 @@ class TestCleanNotionDocumentTask:
|
||||
all_document_ids = [doc.id for doc in documents]
|
||||
clean_notion_document_task(all_document_ids, dataset.id)
|
||||
|
||||
# Verify all segments are deleted regardless of status
|
||||
# Verify all documents and segments are deleted regardless of status
|
||||
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
|
||||
== 0
|
||||
@@ -1140,7 +1142,8 @@ class TestCleanNotionDocumentTask:
|
||||
# Execute cleanup task
|
||||
clean_notion_document_task([document.id], dataset.id)
|
||||
|
||||
# Verify segments are deleted
|
||||
# Verify documents and segments are deleted
|
||||
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
|
||||
== 0
|
||||
|
||||
@@ -9,7 +9,6 @@ from flask import Flask
|
||||
|
||||
from controllers.console import wraps as console_wraps
|
||||
from controllers.console.app import workflow_run as workflow_run_module
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
@@ -54,7 +53,6 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte
|
||||
monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com")
|
||||
|
||||
workflow_run = Mock(spec=WorkflowRun)
|
||||
workflow_run.tenant_id = "tenant-123"
|
||||
workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
workflow_run.created_at = datetime(2024, 1, 1, 12, 0, 0)
|
||||
fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run))
|
||||
@@ -91,20 +89,3 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte
|
||||
== "https://web.example.com/form/backstage-token"
|
||||
)
|
||||
assert "pending_human_inputs" not in response
|
||||
|
||||
|
||||
def test_pause_details_tenant_isolation(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
account = _make_account()
|
||||
_patch_console_guards(monkeypatch, account)
|
||||
monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com")
|
||||
|
||||
workflow_run = Mock(spec=WorkflowRun)
|
||||
workflow_run.tenant_id = "tenant-456"
|
||||
workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
workflow_run.created_at = datetime(2024, 1, 1, 12, 0, 0)
|
||||
fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run))
|
||||
monkeypatch.setattr(workflow_run_module, "db", fake_db)
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"):
|
||||
response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1")
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
"""
|
||||
Unit tests for Service API knowledge pipeline file-upload serialization.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class FakeUploadFile:
|
||||
id: str
|
||||
name: str
|
||||
size: int
|
||||
extension: str
|
||||
mime_type: str
|
||||
created_by: str
|
||||
created_at: datetime | None
|
||||
|
||||
|
||||
def _load_serialize_upload_file():
|
||||
api_dir = Path(__file__).resolve().parents[5]
|
||||
serializers_path = api_dir / "controllers" / "service_api" / "dataset" / "rag_pipeline" / "serializers.py"
|
||||
|
||||
spec = importlib.util.spec_from_file_location("rag_pipeline_serializers", serializers_path)
|
||||
assert spec
|
||||
assert spec.loader
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore[attr-defined]
|
||||
return module.serialize_upload_file
|
||||
|
||||
|
||||
def test_file_upload_created_at_is_isoformat_string():
|
||||
serialize_upload_file = _load_serialize_upload_file()
|
||||
|
||||
created_at = datetime(2026, 2, 8, 12, 0, 0, tzinfo=UTC)
|
||||
upload_file = FakeUploadFile()
|
||||
upload_file.id = "file-1"
|
||||
upload_file.name = "test.pdf"
|
||||
upload_file.size = 123
|
||||
upload_file.extension = "pdf"
|
||||
upload_file.mime_type = "application/pdf"
|
||||
upload_file.created_by = "account-1"
|
||||
upload_file.created_at = created_at
|
||||
|
||||
result = serialize_upload_file(upload_file)
|
||||
assert result["created_at"] == created_at.isoformat()
|
||||
|
||||
|
||||
def test_file_upload_created_at_none_serializes_to_null():
|
||||
serialize_upload_file = _load_serialize_upload_file()
|
||||
|
||||
upload_file = FakeUploadFile()
|
||||
upload_file.id = "file-1"
|
||||
upload_file.name = "test.pdf"
|
||||
upload_file.size = 123
|
||||
upload_file.extension = "pdf"
|
||||
upload_file.mime_type = "application/pdf"
|
||||
upload_file.created_by = "account-1"
|
||||
upload_file.created_at = None
|
||||
|
||||
result = serialize_upload_file(upload_file)
|
||||
assert result["created_at"] is None
|
||||
@@ -1,54 +0,0 @@
|
||||
"""
|
||||
Unit tests for Service API knowledge pipeline route registration.
|
||||
"""
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_rag_pipeline_routes_registered():
|
||||
api_dir = Path(__file__).resolve().parents[5]
|
||||
|
||||
service_api_init = api_dir / "controllers" / "service_api" / "__init__.py"
|
||||
rag_pipeline_workflow = (
|
||||
api_dir / "controllers" / "service_api" / "dataset" / "rag_pipeline" / "rag_pipeline_workflow.py"
|
||||
)
|
||||
|
||||
assert service_api_init.exists()
|
||||
assert rag_pipeline_workflow.exists()
|
||||
|
||||
init_tree = ast.parse(service_api_init.read_text(encoding="utf-8"))
|
||||
import_found = False
|
||||
for node in ast.walk(init_tree):
|
||||
if not isinstance(node, ast.ImportFrom):
|
||||
continue
|
||||
if node.module != "dataset.rag_pipeline" or node.level != 1:
|
||||
continue
|
||||
if any(alias.name == "rag_pipeline_workflow" for alias in node.names):
|
||||
import_found = True
|
||||
break
|
||||
assert import_found, "from .dataset.rag_pipeline import rag_pipeline_workflow not found in service_api/__init__.py"
|
||||
|
||||
workflow_tree = ast.parse(rag_pipeline_workflow.read_text(encoding="utf-8"))
|
||||
route_paths: set[str] = set()
|
||||
|
||||
for node in ast.walk(workflow_tree):
|
||||
if not isinstance(node, ast.ClassDef):
|
||||
continue
|
||||
for decorator in node.decorator_list:
|
||||
if not isinstance(decorator, ast.Call):
|
||||
continue
|
||||
if not isinstance(decorator.func, ast.Attribute):
|
||||
continue
|
||||
if decorator.func.attr != "route":
|
||||
continue
|
||||
if not decorator.args:
|
||||
continue
|
||||
first_arg = decorator.args[0]
|
||||
if isinstance(first_arg, ast.Constant) and isinstance(first_arg.value, str):
|
||||
route_paths.add(first_arg.value)
|
||||
|
||||
assert "/datasets/<uuid:dataset_id>/pipeline/datasource-plugins" in route_paths
|
||||
assert "/datasets/<uuid:dataset_id>/pipeline/datasource/nodes/<string:node_id>/run" in route_paths
|
||||
assert "/datasets/<uuid:dataset_id>/pipeline/run" in route_paths
|
||||
assert "/datasets/pipeline/file-upload" in route_paths
|
||||
@@ -1,61 +0,0 @@
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.service_api.end_user.end_user import EndUserApi
|
||||
from controllers.service_api.end_user.error import EndUserNotFoundError
|
||||
from models.model import App, EndUser
|
||||
|
||||
|
||||
class TestEndUserApi:
|
||||
@pytest.fixture
|
||||
def resource(self) -> EndUserApi:
|
||||
return EndUserApi()
|
||||
|
||||
@pytest.fixture
|
||||
def app_model(self) -> App:
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid4())
|
||||
app.tenant_id = str(uuid4())
|
||||
return app
|
||||
|
||||
def test_get_end_user_returns_all_attributes(self, mocker, resource: EndUserApi, app_model: App) -> None:
|
||||
end_user = Mock(spec=EndUser)
|
||||
end_user.id = str(uuid4())
|
||||
end_user.tenant_id = app_model.tenant_id
|
||||
end_user.app_id = app_model.id
|
||||
end_user.type = "service_api"
|
||||
end_user.external_user_id = "external-123"
|
||||
end_user.name = "Alice"
|
||||
end_user._is_anonymous = True
|
||||
end_user.session_id = "session-xyz"
|
||||
end_user.created_at = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
end_user.updated_at = datetime(2024, 1, 2, tzinfo=UTC)
|
||||
|
||||
get_end_user_by_id = mocker.patch(
|
||||
"controllers.service_api.end_user.end_user.EndUserService.get_end_user_by_id", return_value=end_user
|
||||
)
|
||||
|
||||
result = EndUserApi.get.__wrapped__(resource, app_model=app_model, end_user_id=UUID(end_user.id))
|
||||
|
||||
get_end_user_by_id.assert_called_once_with(
|
||||
tenant_id=app_model.tenant_id, app_id=app_model.id, end_user_id=end_user.id
|
||||
)
|
||||
assert result["id"] == end_user.id
|
||||
assert result["tenant_id"] == end_user.tenant_id
|
||||
assert result["app_id"] == end_user.app_id
|
||||
assert result["type"] == end_user.type
|
||||
assert result["external_user_id"] == end_user.external_user_id
|
||||
assert result["name"] == end_user.name
|
||||
assert result["is_anonymous"] is True
|
||||
assert result["session_id"] == end_user.session_id
|
||||
assert result["created_at"].startswith("2024-01-01T00:00:00")
|
||||
assert result["updated_at"].startswith("2024-01-02T00:00:00")
|
||||
|
||||
def test_get_end_user_not_found(self, mocker, resource: EndUserApi, app_model: App) -> None:
|
||||
mocker.patch("controllers.service_api.end_user.end_user.EndUserService.get_end_user_by_id", return_value=None)
|
||||
|
||||
with pytest.raises(EndUserNotFoundError):
|
||||
EndUserApi.get.__wrapped__(resource, app_model=app_model, end_user_id=uuid4())
|
||||
@@ -1,715 +0,0 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.nodes.knowledge_retrieval import exc
|
||||
from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest
|
||||
from models.dataset import Dataset
|
||||
|
||||
# ==================== Helper Functions ====================
|
||||
|
||||
|
||||
def create_mock_dataset(
|
||||
dataset_id: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
provider: str = "dify",
|
||||
indexing_technique: str = "high_quality",
|
||||
available_document_count: int = 10,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock Dataset object for testing.
|
||||
|
||||
Args:
|
||||
dataset_id: Unique identifier for the dataset
|
||||
tenant_id: Tenant ID for the dataset
|
||||
provider: Provider type ("dify" or "external")
|
||||
indexing_technique: Indexing technique ("high_quality" or "economy")
|
||||
available_document_count: Number of available documents
|
||||
|
||||
Returns:
|
||||
Mock: A properly configured Dataset mock
|
||||
"""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id or str(uuid4())
|
||||
dataset.tenant_id = tenant_id or str(uuid4())
|
||||
dataset.name = "test_dataset"
|
||||
dataset.provider = provider
|
||||
dataset.indexing_technique = indexing_technique
|
||||
dataset.available_document_count = available_document_count
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
"reranking_enable": False,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
return dataset
|
||||
|
||||
|
||||
def create_mock_document(
|
||||
content: str,
|
||||
doc_id: str,
|
||||
score: float = 0.8,
|
||||
provider: str = "dify",
|
||||
additional_metadata: dict | None = None,
|
||||
) -> Document:
|
||||
"""
|
||||
Create a mock Document object for testing.
|
||||
|
||||
Args:
|
||||
content: The text content of the document
|
||||
doc_id: Unique identifier for the document chunk
|
||||
score: Relevance score (0.0 to 1.0)
|
||||
provider: Document provider ("dify" or "external")
|
||||
additional_metadata: Optional extra metadata fields
|
||||
|
||||
Returns:
|
||||
Document: A properly structured Document object
|
||||
"""
|
||||
metadata = {
|
||||
"doc_id": doc_id,
|
||||
"document_id": str(uuid4()),
|
||||
"dataset_id": str(uuid4()),
|
||||
"score": score,
|
||||
}
|
||||
|
||||
if additional_metadata:
|
||||
metadata.update(additional_metadata)
|
||||
|
||||
return Document(
|
||||
page_content=content,
|
||||
metadata=metadata,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
|
||||
# ==================== Test _check_knowledge_rate_limit ====================
|
||||
|
||||
|
||||
class TestCheckKnowledgeRateLimit:
|
||||
"""
|
||||
Test suite for _check_knowledge_rate_limit method.
|
||||
|
||||
The _check_knowledge_rate_limit method validates whether a tenant has
|
||||
exceeded their knowledge retrieval rate limit. This is important for:
|
||||
- Preventing abuse of the knowledge retrieval system
|
||||
- Enforcing subscription plan limits
|
||||
- Tracking usage for billing purposes
|
||||
|
||||
Test Cases:
|
||||
============
|
||||
1. Rate limit disabled - no exception raised
|
||||
2. Rate limit enabled but not exceeded - no exception raised
|
||||
3. Rate limit enabled and exceeded - RateLimitExceededError raised
|
||||
4. Redis operations are performed correctly
|
||||
5. RateLimitLog is created when limit is exceeded
|
||||
"""
|
||||
|
||||
@patch("core.rag.retrieval.dataset_retrieval.FeatureService")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.redis_client")
|
||||
def test_rate_limit_disabled_no_exception(self, mock_redis, mock_feature_service):
|
||||
"""
|
||||
Test that when rate limit is disabled, no exception is raised.
|
||||
|
||||
This test verifies the behavior when the tenant's subscription
|
||||
does not have rate limiting enabled.
|
||||
|
||||
Verifies:
|
||||
- FeatureService.get_knowledge_rate_limit is called
|
||||
- No Redis operations are performed
|
||||
- No exception is raised
|
||||
- Retrieval proceeds normally
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock rate limit disabled
|
||||
mock_limit = Mock()
|
||||
mock_limit.enabled = False
|
||||
mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit
|
||||
|
||||
# Act & Assert - should not raise any exception
|
||||
dataset_retrieval._check_knowledge_rate_limit(tenant_id)
|
||||
|
||||
# Verify FeatureService was called
|
||||
mock_feature_service.get_knowledge_rate_limit.assert_called_once_with(tenant_id)
|
||||
|
||||
# Verify no Redis operations were performed
|
||||
assert not mock_redis.zadd.called
|
||||
assert not mock_redis.zremrangebyscore.called
|
||||
assert not mock_redis.zcard.called
|
||||
|
||||
@patch("core.rag.retrieval.dataset_retrieval.session_factory")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.FeatureService")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.redis_client")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.time")
|
||||
def test_rate_limit_enabled_not_exceeded(self, mock_time, mock_redis, mock_feature_service, mock_session_factory):
|
||||
"""
|
||||
Test that when rate limit is enabled but not exceeded, no exception is raised.
|
||||
|
||||
This test simulates a tenant making requests within their rate limit.
|
||||
The Redis sorted set stores timestamps of recent requests, and old
|
||||
requests (older than 60 seconds) are removed.
|
||||
|
||||
Verifies:
|
||||
- Redis zadd is called to track the request
|
||||
- Redis zremrangebyscore removes old entries
|
||||
- Redis zcard returns count within limit
|
||||
- No exception is raised
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock rate limit enabled with limit of 100 requests per minute
|
||||
mock_limit = Mock()
|
||||
mock_limit.enabled = True
|
||||
mock_limit.limit = 100
|
||||
mock_limit.subscription_plan = "professional"
|
||||
mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit
|
||||
|
||||
# Mock time
|
||||
current_time = 1234567890000 # Current time in milliseconds
|
||||
mock_time.time.return_value = current_time / 1000 # Return seconds
|
||||
mock_time.time.__mul__ = lambda self, x: int(self * x) # Multiply to get milliseconds
|
||||
|
||||
# Mock Redis operations
|
||||
# zcard returns 50 (within limit of 100)
|
||||
mock_redis.zcard.return_value = 50
|
||||
|
||||
# Mock session_factory.create_session
|
||||
mock_session = MagicMock()
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
mock_session_factory.create_session.return_value.__exit__.return_value = None
|
||||
|
||||
# Act & Assert - should not raise any exception
|
||||
dataset_retrieval._check_knowledge_rate_limit(tenant_id)
|
||||
|
||||
# Verify Redis operations
|
||||
expected_key = f"rate_limit_{tenant_id}"
|
||||
mock_redis.zadd.assert_called_once_with(expected_key, {current_time: current_time})
|
||||
mock_redis.zremrangebyscore.assert_called_once_with(expected_key, 0, current_time - 60000)
|
||||
mock_redis.zcard.assert_called_once_with(expected_key)
|
||||
|
||||
@patch("core.rag.retrieval.dataset_retrieval.session_factory")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.FeatureService")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.redis_client")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.time")
|
||||
def test_rate_limit_enabled_exceeded_raises_exception(
|
||||
self, mock_time, mock_redis, mock_feature_service, mock_session_factory
|
||||
):
|
||||
"""
|
||||
Test that when rate limit is enabled and exceeded, RateLimitExceededError is raised.
|
||||
|
||||
This test simulates a tenant exceeding their rate limit. When the count
|
||||
of recent requests exceeds the limit, an exception should be raised and
|
||||
a RateLimitLog should be created.
|
||||
|
||||
Verifies:
|
||||
- Redis zcard returns count exceeding limit
|
||||
- RateLimitExceededError is raised with correct message
|
||||
- RateLimitLog is created in database
|
||||
- Session operations are performed correctly
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock rate limit enabled with limit of 100 requests per minute
|
||||
mock_limit = Mock()
|
||||
mock_limit.enabled = True
|
||||
mock_limit.limit = 100
|
||||
mock_limit.subscription_plan = "professional"
|
||||
mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit
|
||||
|
||||
# Mock time
|
||||
current_time = 1234567890000
|
||||
mock_time.time.return_value = current_time / 1000
|
||||
|
||||
# Mock Redis operations - return count exceeding limit
|
||||
mock_redis.zcard.return_value = 150 # Exceeds limit of 100
|
||||
|
||||
# Mock session_factory.create_session
|
||||
mock_session = MagicMock()
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
mock_session_factory.create_session.return_value.__exit__.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(exc.RateLimitExceededError) as exc_info:
|
||||
dataset_retrieval._check_knowledge_rate_limit(tenant_id)
|
||||
|
||||
# Verify exception message
|
||||
assert "knowledge base request rate limit" in str(exc_info.value)
|
||||
|
||||
# Verify RateLimitLog was created
|
||||
mock_session.add.assert_called_once()
|
||||
added_log = mock_session.add.call_args[0][0]
|
||||
assert added_log.tenant_id == tenant_id
|
||||
assert added_log.subscription_plan == "professional"
|
||||
assert added_log.operation == "knowledge"
|
||||
|
||||
|
||||
# ==================== Test _get_available_datasets ====================
|
||||
|
||||
|
||||
class TestGetAvailableDatasets:
|
||||
"""
|
||||
Test suite for _get_available_datasets method.
|
||||
|
||||
The _get_available_datasets method retrieves datasets that are available
|
||||
for retrieval. A dataset is considered available if:
|
||||
- It belongs to the specified tenant
|
||||
- It's in the list of requested dataset_ids
|
||||
- It has at least one completed, enabled, non-archived document OR
|
||||
- It's an external provider dataset
|
||||
|
||||
Note: Due to SQLAlchemy subquery complexity, full testing is done in
|
||||
integration tests. Unit tests here verify basic behavior.
|
||||
"""
|
||||
|
||||
def test_method_exists_and_has_correct_signature(self):
|
||||
"""
|
||||
Test that the method exists and has the correct signature.
|
||||
|
||||
Verifies:
|
||||
- Method exists on DatasetRetrieval class
|
||||
- Accepts tenant_id and dataset_ids parameters
|
||||
"""
|
||||
# Arrange
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Assert - method exists
|
||||
assert hasattr(dataset_retrieval, "_get_available_datasets")
|
||||
# Assert - method is callable
|
||||
assert callable(dataset_retrieval._get_available_datasets)
|
||||
|
||||
|
||||
# ==================== Test knowledge_retrieval ====================
|
||||
|
||||
|
||||
class TestDatasetRetrievalKnowledgeRetrieval:
|
||||
"""
|
||||
Test suite for knowledge_retrieval method.
|
||||
|
||||
The knowledge_retrieval method is the main entry point for retrieving
|
||||
knowledge from datasets. It orchestrates the entire retrieval process:
|
||||
1. Checks rate limits
|
||||
2. Gets available datasets
|
||||
3. Applies metadata filtering if enabled
|
||||
4. Performs retrieval (single or multiple mode)
|
||||
5. Formats and returns results
|
||||
|
||||
Test Cases:
|
||||
============
|
||||
1. Single mode retrieval
|
||||
2. Multiple mode retrieval
|
||||
3. Metadata filtering disabled
|
||||
4. Metadata filtering automatic
|
||||
5. Metadata filtering manual
|
||||
6. External documents handling
|
||||
7. Dify documents handling
|
||||
8. Empty results handling
|
||||
9. Rate limit exceeded
|
||||
10. No available datasets
|
||||
"""
|
||||
|
||||
def test_knowledge_retrieval_single_mode_basic(self):
|
||||
"""
|
||||
Test knowledge_retrieval in single retrieval mode - basic check.
|
||||
|
||||
Note: Full single mode testing requires complex model mocking and
|
||||
is better suited for integration tests. This test verifies the
|
||||
method accepts single mode requests.
|
||||
|
||||
Verifies:
|
||||
- Method can accept single mode request
|
||||
- Request parameters are correctly structured
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
app_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
user_from="web",
|
||||
dataset_ids=[dataset_id],
|
||||
query="What is Python?",
|
||||
retrieval_mode="single",
|
||||
model_provider="openai",
|
||||
model_name="gpt-4",
|
||||
model_mode="chat",
|
||||
completion_params={"temperature": 0.7},
|
||||
)
|
||||
|
||||
# Assert - request is properly structured
|
||||
assert request.retrieval_mode == "single"
|
||||
assert request.model_provider == "openai"
|
||||
assert request.model_name == "gpt-4"
|
||||
assert request.model_mode == "chat"
|
||||
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.session_factory")
|
||||
def test_knowledge_retrieval_multiple_mode(self, mock_session_factory, mock_data_processor):
|
||||
"""
|
||||
Test knowledge_retrieval in multiple retrieval mode.
|
||||
|
||||
In multiple mode, retrieval is performed across all datasets and
|
||||
results are combined and reranked.
|
||||
|
||||
Verifies:
|
||||
- Rate limit is checked
|
||||
- Available datasets are retrieved
|
||||
- Multiple retrieval is performed
|
||||
- Results are combined and reranked
|
||||
- Results are formatted correctly
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
app_id = str(uuid4())
|
||||
dataset_id1 = str(uuid4())
|
||||
dataset_id2 = str(uuid4())
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
user_from="web",
|
||||
dataset_ids=[dataset_id1, dataset_id2],
|
||||
query="What is Python?",
|
||||
retrieval_mode="multiple",
|
||||
top_k=5,
|
||||
score_threshold=0.7,
|
||||
reranking_enable=True,
|
||||
reranking_mode="reranking_model",
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock _check_knowledge_rate_limit
|
||||
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
|
||||
# Mock _get_available_datasets
|
||||
mock_dataset1 = create_mock_dataset(dataset_id=dataset_id1, tenant_id=tenant_id)
|
||||
mock_dataset2 = create_mock_dataset(dataset_id=dataset_id2, tenant_id=tenant_id)
|
||||
with patch.object(
|
||||
dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset1, mock_dataset2]
|
||||
):
|
||||
# Mock get_metadata_filter_condition
|
||||
with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)):
|
||||
# Mock multiple_retrieve to return documents
|
||||
doc1 = create_mock_document("Python is great", "doc1", score=0.9)
|
||||
doc2 = create_mock_document("Python is awesome", "doc2", score=0.8)
|
||||
with patch.object(
|
||||
dataset_retrieval, "multiple_retrieve", return_value=[doc1, doc2]
|
||||
) as mock_multiple_retrieve:
|
||||
# Mock format_retrieval_documents
|
||||
mock_record = Mock()
|
||||
mock_record.segment = Mock()
|
||||
mock_record.segment.dataset_id = dataset_id1
|
||||
mock_record.segment.document_id = str(uuid4())
|
||||
mock_record.segment.index_node_hash = "hash123"
|
||||
mock_record.segment.hit_count = 5
|
||||
mock_record.segment.word_count = 100
|
||||
mock_record.segment.position = 1
|
||||
mock_record.segment.get_sign_content.return_value = "Python is great"
|
||||
mock_record.segment.answer = None
|
||||
mock_record.score = 0.9
|
||||
mock_record.child_chunks = []
|
||||
mock_record.summary = None
|
||||
mock_record.files = None
|
||||
|
||||
mock_retrieval_service = Mock()
|
||||
mock_retrieval_service.format_retrieval_documents.return_value = [mock_record]
|
||||
|
||||
with patch(
|
||||
"core.rag.retrieval.dataset_retrieval.RetrievalService",
|
||||
return_value=mock_retrieval_service,
|
||||
):
|
||||
# Mock database queries
|
||||
mock_session = MagicMock()
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
mock_session_factory.create_session.return_value.__exit__.return_value = None
|
||||
|
||||
mock_dataset_from_db = Mock()
|
||||
mock_dataset_from_db.id = dataset_id1
|
||||
mock_dataset_from_db.name = "test_dataset"
|
||||
|
||||
mock_document = Mock()
|
||||
mock_document.id = str(uuid4())
|
||||
mock_document.name = "test_doc"
|
||||
mock_document.data_source_type = "upload_file"
|
||||
mock_document.doc_metadata = {}
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [
|
||||
mock_dataset_from_db
|
||||
]
|
||||
mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter(
|
||||
[mock_dataset_from_db, mock_document]
|
||||
)
|
||||
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, list)
|
||||
mock_multiple_retrieve.assert_called_once()
|
||||
|
||||
def test_knowledge_retrieval_metadata_filtering_disabled(self):
|
||||
"""
|
||||
Test knowledge_retrieval with metadata filtering disabled.
|
||||
|
||||
When metadata filtering is disabled, get_metadata_filter_condition is
|
||||
NOT called (the method checks metadata_filtering_mode != "disabled").
|
||||
|
||||
Verifies:
|
||||
- get_metadata_filter_condition is NOT called when mode is "disabled"
|
||||
- Retrieval proceeds without metadata filters
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
app_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
user_from="web",
|
||||
dataset_ids=[dataset_id],
|
||||
query="What is Python?",
|
||||
retrieval_mode="multiple",
|
||||
metadata_filtering_mode="disabled",
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock dependencies
|
||||
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
|
||||
mock_dataset = create_mock_dataset(dataset_id=dataset_id, tenant_id=tenant_id)
|
||||
with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]):
|
||||
# Mock get_metadata_filter_condition - should NOT be called when disabled
|
||||
with patch.object(
|
||||
dataset_retrieval,
|
||||
"get_metadata_filter_condition",
|
||||
return_value=(None, None),
|
||||
) as mock_get_metadata:
|
||||
with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]):
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, list)
|
||||
# get_metadata_filter_condition should NOT be called when mode is "disabled"
|
||||
mock_get_metadata.assert_not_called()
|
||||
|
||||
def test_knowledge_retrieval_with_external_documents(self):
|
||||
"""
|
||||
Test knowledge_retrieval with external documents.
|
||||
|
||||
External documents come from external knowledge bases and should
|
||||
be formatted differently than Dify documents.
|
||||
|
||||
Verifies:
|
||||
- External documents are handled correctly
|
||||
- Provider is set to "external"
|
||||
- Metadata includes external-specific fields
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
app_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
user_from="web",
|
||||
dataset_ids=[dataset_id],
|
||||
query="What is Python?",
|
||||
retrieval_mode="multiple",
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock dependencies
|
||||
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
|
||||
mock_dataset = create_mock_dataset(dataset_id=dataset_id, tenant_id=tenant_id, provider="external")
|
||||
with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]):
|
||||
with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)):
|
||||
# Create external document
|
||||
external_doc = create_mock_document(
|
||||
"External knowledge",
|
||||
"doc1",
|
||||
score=0.9,
|
||||
provider="external",
|
||||
additional_metadata={
|
||||
"dataset_id": dataset_id,
|
||||
"dataset_name": "external_kb",
|
||||
"document_id": "ext_doc1",
|
||||
"title": "External Document",
|
||||
},
|
||||
)
|
||||
with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[external_doc]):
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, list)
|
||||
if result:
|
||||
assert result[0].metadata.data_source_type == "external"
|
||||
|
||||
def test_knowledge_retrieval_empty_results(self):
|
||||
"""
|
||||
Test knowledge_retrieval when no documents are found.
|
||||
|
||||
Verifies:
|
||||
- Empty list is returned
|
||||
- No errors are raised
|
||||
- All dependencies are still called
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
app_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
user_from="web",
|
||||
dataset_ids=[dataset_id],
|
||||
query="What is Python?",
|
||||
retrieval_mode="multiple",
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock dependencies
|
||||
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
|
||||
mock_dataset = create_mock_dataset(dataset_id=dataset_id, tenant_id=tenant_id)
|
||||
with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]):
|
||||
with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)):
|
||||
# Mock multiple_retrieve to return empty list
|
||||
with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]):
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_knowledge_retrieval_rate_limit_exceeded(self):
|
||||
"""
|
||||
Test knowledge_retrieval when rate limit is exceeded.
|
||||
|
||||
Verifies:
|
||||
- RateLimitExceededError is raised
|
||||
- No further processing occurs
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
app_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
user_from="web",
|
||||
dataset_ids=[dataset_id],
|
||||
query="What is Python?",
|
||||
retrieval_mode="multiple",
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock _check_knowledge_rate_limit to raise exception
|
||||
with patch.object(
|
||||
dataset_retrieval,
|
||||
"_check_knowledge_rate_limit",
|
||||
side_effect=exc.RateLimitExceededError("Rate limit exceeded"),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(exc.RateLimitExceededError):
|
||||
dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
def test_knowledge_retrieval_no_available_datasets(self):
|
||||
"""
|
||||
Test knowledge_retrieval when no datasets are available.
|
||||
|
||||
Verifies:
|
||||
- Empty list is returned
|
||||
- No retrieval is attempted
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
app_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
user_from="web",
|
||||
dataset_ids=[dataset_id],
|
||||
query="What is Python?",
|
||||
retrieval_mode="multiple",
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock dependencies
|
||||
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
|
||||
# Mock _get_available_datasets to return empty list
|
||||
with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[]):
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_knowledge_retrieval_handles_multiple_documents_with_different_scores(self):
|
||||
"""
|
||||
Test that knowledge_retrieval processes multiple documents with different scores.
|
||||
|
||||
Note: Full sorting and position testing requires complex SQLAlchemy mocking
|
||||
which is better suited for integration tests. This test verifies documents
|
||||
with different scores can be created and have their metadata.
|
||||
|
||||
Verifies:
|
||||
- Documents can be created with different scores
|
||||
- Score metadata is properly set
|
||||
"""
|
||||
# Create documents with different scores
|
||||
doc1 = create_mock_document("Low score", "doc1", score=0.6)
|
||||
doc2 = create_mock_document("High score", "doc2", score=0.95)
|
||||
doc3 = create_mock_document("Medium score", "doc3", score=0.8)
|
||||
|
||||
# Assert - each document has the correct score
|
||||
assert doc1.metadata["score"] == 0.6
|
||||
assert doc2.metadata["score"] == 0.95
|
||||
assert doc3.metadata["score"] == 0.8
|
||||
|
||||
# Assert - documents are correctly sorted (not the retrieval result, just the list)
|
||||
unsorted = [doc1, doc2, doc3]
|
||||
sorted_docs = sorted(unsorted, key=lambda d: d.metadata["score"], reverse=True)
|
||||
assert [d.metadata["score"] for d in sorted_docs] == [0.95, 0.8, 0.6]
|
||||
@@ -0,0 +1,197 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.nodes.agent.agent_node import AgentNode
|
||||
|
||||
|
||||
class TestInferToolProviderType:
|
||||
"""Test cases for AgentNode._infer_tool_provider_type method."""
|
||||
|
||||
def test_infer_type_from_config_workflow(self):
|
||||
"""Test inferring workflow provider type from config."""
|
||||
tool_config = {
|
||||
"type": "workflow",
|
||||
"provider_name": "workflow-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.WORKFLOW
|
||||
|
||||
def test_infer_type_from_config_builtin(self):
|
||||
"""Test inferring builtin provider type from config."""
|
||||
tool_config = {
|
||||
"type": "builtin",
|
||||
"provider_name": "builtin-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.BUILT_IN
|
||||
|
||||
def test_infer_type_from_config_api(self):
|
||||
"""Test inferring API provider type from config."""
|
||||
tool_config = {
|
||||
"type": "api",
|
||||
"provider_name": "api-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.API
|
||||
|
||||
def test_infer_type_from_config_mcp(self):
|
||||
"""Test inferring MCP provider type from config."""
|
||||
tool_config = {
|
||||
"type": "mcp",
|
||||
"provider_name": "mcp-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.MCP
|
||||
|
||||
def test_infer_type_invalid_config_value_raises_error(self):
|
||||
"""Test that invalid type value in config raises ValueError."""
|
||||
tool_config = {
|
||||
"type": "invalid-type",
|
||||
"provider_name": "workflow-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
def test_infer_workflow_type_from_database(self):
|
||||
"""Test inferring workflow provider type from database."""
|
||||
tool_config = {
|
||||
"provider_name": "workflow-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||
mock_session = MagicMock()
|
||||
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# First query (WorkflowToolProvider) returns a result
|
||||
mock_session.scalar.return_value = True
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.WORKFLOW
|
||||
# Should only query once (after finding WorkflowToolProvider)
|
||||
assert mock_session.scalar.call_count == 1
|
||||
|
||||
def test_infer_mcp_type_from_database(self):
|
||||
"""Test inferring MCP provider type from database."""
|
||||
tool_config = {
|
||||
"provider_name": "mcp-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||
mock_session = MagicMock()
|
||||
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# First query (WorkflowToolProvider) returns None
|
||||
# Second query (MCPToolProvider) returns a result
|
||||
mock_session.scalar.side_effect = [None, True]
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.MCP
|
||||
assert mock_session.scalar.call_count == 2
|
||||
|
||||
def test_infer_api_type_from_database(self):
|
||||
"""Test inferring API provider type from database."""
|
||||
tool_config = {
|
||||
"provider_name": "api-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||
mock_session = MagicMock()
|
||||
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# First query (WorkflowToolProvider) returns None
|
||||
# Second query (MCPToolProvider) returns None
|
||||
# Third query (ApiToolProvider) returns a result
|
||||
mock_session.scalar.side_effect = [None, None, True]
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.API
|
||||
assert mock_session.scalar.call_count == 3
|
||||
|
||||
def test_infer_builtin_type_from_database(self):
|
||||
"""Test inferring builtin provider type from database."""
|
||||
tool_config = {
|
||||
"provider_name": "builtin-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||
mock_session = MagicMock()
|
||||
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# First three queries return None
|
||||
# Fourth query (BuiltinToolProvider) returns a result
|
||||
mock_session.scalar.side_effect = [None, None, None, True]
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.BUILT_IN
|
||||
assert mock_session.scalar.call_count == 4
|
||||
|
||||
def test_infer_type_default_when_not_found(self):
|
||||
"""Test raising AgentNodeError when provider is not found in database."""
|
||||
tool_config = {
|
||||
"provider_name": "unknown-provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||
mock_session = MagicMock()
|
||||
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# All queries return None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Current implementation raises AgentNodeError when provider not found
|
||||
from core.workflow.nodes.agent.exc import AgentNodeError
|
||||
|
||||
with pytest.raises(AgentNodeError, match="Tool provider with ID 'unknown-provider-id' not found"):
|
||||
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
def test_infer_type_default_when_no_provider_name(self):
|
||||
"""Test defaulting to BUILT_IN when provider_name is missing."""
|
||||
tool_config = {}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
|
||||
assert result == ToolProviderType.BUILT_IN
|
||||
|
||||
def test_infer_type_database_exception_propagates(self):
|
||||
"""Test that database exception propagates (current implementation doesn't catch it)."""
|
||||
tool_config = {
|
||||
"provider_name": "provider-id",
|
||||
}
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
|
||||
mock_session = MagicMock()
|
||||
mock_create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Database query raises exception
|
||||
mock_session.scalar.side_effect = Exception("Database error")
|
||||
|
||||
# Current implementation doesn't catch exceptions, so it propagates
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
|
||||
@@ -1,595 +0,0 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import StringSegment
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import (
|
||||
KnowledgeRetrievalNodeData,
|
||||
MultipleRetrievalConfig,
|
||||
RerankingModelConfig,
|
||||
SingleRetrievalConfig,
|
||||
)
|
||||
from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.repositories.rag_retrieval_protocol import RAGRetrievalProtocol, Source
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_init_params():
|
||||
"""Create mock GraphInitParams."""
|
||||
return GraphInitParams(
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
app_id=str(uuid.uuid4()),
|
||||
workflow_id=str(uuid.uuid4()),
|
||||
graph_config={},
|
||||
user_id=str(uuid.uuid4()),
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_runtime_state():
|
||||
"""Create mock GraphRuntimeState."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_retrieval():
|
||||
"""Create mock RAGRetrievalProtocol."""
|
||||
mock_retrieval = Mock(spec=RAGRetrievalProtocol)
|
||||
mock_retrieval.knowledge_retrieval.return_value = []
|
||||
mock_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
return mock_retrieval
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_node_data():
|
||||
"""Create sample KnowledgeRetrievalNodeData."""
|
||||
return KnowledgeRetrievalNodeData(
|
||||
title="Knowledge Retrieval",
|
||||
type="knowledge-retrieval",
|
||||
dataset_ids=[str(uuid.uuid4())],
|
||||
retrieval_mode="multiple",
|
||||
multiple_retrieval_config=MultipleRetrievalConfig(
|
||||
top_k=5,
|
||||
score_threshold=0.7,
|
||||
reranking_mode="reranking_model",
|
||||
reranking_enable=True,
|
||||
reranking_model=RerankingModelConfig(
|
||||
provider="cohere",
|
||||
model="rerank-v2",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestKnowledgeRetrievalNode:
|
||||
"""
|
||||
Test suite for KnowledgeRetrievalNode.
|
||||
"""
|
||||
|
||||
def test_node_initialization(self, mock_graph_init_params, mock_graph_runtime_state, mock_rag_retrieval):
|
||||
"""Test KnowledgeRetrievalNode initialization."""
|
||||
# Arrange
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": {
|
||||
"title": "Knowledge Retrieval",
|
||||
"type": "knowledge-retrieval",
|
||||
"dataset_ids": [str(uuid.uuid4())],
|
||||
"retrieval_mode": "multiple",
|
||||
},
|
||||
}
|
||||
|
||||
# Act
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert node.id == node_id
|
||||
assert node._rag_retrieval == mock_rag_retrieval
|
||||
assert node._llm_file_saver is not None
|
||||
|
||||
def test_run_with_no_query_or_attachment(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _run returns success when no query or attachment is provided."""
|
||||
# Arrange
|
||||
sample_node_data.query_variable_selector = None
|
||||
sample_node_data.query_attachment_selector = None
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs == {}
|
||||
assert mock_rag_retrieval.knowledge_retrieval.call_count == 0
|
||||
|
||||
def test_run_with_query_variable_single_mode(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
):
|
||||
"""Test _run with query variable in single mode."""
|
||||
# Arrange
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
|
||||
query = "What is Python?"
|
||||
query_selector = ["start", "query"]
|
||||
|
||||
# Add query to variable pool
|
||||
mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query))
|
||||
|
||||
node_data = KnowledgeRetrievalNodeData(
|
||||
title="Knowledge Retrieval",
|
||||
type="knowledge-retrieval",
|
||||
dataset_ids=[str(uuid.uuid4())],
|
||||
retrieval_mode="single",
|
||||
query_variable_selector=query_selector,
|
||||
single_retrieval_config=SingleRetrievalConfig(
|
||||
model=ModelConfig(
|
||||
provider="openai",
|
||||
name="gpt-4",
|
||||
mode="chat",
|
||||
completion_params={"temperature": 0.7},
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": node_data.model_dump(),
|
||||
}
|
||||
|
||||
# Mock retrieval response
|
||||
mock_source = Mock(spec=Source)
|
||||
mock_source.model_dump.return_value = {"content": "Python is a programming language"}
|
||||
mock_rag_retrieval.knowledge_retrieval.return_value = [mock_source]
|
||||
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert "result" in result.outputs
|
||||
assert mock_rag_retrieval.knowledge_retrieval.called
|
||||
|
||||
def test_run_with_query_variable_multiple_mode(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _run with query variable in multiple mode."""
|
||||
# Arrange
|
||||
query = "What is Python?"
|
||||
query_selector = ["start", "query"]
|
||||
|
||||
# Add query to variable pool
|
||||
mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query))
|
||||
sample_node_data.query_variable_selector = query_selector
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
# Mock retrieval response
|
||||
mock_source = Mock(spec=Source)
|
||||
mock_source.model_dump.return_value = {"content": "Python is a programming language"}
|
||||
mock_rag_retrieval.knowledge_retrieval.return_value = [mock_source]
|
||||
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert "result" in result.outputs
|
||||
assert mock_rag_retrieval.knowledge_retrieval.called
|
||||
|
||||
def test_run_with_invalid_query_variable_type(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _run fails when query variable is not StringSegment."""
|
||||
# Arrange
|
||||
query_selector = ["start", "query"]
|
||||
|
||||
# Add non-string variable to variable pool
|
||||
mock_graph_runtime_state.variable_pool.add(query_selector, [1, 2, 3])
|
||||
sample_node_data.query_variable_selector = query_selector
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Query variable is not string type" in result.error
|
||||
|
||||
def test_run_with_invalid_attachment_variable_type(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _run fails when attachment variable is not FileSegment or ArrayFileSegment."""
|
||||
# Arrange
|
||||
attachment_selector = ["start", "attachments"]
|
||||
|
||||
# Add non-file variable to variable pool
|
||||
mock_graph_runtime_state.variable_pool.add(attachment_selector, "not a file")
|
||||
sample_node_data.query_attachment_selector = attachment_selector
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Attachments variable is not array file or file type" in result.error
|
||||
|
||||
def test_run_with_rate_limit_exceeded(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _run handles RateLimitExceededError properly."""
|
||||
# Arrange
|
||||
query = "What is Python?"
|
||||
query_selector = ["start", "query"]
|
||||
|
||||
mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query))
|
||||
sample_node_data.query_variable_selector = query_selector
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
# Mock retrieval to raise RateLimitExceededError
|
||||
mock_rag_retrieval.knowledge_retrieval.side_effect = RateLimitExceededError(
|
||||
"knowledge base request rate limit exceeded"
|
||||
)
|
||||
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "rate limit" in result.error.lower()
|
||||
|
||||
def test_run_with_generic_exception(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _run handles generic exceptions properly."""
|
||||
# Arrange
|
||||
query = "What is Python?"
|
||||
query_selector = ["start", "query"]
|
||||
|
||||
mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query))
|
||||
sample_node_data.query_variable_selector = query_selector
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
# Mock retrieval to raise generic exception
|
||||
mock_rag_retrieval.knowledge_retrieval.side_effect = Exception("Unexpected error")
|
||||
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Unexpected error" in result.error
|
||||
|
||||
def test_extract_variable_selector_to_variable_mapping(self):
|
||||
"""Test _extract_variable_selector_to_variable_mapping class method."""
|
||||
# Arrange
|
||||
node_id = "knowledge_node_1"
|
||||
node_data = {
|
||||
"type": "knowledge-retrieval",
|
||||
"title": "Knowledge Retrieval",
|
||||
"dataset_ids": [str(uuid.uuid4())],
|
||||
"retrieval_mode": "multiple",
|
||||
"query_variable_selector": ["start", "query"],
|
||||
"query_attachment_selector": ["start", "attachments"],
|
||||
}
|
||||
graph_config = {}
|
||||
|
||||
# Act
|
||||
mapping = KnowledgeRetrievalNode._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config,
|
||||
node_id=node_id,
|
||||
node_data=node_data,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert mapping[f"{node_id}.query"] == ["start", "query"]
|
||||
assert mapping[f"{node_id}.queryAttachment"] == ["start", "attachments"]
|
||||
|
||||
|
||||
class TestFetchDatasetRetriever:
|
||||
"""
|
||||
Test suite for _fetch_dataset_retriever method.
|
||||
"""
|
||||
|
||||
def test_fetch_dataset_retriever_single_mode(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
):
|
||||
"""Test _fetch_dataset_retriever in single mode."""
|
||||
# Arrange
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
|
||||
query = "What is Python?"
|
||||
variables = {"query": query}
|
||||
|
||||
node_data = KnowledgeRetrievalNodeData(
|
||||
title="Knowledge Retrieval",
|
||||
type="knowledge-retrieval",
|
||||
dataset_ids=[str(uuid.uuid4())],
|
||||
retrieval_mode="single",
|
||||
single_retrieval_config=SingleRetrievalConfig(
|
||||
model=ModelConfig(
|
||||
provider="openai",
|
||||
name="gpt-4",
|
||||
mode="chat",
|
||||
completion_params={"temperature": 0.7},
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# Mock retrieval response
|
||||
mock_source = Mock(spec=Source)
|
||||
mock_rag_retrieval.knowledge_retrieval.return_value = [mock_source]
|
||||
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {"id": node_id, "data": node_data.model_dump()}
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
results, usage = node._fetch_dataset_retriever(node_data=node_data, variables=variables)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 1
|
||||
assert isinstance(usage, LLMUsage)
|
||||
assert mock_rag_retrieval.knowledge_retrieval.called
|
||||
|
||||
def test_fetch_dataset_retriever_multiple_mode_with_reranking(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _fetch_dataset_retriever in multiple mode with reranking."""
|
||||
# Arrange
|
||||
query = "What is Python?"
|
||||
variables = {"query": query}
|
||||
|
||||
# Mock retrieval response
|
||||
mock_rag_retrieval.knowledge_retrieval.return_value = []
|
||||
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
results, usage = node._fetch_dataset_retriever(node_data=sample_node_data, variables=variables)
|
||||
|
||||
# Assert
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(usage, LLMUsage)
|
||||
assert mock_rag_retrieval.knowledge_retrieval.called
|
||||
|
||||
# Verify reranking parameters via request object
|
||||
call_args = mock_rag_retrieval.knowledge_retrieval.call_args
|
||||
request = call_args[1]["request"]
|
||||
assert request.reranking_enable is True
|
||||
assert request.reranking_mode == "reranking_model"
|
||||
|
||||
def test_fetch_dataset_retriever_multiple_mode_without_reranking(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
):
|
||||
"""Test _fetch_dataset_retriever in multiple mode without reranking."""
|
||||
# Arrange
|
||||
query = "What is Python?"
|
||||
variables = {"query": query}
|
||||
|
||||
node_data = KnowledgeRetrievalNodeData(
|
||||
title="Knowledge Retrieval",
|
||||
type="knowledge-retrieval",
|
||||
dataset_ids=[str(uuid.uuid4())],
|
||||
retrieval_mode="multiple",
|
||||
multiple_retrieval_config=MultipleRetrievalConfig(
|
||||
top_k=5,
|
||||
score_threshold=0.7,
|
||||
reranking_enable=False,
|
||||
reranking_mode="reranking_model",
|
||||
),
|
||||
)
|
||||
|
||||
# Mock retrieval response
|
||||
mock_rag_retrieval.knowledge_retrieval.return_value = []
|
||||
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
results, usage = node._fetch_dataset_retriever(node_data=node_data, variables=variables)
|
||||
|
||||
# Assert
|
||||
assert isinstance(results, list)
|
||||
assert mock_rag_retrieval.knowledge_retrieval.called
|
||||
|
||||
# Verify reranking is disabled
|
||||
call_args = mock_rag_retrieval.knowledge_retrieval.call_args
|
||||
request = call_args[1]["request"]
|
||||
assert request.reranking_enable is False
|
||||
|
||||
def test_version_method(self):
|
||||
"""Test version class method."""
|
||||
# Act
|
||||
version = KnowledgeRetrievalNode.version()
|
||||
|
||||
# Assert
|
||||
assert version == "1"
|
||||
@@ -133,8 +133,6 @@ class TestCelerySSLConfiguration:
|
||||
mock_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK = 0
|
||||
mock_config.ENABLE_TRIGGER_PROVIDER_REFRESH_TASK = False
|
||||
mock_config.TRIGGER_PROVIDER_REFRESH_INTERVAL = 15
|
||||
mock_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK = False
|
||||
mock_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL = 30
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from dify_app import DifyApp
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from hypothesis import HealthCheck, given, settings
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
@@ -493,7 +493,7 @@ def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]:
|
||||
)
|
||||
|
||||
|
||||
@settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None)
|
||||
@settings(max_examples=50)
|
||||
@given(_scalar_value())
|
||||
def test_build_segment_and_extract_values_for_scalar_types(value):
|
||||
seg = variable_factory.build_segment(value)
|
||||
@@ -504,7 +504,7 @@ def test_build_segment_and_extract_values_for_scalar_types(value):
|
||||
assert seg.value == value
|
||||
|
||||
|
||||
@settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None)
|
||||
@settings(max_examples=50)
|
||||
@given(values=st.lists(_scalar_value(), max_size=20))
|
||||
def test_build_segment_and_extract_values_for_array_types(values):
|
||||
seg = variable_factory.build_segment(values)
|
||||
|
||||
@@ -859,7 +859,7 @@ class TestRedisShardedSubscription:
|
||||
client.get_node_from_key.assert_called_once_with("test-sharded-topic")
|
||||
mock_pubsub.get_sharded_message.assert_called_once_with(
|
||||
ignore_subscribe_messages=False,
|
||||
timeout=1,
|
||||
timeout=0.1,
|
||||
target_node="node-1",
|
||||
)
|
||||
assert result == mock_pubsub.get_sharded_message.return_value
|
||||
|
||||
@@ -1,250 +0,0 @@
|
||||
"""
|
||||
Unit tests for API Token Cache module.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from services.api_token_service import (
|
||||
CACHE_KEY_PREFIX,
|
||||
CACHE_NULL_TTL_SECONDS,
|
||||
CACHE_TTL_SECONDS,
|
||||
ApiTokenCache,
|
||||
CachedApiToken,
|
||||
)
|
||||
|
||||
|
||||
class TestApiTokenCache:
|
||||
"""Test cases for ApiTokenCache class."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test fixtures."""
|
||||
self.mock_token = MagicMock()
|
||||
self.mock_token.id = "test-token-id-123"
|
||||
self.mock_token.app_id = "test-app-id-456"
|
||||
self.mock_token.tenant_id = "test-tenant-id-789"
|
||||
self.mock_token.type = "app"
|
||||
self.mock_token.token = "test-token-value-abc"
|
||||
self.mock_token.last_used_at = datetime(2026, 2, 3, 10, 0, 0)
|
||||
self.mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
|
||||
|
||||
def test_make_cache_key(self):
|
||||
"""Test cache key generation."""
|
||||
# Test with scope
|
||||
key = ApiTokenCache._make_cache_key("my-token", "app")
|
||||
assert key == f"{CACHE_KEY_PREFIX}:app:my-token"
|
||||
|
||||
# Test without scope
|
||||
key = ApiTokenCache._make_cache_key("my-token", None)
|
||||
assert key == f"{CACHE_KEY_PREFIX}:any:my-token"
|
||||
|
||||
def test_serialize_token(self):
|
||||
"""Test token serialization."""
|
||||
serialized = ApiTokenCache._serialize_token(self.mock_token)
|
||||
data = json.loads(serialized)
|
||||
|
||||
assert data["id"] == "test-token-id-123"
|
||||
assert data["app_id"] == "test-app-id-456"
|
||||
assert data["tenant_id"] == "test-tenant-id-789"
|
||||
assert data["type"] == "app"
|
||||
assert data["token"] == "test-token-value-abc"
|
||||
assert data["last_used_at"] == "2026-02-03T10:00:00"
|
||||
assert data["created_at"] == "2026-01-01T00:00:00"
|
||||
|
||||
def test_serialize_token_with_nulls(self):
|
||||
"""Test token serialization with None values."""
|
||||
mock_token = MagicMock()
|
||||
mock_token.id = "test-id"
|
||||
mock_token.app_id = None
|
||||
mock_token.tenant_id = None
|
||||
mock_token.type = "dataset"
|
||||
mock_token.token = "test-token"
|
||||
mock_token.last_used_at = None
|
||||
mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
|
||||
|
||||
serialized = ApiTokenCache._serialize_token(mock_token)
|
||||
data = json.loads(serialized)
|
||||
|
||||
assert data["app_id"] is None
|
||||
assert data["tenant_id"] is None
|
||||
assert data["last_used_at"] is None
|
||||
|
||||
def test_deserialize_token(self):
|
||||
"""Test token deserialization."""
|
||||
cached_data = json.dumps(
|
||||
{
|
||||
"id": "test-id",
|
||||
"app_id": "test-app",
|
||||
"tenant_id": "test-tenant",
|
||||
"type": "app",
|
||||
"token": "test-token",
|
||||
"last_used_at": "2026-02-03T10:00:00",
|
||||
"created_at": "2026-01-01T00:00:00",
|
||||
}
|
||||
)
|
||||
|
||||
result = ApiTokenCache._deserialize_token(cached_data)
|
||||
|
||||
assert isinstance(result, CachedApiToken)
|
||||
assert result.id == "test-id"
|
||||
assert result.app_id == "test-app"
|
||||
assert result.tenant_id == "test-tenant"
|
||||
assert result.type == "app"
|
||||
assert result.token == "test-token"
|
||||
assert result.last_used_at == datetime(2026, 2, 3, 10, 0, 0)
|
||||
assert result.created_at == datetime(2026, 1, 1, 0, 0, 0)
|
||||
|
||||
def test_deserialize_null_token(self):
|
||||
"""Test deserialization of null token (cached miss)."""
|
||||
result = ApiTokenCache._deserialize_token("null")
|
||||
assert result is None
|
||||
|
||||
def test_deserialize_invalid_json(self):
|
||||
"""Test deserialization with invalid JSON."""
|
||||
result = ApiTokenCache._deserialize_token("invalid-json{")
|
||||
assert result is None
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_get_cache_hit(self, mock_redis):
|
||||
"""Test cache hit scenario."""
|
||||
cached_data = json.dumps(
|
||||
{
|
||||
"id": "test-id",
|
||||
"app_id": "test-app",
|
||||
"tenant_id": "test-tenant",
|
||||
"type": "app",
|
||||
"token": "test-token",
|
||||
"last_used_at": "2026-02-03T10:00:00",
|
||||
"created_at": "2026-01-01T00:00:00",
|
||||
}
|
||||
).encode("utf-8")
|
||||
mock_redis.get.return_value = cached_data
|
||||
|
||||
result = ApiTokenCache.get("test-token", "app")
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, CachedApiToken)
|
||||
assert result.app_id == "test-app"
|
||||
mock_redis.get.assert_called_once_with(f"{CACHE_KEY_PREFIX}:app:test-token")
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_get_cache_miss(self, mock_redis):
|
||||
"""Test cache miss scenario."""
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result = ApiTokenCache.get("test-token", "app")
|
||||
|
||||
assert result is None
|
||||
mock_redis.get.assert_called_once()
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_set_valid_token(self, mock_redis):
|
||||
"""Test setting a valid token in cache."""
|
||||
result = ApiTokenCache.set("test-token", "app", self.mock_token)
|
||||
|
||||
assert result is True
|
||||
mock_redis.setex.assert_called_once()
|
||||
args = mock_redis.setex.call_args[0]
|
||||
assert args[0] == f"{CACHE_KEY_PREFIX}:app:test-token"
|
||||
assert args[1] == CACHE_TTL_SECONDS
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_set_null_token(self, mock_redis):
|
||||
"""Test setting a null token (cache penetration prevention)."""
|
||||
result = ApiTokenCache.set("invalid-token", "app", None)
|
||||
|
||||
assert result is True
|
||||
mock_redis.setex.assert_called_once()
|
||||
args = mock_redis.setex.call_args[0]
|
||||
assert args[0] == f"{CACHE_KEY_PREFIX}:app:invalid-token"
|
||||
assert args[1] == CACHE_NULL_TTL_SECONDS
|
||||
assert args[2] == b"null"
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_delete_with_scope(self, mock_redis):
|
||||
"""Test deleting token cache with specific scope."""
|
||||
result = ApiTokenCache.delete("test-token", "app")
|
||||
|
||||
assert result is True
|
||||
mock_redis.delete.assert_called_once_with(f"{CACHE_KEY_PREFIX}:app:test-token")
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_delete_without_scope(self, mock_redis):
|
||||
"""Test deleting token cache without scope (delete all)."""
|
||||
# Mock scan_iter to return an iterator of keys
|
||||
mock_redis.scan_iter.return_value = iter(
|
||||
[
|
||||
b"api_token:app:test-token",
|
||||
b"api_token:dataset:test-token",
|
||||
]
|
||||
)
|
||||
|
||||
result = ApiTokenCache.delete("test-token", None)
|
||||
|
||||
assert result is True
|
||||
# Verify scan_iter was called with the correct pattern
|
||||
mock_redis.scan_iter.assert_called_once()
|
||||
call_args = mock_redis.scan_iter.call_args
|
||||
assert call_args[1]["match"] == f"{CACHE_KEY_PREFIX}:*:test-token"
|
||||
|
||||
# Verify delete was called with all matched keys
|
||||
mock_redis.delete.assert_called_once_with(
|
||||
b"api_token:app:test-token",
|
||||
b"api_token:dataset:test-token",
|
||||
)
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_redis_fallback_on_exception(self, mock_redis):
|
||||
"""Test Redis fallback when Redis is unavailable."""
|
||||
from redis import RedisError
|
||||
|
||||
mock_redis.get.side_effect = RedisError("Connection failed")
|
||||
|
||||
result = ApiTokenCache.get("test-token", "app")
|
||||
|
||||
# Should return None (fallback) instead of raising exception
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestApiTokenCacheIntegration:
|
||||
"""Integration test scenarios."""
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_full_cache_lifecycle(self, mock_redis):
|
||||
"""Test complete cache lifecycle: set -> get -> delete."""
|
||||
# Setup mock token
|
||||
mock_token = MagicMock()
|
||||
mock_token.id = "id-123"
|
||||
mock_token.app_id = "app-456"
|
||||
mock_token.tenant_id = "tenant-789"
|
||||
mock_token.type = "app"
|
||||
mock_token.token = "token-abc"
|
||||
mock_token.last_used_at = datetime(2026, 2, 3, 10, 0, 0)
|
||||
mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
|
||||
|
||||
# 1. Set token in cache
|
||||
ApiTokenCache.set("token-abc", "app", mock_token)
|
||||
assert mock_redis.setex.called
|
||||
|
||||
# 2. Simulate cache hit
|
||||
cached_data = ApiTokenCache._serialize_token(mock_token)
|
||||
mock_redis.get.return_value = cached_data # bytes from model_dump_json().encode()
|
||||
|
||||
retrieved = ApiTokenCache.get("token-abc", "app")
|
||||
assert retrieved is not None
|
||||
assert isinstance(retrieved, CachedApiToken)
|
||||
|
||||
# 3. Delete from cache
|
||||
ApiTokenCache.delete("token-abc", "app")
|
||||
assert mock_redis.delete.called
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_cache_penetration_prevention(self, mock_redis):
|
||||
"""Test that non-existent tokens are cached as null."""
|
||||
# Set null token (cache miss)
|
||||
ApiTokenCache.set("non-existent-token", "app", None)
|
||||
|
||||
args = mock_redis.setex.call_args[0]
|
||||
assert args[2] == b"null"
|
||||
assert args[1] == CACHE_NULL_TTL_SECONDS # Shorter TTL for null values
|
||||
@@ -492,45 +492,3 @@ class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
# Assert
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.type == invoke_type
|
||||
|
||||
|
||||
class TestEndUserServiceGetEndUserById:
|
||||
"""Unit tests for EndUserService.get_end_user_by_id."""
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_end_user_by_id_returns_end_user(self, mock_db, mock_session_class):
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
end_user_id = "end-user-789"
|
||||
existing_user = MagicMock(spec=EndUser)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = existing_user
|
||||
|
||||
result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id)
|
||||
|
||||
assert result == existing_user
|
||||
mock_session.query.assert_called_once_with(EndUser)
|
||||
mock_query.where.assert_called_once()
|
||||
assert len(mock_query.where.call_args[0]) == 3
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_end_user_by_id_returns_none(self, mock_db, mock_session_class):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
result = EndUserService.get_end_user_by_id(tenant_id="tenant", app_id="app", end_user_id="end-user")
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -4,7 +4,6 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration
|
||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
from models.model import App
|
||||
from models.tools import WorkflowToolProvider
|
||||
@@ -90,12 +89,6 @@ def _build_fake_session(app) -> SimpleNamespace:
|
||||
return SimpleNamespace(query=query)
|
||||
|
||||
|
||||
def _build_parameters() -> list[WorkflowToolParameterConfiguration]:
|
||||
return [
|
||||
WorkflowToolParameterConfiguration(name="input", description="input", form=ToolParameter.ToolParameterForm.LLM),
|
||||
]
|
||||
|
||||
|
||||
def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch):
|
||||
workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "human-input"}}]})
|
||||
app = SimpleNamespace(workflow=workflow)
|
||||
@@ -107,6 +100,8 @@ def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch):
|
||||
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db)
|
||||
mock_invalidate = MagicMock()
|
||||
|
||||
parameters = [{"name": "input", "description": "input", "form": "form"}]
|
||||
|
||||
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
|
||||
workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool(
|
||||
user_id="user-id",
|
||||
@@ -116,7 +111,7 @@ def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch):
|
||||
label="Tool",
|
||||
icon={"type": "emoji", "emoji": "tool"},
|
||||
description="desc",
|
||||
parameters=_build_parameters(),
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
|
||||
@@ -139,6 +134,7 @@ def test_create_workflow_tool_success(monkeypatch):
|
||||
mock_from_db = MagicMock()
|
||||
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db)
|
||||
|
||||
parameters = [{"name": "input", "description": "input", "form": "form"}]
|
||||
icon = {"type": "emoji", "emoji": "tool"}
|
||||
|
||||
result = workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool(
|
||||
@@ -149,7 +145,7 @@ def test_create_workflow_tool_success(monkeypatch):
|
||||
label="Tool",
|
||||
icon=icon,
|
||||
description="desc",
|
||||
parameters=_build_parameters(),
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@@ -83,127 +83,23 @@ def mock_documents(document_ids, dataset_id):
|
||||
def mock_db_session():
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
|
||||
sessions = [] # Track all created sessions
|
||||
# Shared mock data that all sessions will access
|
||||
shared_mock_data = {"dataset": None, "documents": None, "doc_iter": None}
|
||||
session = MagicMock()
|
||||
# Ensure tests that expect session.close() to be called can observe it via the context manager
|
||||
session.close = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
|
||||
|
||||
def create_session_side_effect():
|
||||
session = MagicMock()
|
||||
session.close = MagicMock()
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
|
||||
# Track commit calls
|
||||
commit_mock = MagicMock()
|
||||
session.commit = commit_mock
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
|
||||
# Support session.begin() for transactions
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
|
||||
def begin_exit_side_effect(*args, **kwargs):
|
||||
# Auto-commit on transaction exit (like SQLAlchemy)
|
||||
session.commit()
|
||||
# Also mark wrapper's commit as called
|
||||
if sessions:
|
||||
sessions[0].commit()
|
||||
|
||||
begin_cm.__exit__ = MagicMock(side_effect=begin_exit_side_effect)
|
||||
session.begin = MagicMock(return_value=begin_cm)
|
||||
|
||||
sessions.append(session)
|
||||
|
||||
# Setup query with side_effect to handle both Dataset and Document queries
|
||||
def query_side_effect(*args):
|
||||
query = MagicMock()
|
||||
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
|
||||
where_result = MagicMock()
|
||||
where_result.first.return_value = shared_mock_data["dataset"]
|
||||
query.where = MagicMock(return_value=where_result)
|
||||
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
|
||||
# Support both .first() and .all() calls with chaining
|
||||
where_result = MagicMock()
|
||||
where_result.where = MagicMock(return_value=where_result)
|
||||
|
||||
# Create an iterator for .first() calls if not exists
|
||||
if shared_mock_data["doc_iter"] is None:
|
||||
docs = shared_mock_data["documents"] or [None]
|
||||
shared_mock_data["doc_iter"] = iter(docs)
|
||||
|
||||
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
|
||||
docs_or_empty = shared_mock_data["documents"] or []
|
||||
where_result.all = MagicMock(return_value=docs_or_empty)
|
||||
query.where = MagicMock(return_value=where_result)
|
||||
else:
|
||||
query.where = MagicMock(return_value=query)
|
||||
return query
|
||||
|
||||
session.query = MagicMock(side_effect=query_side_effect)
|
||||
return cm
|
||||
|
||||
mock_sf.create_session.side_effect = create_session_side_effect
|
||||
|
||||
# Create a wrapper that behaves like the first session but has access to all sessions
|
||||
class SessionWrapper:
|
||||
def __init__(self):
|
||||
self._sessions = sessions
|
||||
self._shared_data = shared_mock_data
|
||||
# Create a default session for setup phase
|
||||
self._default_session = MagicMock()
|
||||
self._default_session.close = MagicMock()
|
||||
self._default_session.commit = MagicMock()
|
||||
|
||||
# Support session.begin() for default session too
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = self._default_session
|
||||
|
||||
def default_begin_exit_side_effect(*args, **kwargs):
|
||||
self._default_session.commit()
|
||||
|
||||
begin_cm.__exit__ = MagicMock(side_effect=default_begin_exit_side_effect)
|
||||
self._default_session.begin = MagicMock(return_value=begin_cm)
|
||||
|
||||
def default_query_side_effect(*args):
|
||||
query = MagicMock()
|
||||
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
|
||||
where_result = MagicMock()
|
||||
where_result.first.return_value = shared_mock_data["dataset"]
|
||||
query.where = MagicMock(return_value=where_result)
|
||||
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
|
||||
where_result = MagicMock()
|
||||
where_result.where = MagicMock(return_value=where_result)
|
||||
|
||||
if shared_mock_data["doc_iter"] is None:
|
||||
docs = shared_mock_data["documents"] or [None]
|
||||
shared_mock_data["doc_iter"] = iter(docs)
|
||||
|
||||
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
|
||||
docs_or_empty = shared_mock_data["documents"] or []
|
||||
where_result.all = MagicMock(return_value=docs_or_empty)
|
||||
query.where = MagicMock(return_value=where_result)
|
||||
else:
|
||||
query.where = MagicMock(return_value=query)
|
||||
return query
|
||||
|
||||
self._default_session.query = MagicMock(side_effect=default_query_side_effect)
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Forward all attribute access to the first session, or default if none created yet
|
||||
target_session = self._sessions[0] if self._sessions else self._default_session
|
||||
return getattr(target_session, name)
|
||||
|
||||
@property
|
||||
def all_sessions(self):
|
||||
"""Access all created sessions for testing."""
|
||||
return self._sessions
|
||||
|
||||
wrapper = SessionWrapper()
|
||||
yield wrapper
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -356,9 +252,18 @@ class TestTaskEnqueuing:
|
||||
use the deprecated function.
|
||||
"""
|
||||
# Arrange
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
# Return documents one by one for each call
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -399,9 +304,21 @@ class TestBatchProcessing:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
# Create an iterator for documents
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
# Return documents one by one for each call
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -440,9 +357,19 @@ class TestBatchProcessing:
|
||||
doc.stopped_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||
@@ -480,9 +407,19 @@ class TestBatchProcessing:
|
||||
doc.stopped_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
@@ -507,10 +444,7 @@ class TestBatchProcessing:
|
||||
"""
|
||||
# Arrange
|
||||
document_ids = []
|
||||
|
||||
# Set shared mock data with empty documents list
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = []
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -548,9 +482,19 @@ class TestProgressTracking:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -584,9 +528,19 @@ class TestProgressTracking:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -681,9 +635,19 @@ class TestErrorHandling:
|
||||
doc.stopped_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
# Set up to trigger vector space limit error
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
@@ -710,9 +674,17 @@ class TestErrorHandling:
|
||||
Errors during indexing should be caught and logged, but not crash the task.
|
||||
"""
|
||||
# Arrange
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
# Make IndexingRunner raise an exception
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing failed")
|
||||
@@ -736,9 +708,17 @@ class TestErrorHandling:
|
||||
but not treated as a failure.
|
||||
"""
|
||||
# Arrange
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
# Make IndexingRunner raise DocumentIsPausedError
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused")
|
||||
@@ -873,9 +853,17 @@ class TestTaskCancellation:
|
||||
Session cleanup should happen in finally block.
|
||||
"""
|
||||
# Arrange
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -895,9 +883,17 @@ class TestTaskCancellation:
|
||||
Session cleanup should happen even when errors occur.
|
||||
"""
|
||||
# Arrange
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
# Make IndexingRunner raise an exception
|
||||
mock_indexing_runner.run.side_effect = Exception("Test error")
|
||||
@@ -966,7 +962,6 @@ class TestAdvancedScenarios:
|
||||
document_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
|
||||
# Create only 2 documents (simulate one missing)
|
||||
# The new code uses .all() which will only return existing documents
|
||||
mock_documents = []
|
||||
for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one
|
||||
doc = MagicMock(spec=Document)
|
||||
@@ -976,9 +971,21 @@ class TestAdvancedScenarios:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
# Set shared mock data - .all() will only return existing documents
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
# Create iterator that returns None for missing document
|
||||
doc_responses = [mock_documents[0], None, mock_documents[1]]
|
||||
doc_iter = iter(doc_responses)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1068,9 +1075,19 @@ class TestAdvancedScenarios:
|
||||
doc.stopped_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
# Set vector space exactly at limit
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
@@ -1202,9 +1219,19 @@ class TestAdvancedScenarios:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
# Billing disabled - limits should not be checked
|
||||
mock_feature_service.get_features.return_value.billing.enabled = False
|
||||
@@ -1246,9 +1273,19 @@ class TestIntegration:
|
||||
|
||||
# Set up rpop to return None for concurrency check (no more tasks)
|
||||
mock_redis.rpop.side_effect = [None]
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1284,9 +1321,19 @@ class TestIntegration:
|
||||
|
||||
# Set up rpop to return None for concurrency check (no more tasks)
|
||||
mock_redis.rpop.side_effect = [None]
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1368,9 +1415,17 @@ class TestEdgeCases:
|
||||
mock_document.indexing_status = "waiting"
|
||||
mock_document.processing_started_at = None
|
||||
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = [mock_document]
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: mock_document
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1410,9 +1465,17 @@ class TestEdgeCases:
|
||||
mock_document.indexing_status = "waiting"
|
||||
mock_document.processing_started_at = None
|
||||
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = [mock_document]
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: mock_document
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1492,9 +1555,19 @@ class TestEdgeCases:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
# Set vector space limit to 0 (unlimited)
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
@@ -1539,9 +1612,19 @@ class TestEdgeCases:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
# Set negative vector space limit
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
@@ -1592,9 +1675,19 @@ class TestPerformanceScenarios:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
# Configure billing with sufficient limits
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
@@ -1733,9 +1826,19 @@ class TestRobustness:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
# Make IndexingRunner raise an exception
|
||||
mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error")
|
||||
@@ -1763,7 +1866,7 @@ class TestRobustness:
|
||||
- No exceptions occur
|
||||
|
||||
Expected behavior:
|
||||
- All database sessions are closed
|
||||
- Database session is closed
|
||||
- No connection leaks
|
||||
"""
|
||||
# Arrange
|
||||
@@ -1776,9 +1879,19 @@ class TestRobustness:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1786,11 +1899,10 @@ class TestRobustness:
|
||||
# Act
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
|
||||
# Assert - All created sessions should be closed
|
||||
# The code creates multiple sessions: validation, Phase 1 (parsing), Phase 3 (summary)
|
||||
assert len(mock_db_session.all_sessions) >= 1
|
||||
for session in mock_db_session.all_sessions:
|
||||
assert session.close.called, "All sessions should be closed"
|
||||
# Assert
|
||||
assert mock_db_session.close.called
|
||||
# Verify close is called exactly once
|
||||
assert mock_db_session.close.call_count == 1
|
||||
|
||||
def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis):
|
||||
"""
|
||||
|
||||
@@ -109,87 +109,25 @@ def mock_document_segments(document_id):
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session via session_factory.create_session().
|
||||
|
||||
After session split refactor, the code calls create_session() multiple times.
|
||||
This fixture creates shared query mocks so all sessions use the same
|
||||
query configuration, simulating database persistence across sessions.
|
||||
|
||||
The fixture automatically converts side_effect to cycle to prevent StopIteration.
|
||||
Tests configure mocks the same way as before, but behind the scenes the values
|
||||
are cycled infinitely for all sessions.
|
||||
"""
|
||||
from itertools import cycle
|
||||
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
|
||||
sessions = []
|
||||
session = MagicMock()
|
||||
# Ensure tests can observe session.close() via context manager teardown
|
||||
session.close = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
|
||||
# Shared query mocks - all sessions use these
|
||||
shared_query = MagicMock()
|
||||
shared_filter_by = MagicMock()
|
||||
shared_scalars_result = MagicMock()
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
|
||||
# Create custom first mock that auto-cycles side_effect
|
||||
class CyclicMock(MagicMock):
|
||||
def __setattr__(self, name, value):
|
||||
if name == "side_effect" and value is not None:
|
||||
# Convert list/tuple to infinite cycle
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = cycle(value)
|
||||
super().__setattr__(name, value)
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
shared_query.where.return_value.first = CyclicMock()
|
||||
shared_filter_by.first = CyclicMock()
|
||||
|
||||
def _create_session():
|
||||
"""Create a new mock session for each create_session() call."""
|
||||
session = MagicMock()
|
||||
session.close = MagicMock()
|
||||
session.commit = MagicMock()
|
||||
|
||||
# Mock session.begin() context manager
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
|
||||
def _begin_exit_side_effect(exc_type, exc, tb):
|
||||
# commit on success
|
||||
if exc_type is None:
|
||||
session.commit()
|
||||
# return False to propagate exceptions
|
||||
return False
|
||||
|
||||
begin_cm.__exit__.side_effect = _begin_exit_side_effect
|
||||
session.begin.return_value = begin_cm
|
||||
|
||||
# Mock create_session() context manager
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
|
||||
def _exit_side_effect(exc_type, exc, tb):
|
||||
session.close()
|
||||
return False
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
|
||||
# All sessions use the same shared query mocks
|
||||
session.query.return_value = shared_query
|
||||
shared_query.where.return_value = shared_query
|
||||
shared_query.filter_by.return_value = shared_filter_by
|
||||
session.scalars.return_value = shared_scalars_result
|
||||
|
||||
sessions.append(session)
|
||||
# Attach helpers on the first created session for assertions across all sessions
|
||||
if len(sessions) == 1:
|
||||
session.get_all_sessions = lambda: sessions
|
||||
session.any_close_called = lambda: any(s.close.called for s in sessions)
|
||||
session.any_commit_called = lambda: any(s.commit.called for s in sessions)
|
||||
return cm
|
||||
|
||||
mock_sf.create_session.side_effect = _create_session
|
||||
|
||||
# Create first session and return it
|
||||
_create_session()
|
||||
yield sessions[0]
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
session.scalars.return_value = MagicMock()
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -248,8 +186,8 @@ class TestDocumentIndexingSyncTask:
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert - at least one session should have been closed
|
||||
assert mock_db_session.any_close_called()
|
||||
# Assert
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when notion_workspace_id is missing."""
|
||||
@@ -292,7 +230,6 @@ class TestDocumentIndexingSyncTask:
|
||||
"""Test that task handles missing credentials by updating document status."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_datasource_provider_service.get_datasource_credentials.return_value = None
|
||||
|
||||
# Act
|
||||
@@ -302,8 +239,8 @@ class TestDocumentIndexingSyncTask:
|
||||
assert mock_document.indexing_status == "error"
|
||||
assert "Datasource credential not found" in mock_document.error
|
||||
assert mock_document.stopped_at is not None
|
||||
assert mock_db_session.any_commit_called()
|
||||
assert mock_db_session.any_close_called()
|
||||
mock_db_session.commit.assert_called()
|
||||
mock_db_session.close.assert_called()
|
||||
|
||||
def test_page_not_updated(
|
||||
self,
|
||||
@@ -317,7 +254,6 @@ class TestDocumentIndexingSyncTask:
|
||||
"""Test that task does nothing when page has not been updated."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
# Return same time as stored in document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
@@ -327,8 +263,8 @@ class TestDocumentIndexingSyncTask:
|
||||
# Assert
|
||||
# Document status should remain unchanged
|
||||
assert mock_document.indexing_status == "completed"
|
||||
# At least one session should have been closed via context manager teardown
|
||||
assert mock_db_session.any_close_called()
|
||||
# Session should still be closed via context manager teardown
|
||||
assert mock_db_session.close.called
|
||||
|
||||
def test_successful_sync_when_page_updated(
|
||||
self,
|
||||
@@ -345,20 +281,7 @@ class TestDocumentIndexingSyncTask:
|
||||
):
|
||||
"""Test successful sync flow when Notion page has been updated."""
|
||||
# Arrange
|
||||
# Set exact sequence of returns across calls to `.first()`:
|
||||
# 1) document (initial fetch)
|
||||
# 2) dataset (pre-check)
|
||||
# 3) dataset (cleaning phase)
|
||||
# 4) document (pre-indexing update)
|
||||
# 5) document (indexing runner fetch)
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document,
|
||||
]
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
# NotionExtractor returns updated time
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
@@ -376,40 +299,28 @@ class TestDocumentIndexingSyncTask:
|
||||
mock_processor.clean.assert_called_once()
|
||||
|
||||
# Verify segments were deleted from database in batch (DELETE FROM document_segments)
|
||||
# Aggregate execute calls across all created sessions
|
||||
execute_sqls = []
|
||||
for s in mock_db_session.get_all_sessions():
|
||||
execute_sqls.extend([" ".join(str(c[0][0]).split()) for c in s.execute.call_args_list])
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
|
||||
# Verify indexing runner was called
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
|
||||
# Verify session operations (across any created session)
|
||||
assert mock_db_session.any_commit_called()
|
||||
assert mock_db_session.any_close_called()
|
||||
# Verify session operations
|
||||
assert mock_db_session.commit.called
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_dataset_not_found_during_cleaning(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_indexing_runner,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task handles dataset not found during cleaning phase."""
|
||||
# Arrange
|
||||
# Sequence: document (initial), dataset (pre-check), None (cleaning), document (update), document (indexing)
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
None,
|
||||
mock_document,
|
||||
mock_document,
|
||||
]
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None]
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
@@ -418,8 +329,8 @@ class TestDocumentIndexingSyncTask:
|
||||
# Assert
|
||||
# Document should still be set to parsing
|
||||
assert mock_document.indexing_status == "parsing"
|
||||
# At least one session should be closed after error
|
||||
assert mock_db_session.any_close_called()
|
||||
# Session should be closed after error
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_cleaning_error_continues_to_indexing(
|
||||
self,
|
||||
@@ -435,14 +346,8 @@ class TestDocumentIndexingSyncTask:
|
||||
):
|
||||
"""Test that indexing continues even if cleaning fails."""
|
||||
# Arrange
|
||||
from itertools import cycle
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
# Make the cleaning step fail but not the segment fetch
|
||||
processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
processor.clean.side_effect = Exception("Cleaning error")
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error")
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
@@ -451,7 +356,7 @@ class TestDocumentIndexingSyncTask:
|
||||
# Assert
|
||||
# Indexing should still be attempted despite cleaning error
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
assert mock_db_session.any_close_called()
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_indexing_runner_document_paused_error(
|
||||
self,
|
||||
@@ -468,10 +373,7 @@ class TestDocumentIndexingSyncTask:
|
||||
):
|
||||
"""Test that DocumentIsPausedError is handled gracefully."""
|
||||
# Arrange
|
||||
from itertools import cycle
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
|
||||
@@ -481,7 +383,7 @@ class TestDocumentIndexingSyncTask:
|
||||
|
||||
# Assert
|
||||
# Session should be closed after handling error
|
||||
assert mock_db_session.any_close_called()
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_indexing_runner_general_error(
|
||||
self,
|
||||
@@ -498,10 +400,7 @@ class TestDocumentIndexingSyncTask:
|
||||
):
|
||||
"""Test that general exceptions during indexing are handled."""
|
||||
# Arrange
|
||||
from itertools import cycle
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing error")
|
||||
@@ -511,7 +410,7 @@ class TestDocumentIndexingSyncTask:
|
||||
|
||||
# Assert
|
||||
# Session should be closed after error
|
||||
assert mock_db_session.any_close_called()
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_notion_extractor_initialized_with_correct_params(
|
||||
self,
|
||||
@@ -618,14 +517,7 @@ class TestDocumentIndexingSyncTask:
|
||||
):
|
||||
"""Test that index processor clean is called with correct parameters."""
|
||||
# Arrange
|
||||
# Sequence: document (initial), dataset (pre-check), dataset (cleaning), document (update), document (indexing)
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document,
|
||||
]
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import base64
|
||||
from decimal import Decimal
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@@ -10,10 +9,8 @@ from core.mcp.types import (
|
||||
CallToolResult,
|
||||
EmbeddedResource,
|
||||
ImageContent,
|
||||
TextContent,
|
||||
TextResourceContents,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
|
||||
@@ -123,231 +120,3 @@ class TestMCPToolInvoke:
|
||||
# Validate values
|
||||
values = {m.message.variable_name: m.message.variable_value for m in var_msgs}
|
||||
assert values == {"a": 1, "b": "x"}
|
||||
|
||||
|
||||
class TestMCPToolUsageExtraction:
|
||||
"""Test usage metadata extraction from MCP tool results."""
|
||||
|
||||
def test_extract_usage_dict_from_direct_usage_field(self) -> None:
|
||||
"""Test extraction when usage is directly in meta.usage field."""
|
||||
meta = {
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
"total_price": "0.001",
|
||||
"currency": "USD",
|
||||
}
|
||||
}
|
||||
usage_dict = MCPTool._extract_usage_dict(meta)
|
||||
assert usage_dict is not None
|
||||
assert usage_dict["prompt_tokens"] == 100
|
||||
assert usage_dict["completion_tokens"] == 50
|
||||
assert usage_dict["total_tokens"] == 150
|
||||
assert usage_dict["total_price"] == "0.001"
|
||||
assert usage_dict["currency"] == "USD"
|
||||
|
||||
def test_extract_usage_dict_from_nested_metadata(self) -> None:
|
||||
"""Test extraction when usage is nested in meta.metadata.usage."""
|
||||
meta = {
|
||||
"metadata": {
|
||||
"usage": {
|
||||
"prompt_tokens": 200,
|
||||
"completion_tokens": 100,
|
||||
"total_tokens": 300,
|
||||
}
|
||||
}
|
||||
}
|
||||
usage_dict = MCPTool._extract_usage_dict(meta)
|
||||
assert usage_dict is not None
|
||||
assert usage_dict["prompt_tokens"] == 200
|
||||
assert usage_dict["total_tokens"] == 300
|
||||
|
||||
def test_extract_usage_dict_from_flat_token_fields(self) -> None:
|
||||
"""Test extraction when token counts are directly in meta."""
|
||||
meta = {
|
||||
"prompt_tokens": 150,
|
||||
"completion_tokens": 75,
|
||||
"total_tokens": 225,
|
||||
"currency": "EUR",
|
||||
}
|
||||
usage_dict = MCPTool._extract_usage_dict(meta)
|
||||
assert usage_dict is not None
|
||||
assert usage_dict["prompt_tokens"] == 150
|
||||
assert usage_dict["completion_tokens"] == 75
|
||||
assert usage_dict["total_tokens"] == 225
|
||||
assert usage_dict["currency"] == "EUR"
|
||||
|
||||
def test_extract_usage_dict_recursive(self) -> None:
|
||||
"""Test recursive search through nested structures."""
|
||||
meta = {
|
||||
"custom": {
|
||||
"nested": {
|
||||
"usage": {
|
||||
"total_tokens": 500,
|
||||
"prompt_tokens": 300,
|
||||
"completion_tokens": 200,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
usage_dict = MCPTool._extract_usage_dict(meta)
|
||||
assert usage_dict is not None
|
||||
assert usage_dict["total_tokens"] == 500
|
||||
|
||||
def test_extract_usage_dict_from_list(self) -> None:
|
||||
"""Test extraction from nested list structures."""
|
||||
meta = {
|
||||
"items": [
|
||||
{"usage": {"total_tokens": 100}},
|
||||
{"other": "data"},
|
||||
]
|
||||
}
|
||||
usage_dict = MCPTool._extract_usage_dict(meta)
|
||||
assert usage_dict is not None
|
||||
assert usage_dict["total_tokens"] == 100
|
||||
|
||||
def test_extract_usage_dict_returns_none_when_missing(self) -> None:
|
||||
"""Test that None is returned when no usage data is present."""
|
||||
meta = {"other": "data", "custom": {"nested": {"value": 123}}}
|
||||
usage_dict = MCPTool._extract_usage_dict(meta)
|
||||
assert usage_dict is None
|
||||
|
||||
def test_extract_usage_dict_empty_meta(self) -> None:
|
||||
"""Test with empty meta dict."""
|
||||
usage_dict = MCPTool._extract_usage_dict({})
|
||||
assert usage_dict is None
|
||||
|
||||
def test_derive_usage_from_result_with_meta(self) -> None:
|
||||
"""Test _derive_usage_from_result with populated meta."""
|
||||
meta = {
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
"total_price": "0.0015",
|
||||
"currency": "USD",
|
||||
}
|
||||
}
|
||||
result = CallToolResult(content=[], _meta=meta)
|
||||
usage = MCPTool._derive_usage_from_result(result)
|
||||
|
||||
assert isinstance(usage, LLMUsage)
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 50
|
||||
assert usage.total_tokens == 150
|
||||
assert usage.total_price == Decimal("0.0015")
|
||||
assert usage.currency == "USD"
|
||||
|
||||
def test_derive_usage_from_result_without_meta(self) -> None:
|
||||
"""Test _derive_usage_from_result with no meta returns empty usage."""
|
||||
result = CallToolResult(content=[], meta=None)
|
||||
usage = MCPTool._derive_usage_from_result(result)
|
||||
|
||||
assert isinstance(usage, LLMUsage)
|
||||
assert usage.total_tokens == 0
|
||||
assert usage.prompt_tokens == 0
|
||||
assert usage.completion_tokens == 0
|
||||
|
||||
def test_derive_usage_from_result_calculates_total_tokens(self) -> None:
|
||||
"""Test that total_tokens is calculated when missing."""
|
||||
meta = {
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
# total_tokens is missing
|
||||
}
|
||||
}
|
||||
result = CallToolResult(content=[], _meta=meta)
|
||||
usage = MCPTool._derive_usage_from_result(result)
|
||||
|
||||
assert usage.total_tokens == 150 # 100 + 50
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 50
|
||||
|
||||
def test_invoke_sets_latest_usage_from_meta(self) -> None:
|
||||
"""Test that _invoke sets _latest_usage from result meta."""
|
||||
tool = _make_mcp_tool()
|
||||
meta = {
|
||||
"usage": {
|
||||
"prompt_tokens": 200,
|
||||
"completion_tokens": 100,
|
||||
"total_tokens": 300,
|
||||
"total_price": "0.003",
|
||||
"currency": "USD",
|
||||
}
|
||||
}
|
||||
result = CallToolResult(content=[TextContent(type="text", text="test")], _meta=meta)
|
||||
|
||||
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
|
||||
list(tool._invoke(user_id="test_user", tool_parameters={}))
|
||||
|
||||
# Verify latest_usage was set correctly
|
||||
assert tool.latest_usage.prompt_tokens == 200
|
||||
assert tool.latest_usage.completion_tokens == 100
|
||||
assert tool.latest_usage.total_tokens == 300
|
||||
assert tool.latest_usage.total_price == Decimal("0.003")
|
||||
|
||||
def test_invoke_with_no_meta_returns_empty_usage(self) -> None:
|
||||
"""Test that _invoke returns empty usage when no meta is present."""
|
||||
tool = _make_mcp_tool()
|
||||
result = CallToolResult(content=[TextContent(type="text", text="test")], _meta=None)
|
||||
|
||||
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
|
||||
list(tool._invoke(user_id="test_user", tool_parameters={}))
|
||||
|
||||
# Verify latest_usage is empty
|
||||
assert tool.latest_usage.total_tokens == 0
|
||||
assert tool.latest_usage.prompt_tokens == 0
|
||||
assert tool.latest_usage.completion_tokens == 0
|
||||
|
||||
def test_latest_usage_property_returns_llm_usage(self) -> None:
|
||||
"""Test that latest_usage property returns LLMUsage instance."""
|
||||
tool = _make_mcp_tool()
|
||||
assert isinstance(tool.latest_usage, LLMUsage)
|
||||
|
||||
def test_initial_usage_is_empty(self) -> None:
|
||||
"""Test that MCPTool is initialized with empty usage."""
|
||||
tool = _make_mcp_tool()
|
||||
assert tool.latest_usage.total_tokens == 0
|
||||
assert tool.latest_usage.prompt_tokens == 0
|
||||
assert tool.latest_usage.completion_tokens == 0
|
||||
assert tool.latest_usage.total_price == Decimal(0)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"meta_data",
|
||||
[
|
||||
# Direct usage field
|
||||
{"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}},
|
||||
# Nested metadata
|
||||
{"metadata": {"usage": {"total_tokens": 100}}},
|
||||
# Flat token fields
|
||||
{"total_tokens": 50, "prompt_tokens": 30, "completion_tokens": 20},
|
||||
# With price info
|
||||
{
|
||||
"usage": {
|
||||
"total_tokens": 150,
|
||||
"total_price": "0.002",
|
||||
"currency": "EUR",
|
||||
}
|
||||
},
|
||||
# Deep nested
|
||||
{"level1": {"level2": {"usage": {"total_tokens": 200}}}},
|
||||
],
|
||||
)
|
||||
def test_various_meta_formats(self, meta_data) -> None:
|
||||
"""Test that various meta formats are correctly parsed."""
|
||||
result = CallToolResult(content=[], _meta=meta_data)
|
||||
usage = MCPTool._derive_usage_from_result(result)
|
||||
|
||||
assert isinstance(usage, LLMUsage)
|
||||
# Should have at least some usage data
|
||||
if meta_data.get("usage", {}).get("total_tokens") or meta_data.get("total_tokens"):
|
||||
expected_total = (
|
||||
meta_data.get("usage", {}).get("total_tokens")
|
||||
or meta_data.get("total_tokens")
|
||||
or meta_data.get("metadata", {}).get("usage", {}).get("total_tokens")
|
||||
or meta_data.get("level1", {}).get("level2", {}).get("usage", {}).get("total_tokens")
|
||||
)
|
||||
if expected_total:
|
||||
assert usage.total_tokens == expected_total
|
||||
|
||||
11
api/uv.lock
generated
11
api/uv.lock
generated
@@ -1653,7 +1653,7 @@ requires-dist = [
|
||||
{ name = "starlette", specifier = "==0.49.1" },
|
||||
{ name = "tiktoken", specifier = "~=0.9.0" },
|
||||
{ name = "transformers", specifier = "~=4.56.1" },
|
||||
{ name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.18.18" },
|
||||
{ name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" },
|
||||
{ name = "weave", specifier = ">=0.52.16" },
|
||||
{ name = "weaviate-client", specifier = "==4.17.0" },
|
||||
{ name = "webvtt-py", specifier = "~=0.5.1" },
|
||||
@@ -6814,12 +6814,12 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "unstructured"
|
||||
version = "0.18.31"
|
||||
version = "0.16.25"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "backoff" },
|
||||
{ name = "beautifulsoup4" },
|
||||
{ name = "charset-normalizer" },
|
||||
{ name = "chardet" },
|
||||
{ name = "dataclasses-json" },
|
||||
{ name = "emoji" },
|
||||
{ name = "filetype" },
|
||||
@@ -6827,7 +6827,6 @@ dependencies = [
|
||||
{ name = "langdetect" },
|
||||
{ name = "lxml" },
|
||||
{ name = "nltk" },
|
||||
{ name = "numba" },
|
||||
{ name = "numpy" },
|
||||
{ name = "psutil" },
|
||||
{ name = "python-iso639" },
|
||||
@@ -6840,9 +6839,9 @@ dependencies = [
|
||||
{ name = "unstructured-client" },
|
||||
{ name = "wrapt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a9/5f/64285bd69a538bc28753f1423fcaa9d64cd79a9e7c097171b1f0d27e9cdb/unstructured-0.18.31.tar.gz", hash = "sha256:af4bbe32d1894ae6e755f0da6fc0dd307a1d0adeebe0e7cc6278f6cf744339ca", size = 1707700, upload-time = "2026-01-27T15:33:05.378Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/64/31/98c4c78e305d1294888adf87fd5ee30577a4c393951341ca32b43f167f1e/unstructured-0.16.25.tar.gz", hash = "sha256:73b9b0f51dbb687af572ecdb849a6811710b9cac797ddeab8ee80fa07d8aa5e6", size = 1683097, upload-time = "2025-03-07T11:19:39.507Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/4a/9c43f39d9e443c9bc3f2e379b305bca27110adc653b071221b3132c18de5/unstructured-0.18.31-py3-none-any.whl", hash = "sha256:fab4641176cb9b192ed38048758aa0d9843121d03626d18f42275afb31e5b2d3", size = 1794889, upload-time = "2026-01-27T15:33:03.136Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/12/4f/ad08585b5c8a33c82ea119494c4d3023f4796958c56e668b15cc282ec0a0/unstructured-0.16.25-py3-none-any.whl", hash = "sha256:14719ccef2830216cf1c5bf654f75e2bf07b17ca5dcee9da5ac74618130fd337", size = 1769286, upload-time = "2025-03-07T11:19:37.299Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
|
||||
@@ -109,7 +109,6 @@ const AgentTools: FC = () => {
|
||||
tool_parameters: paramsWithDefaultValue,
|
||||
notAuthor: !tool.is_team_authorization,
|
||||
enabled: true,
|
||||
type: tool.provider_type as CollectionType,
|
||||
}
|
||||
}
|
||||
const handleSelectTool = (tool: ToolDefaultValue) => {
|
||||
|
||||
@@ -194,11 +194,11 @@ const ConfigContent: FC<Props> = ({
|
||||
</div>
|
||||
{type === RETRIEVE_TYPE.multiWay && (
|
||||
<>
|
||||
<div className="my-2 flex flex-col items-center py-1">
|
||||
<div className="system-xs-semibold-uppercase mb-2 mr-2 shrink-0 text-text-secondary">
|
||||
<div className="my-2 flex h-6 items-center py-1">
|
||||
<div className="system-xs-semibold-uppercase mr-2 shrink-0 text-text-secondary">
|
||||
{t('rerankSettings', { ns: 'dataset' })}
|
||||
</div>
|
||||
<Divider bgStyle="gradient" className="m-0 !h-px" />
|
||||
<Divider bgStyle="gradient" className="mx-0 !h-px" />
|
||||
</div>
|
||||
{
|
||||
selectedDatasetsMode.inconsistentEmbeddingModel
|
||||
|
||||
@@ -7,7 +7,6 @@ import dayjs from 'dayjs'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import DatePicker from '@/app/components/base/date-and-time-picker/date-picker'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import useTimestamp from '@/hooks/use-timestamp'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
@@ -22,7 +21,7 @@ const WrappedDatePicker = ({
|
||||
onChange,
|
||||
}: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const { userProfile: { timezone } } = useAppContext()
|
||||
// const { userProfile: { timezone } } = useAppContext()
|
||||
const { formatTime: formatTimestamp } = useTimestamp()
|
||||
|
||||
const handleDateChange = useCallback((date?: dayjs.Dayjs) => {
|
||||
@@ -65,7 +64,6 @@ const WrappedDatePicker = ({
|
||||
return (
|
||||
<DatePicker
|
||||
value={dayjs(value ? value * 1000 : Date.now())}
|
||||
timezone={timezone}
|
||||
onChange={handleDateChange}
|
||||
onClear={handleDateChange}
|
||||
renderTrigger={renderTrigger}
|
||||
|
||||
@@ -273,71 +273,6 @@ The text generation application offers non-session support and is ideal for tran
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/end-users/:end_user_id'
|
||||
method='GET'
|
||||
title='Get End User'
|
||||
name='#end-user'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
Retrieve an end user by ID.
|
||||
|
||||
This is useful when other APIs return an end-user ID (e.g. `created_by` from File Upload).
|
||||
|
||||
### Path Parameters
|
||||
- `end_user_id` (uuid) Required
|
||||
End user ID.
|
||||
|
||||
### Response
|
||||
Returns an EndUser object.
|
||||
- `id` (uuid) ID
|
||||
- `tenant_id` (uuid) Tenant ID
|
||||
- `app_id` (uuid) App ID
|
||||
- `type` (string) End user type
|
||||
- `external_user_id` (string) External user ID
|
||||
- `name` (string) Name
|
||||
- `is_anonymous` (boolean) Whether anonymous
|
||||
- `session_id` (string) Session ID
|
||||
- `created_at` (string) ISO 8601 datetime
|
||||
- `updated_at` (string) ISO 8601 datetime
|
||||
|
||||
### Errors
|
||||
- 404, `end_user_not_found`, end user not found
|
||||
- 500, internal server error
|
||||
|
||||
</Col>
|
||||
<Col sticky>
|
||||
### Request Example
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="GET"
|
||||
label="/end-users/:end_user_id"
|
||||
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
|
||||
--header 'Authorization: Bearer {api_key}'`}
|
||||
/>
|
||||
|
||||
### Response Example
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
|
||||
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
|
||||
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
|
||||
"type": "service_api",
|
||||
"external_user_id": "abc-123",
|
||||
"name": "Alice",
|
||||
"is_anonymous": false,
|
||||
"session_id": "abc-123",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/files/:file_id/preview'
|
||||
method='GET'
|
||||
|
||||
@@ -272,71 +272,6 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/end-users/:end_user_id'
|
||||
method='GET'
|
||||
title='エンドユーザーを取得'
|
||||
name='#end-user'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
エンドユーザー ID からエンドユーザー情報を取得します。
|
||||
|
||||
他の API がエンドユーザー ID(例:ファイルアップロードの `created_by`)を返す場合に利用できます。
|
||||
|
||||
### パスパラメータ
|
||||
- `end_user_id` (uuid) 必須
|
||||
エンドユーザー ID。
|
||||
|
||||
### レスポンス
|
||||
EndUser オブジェクトを返します。
|
||||
- `id` (uuid) ID
|
||||
- `tenant_id` (uuid) テナント ID
|
||||
- `app_id` (uuid) アプリ ID
|
||||
- `type` (string) エンドユーザー種別
|
||||
- `external_user_id` (string) 外部ユーザー ID
|
||||
- `name` (string) 名前
|
||||
- `is_anonymous` (boolean) 匿名ユーザーかどうか
|
||||
- `session_id` (string) セッション ID
|
||||
- `created_at` (string) ISO 8601 日時
|
||||
- `updated_at` (string) ISO 8601 日時
|
||||
|
||||
### エラー
|
||||
- 404, `end_user_not_found`, エンドユーザーが見つかりません
|
||||
- 500, 内部サーバーエラー
|
||||
|
||||
</Col>
|
||||
<Col sticky>
|
||||
### リクエスト例
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="GET"
|
||||
label="/end-users/:end_user_id"
|
||||
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
|
||||
--header 'Authorization: Bearer {api_key}'`}
|
||||
/>
|
||||
|
||||
### レスポンス例
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
|
||||
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
|
||||
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
|
||||
"type": "service_api",
|
||||
"external_user_id": "abc-123",
|
||||
"name": "Alice",
|
||||
"is_anonymous": false,
|
||||
"session_id": "abc-123",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/files/:file_id/preview'
|
||||
method='GET'
|
||||
|
||||
@@ -249,69 +249,6 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/end-users/:end_user_id'
|
||||
method='GET'
|
||||
title='获取终端用户'
|
||||
name='#end-user'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
通过终端用户 ID 获取终端用户信息。
|
||||
|
||||
当其他 API 返回终端用户 ID(例如:上传文件接口返回的 `created_by`)时,可使用该接口查询对应的终端用户信息。
|
||||
|
||||
### 路径参数
|
||||
- `end_user_id` (uuid) 必需
|
||||
终端用户 ID。
|
||||
|
||||
### Response
|
||||
返回 EndUser 对象。
|
||||
- `id` (uuid) ID
|
||||
- `tenant_id` (uuid) 工作空间(Tenant)ID
|
||||
- `app_id` (uuid) 应用 ID
|
||||
- `type` (string) 终端用户类型
|
||||
- `external_user_id` (string) 外部用户 ID
|
||||
- `name` (string) 名称
|
||||
- `is_anonymous` (boolean) 是否匿名
|
||||
- `session_id` (string) 会话 ID
|
||||
- `created_at` (string) ISO 8601 时间
|
||||
- `updated_at` (string) ISO 8601 时间
|
||||
|
||||
### Errors
|
||||
- 404,`end_user_not_found`,终端用户不存在
|
||||
- 500,内部服务器错误
|
||||
|
||||
</Col>
|
||||
<Col sticky>
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="GET"
|
||||
label="/end-users/:end_user_id"
|
||||
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
|
||||
--header 'Authorization: Bearer {api_key}'`}
|
||||
/>
|
||||
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
|
||||
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
|
||||
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
|
||||
"type": "service_api",
|
||||
"external_user_id": "abc-123",
|
||||
"name": "Alice",
|
||||
"is_anonymous": false,
|
||||
"session_id": "abc-123",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/files/:file_id/preview'
|
||||
method='GET'
|
||||
|
||||
@@ -392,71 +392,6 @@ Chat applications support session persistence, allowing previous chat history to
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/end-users/:end_user_id'
|
||||
method='GET'
|
||||
title='Get End User'
|
||||
name='#end-user'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
Retrieve an end user by ID.
|
||||
|
||||
This is useful when other APIs return an end-user ID (e.g. `created_by` from File Upload).
|
||||
|
||||
### Path Parameters
|
||||
- `end_user_id` (uuid) Required
|
||||
End user ID.
|
||||
|
||||
### Response
|
||||
Returns an EndUser object.
|
||||
- `id` (uuid) ID
|
||||
- `tenant_id` (uuid) Tenant ID
|
||||
- `app_id` (uuid) App ID
|
||||
- `type` (string) End user type
|
||||
- `external_user_id` (string) External user ID
|
||||
- `name` (string) Name
|
||||
- `is_anonymous` (boolean) Whether anonymous
|
||||
- `session_id` (string) Session ID
|
||||
- `created_at` (string) ISO 8601 datetime
|
||||
- `updated_at` (string) ISO 8601 datetime
|
||||
|
||||
### Errors
|
||||
- 404, `end_user_not_found`, end user not found
|
||||
- 500, internal server error
|
||||
|
||||
</Col>
|
||||
<Col sticky>
|
||||
### Request Example
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="GET"
|
||||
label="/end-users/:end_user_id"
|
||||
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
|
||||
--header 'Authorization: Bearer {api_key}'`}
|
||||
/>
|
||||
|
||||
### Response Example
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
|
||||
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
|
||||
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
|
||||
"type": "service_api",
|
||||
"external_user_id": "abc-123",
|
||||
"name": "Alice",
|
||||
"is_anonymous": false,
|
||||
"session_id": "abc-123",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/files/:file_id/preview'
|
||||
method='GET'
|
||||
|
||||
@@ -393,71 +393,6 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/end-users/:end_user_id'
|
||||
method='GET'
|
||||
title='エンドユーザーを取得'
|
||||
name='#end-user'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
エンドユーザー ID からエンドユーザー情報を取得します。
|
||||
|
||||
他の API がエンドユーザー ID(例:ファイルアップロードの `created_by`)を返す場合に利用できます。
|
||||
|
||||
### パスパラメータ
|
||||
- `end_user_id` (uuid) 必須
|
||||
エンドユーザー ID。
|
||||
|
||||
### レスポンス
|
||||
EndUser オブジェクトを返します。
|
||||
- `id` (uuid) ID
|
||||
- `tenant_id` (uuid) テナント ID
|
||||
- `app_id` (uuid) アプリ ID
|
||||
- `type` (string) エンドユーザー種別
|
||||
- `external_user_id` (string) 外部ユーザー ID
|
||||
- `name` (string) 名前
|
||||
- `is_anonymous` (boolean) 匿名ユーザーかどうか
|
||||
- `session_id` (string) セッション ID
|
||||
- `created_at` (string) ISO 8601 日時
|
||||
- `updated_at` (string) ISO 8601 日時
|
||||
|
||||
### エラー
|
||||
- 404, `end_user_not_found`, エンドユーザーが見つかりません
|
||||
- 500, 内部サーバーエラー
|
||||
|
||||
</Col>
|
||||
<Col sticky>
|
||||
### リクエスト例
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="GET"
|
||||
label="/end-users/:end_user_id"
|
||||
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
|
||||
--header 'Authorization: Bearer {api_key}'`}
|
||||
/>
|
||||
|
||||
### レスポンス例
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
|
||||
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
|
||||
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
|
||||
"type": "service_api",
|
||||
"external_user_id": "abc-123",
|
||||
"name": "Alice",
|
||||
"is_anonymous": false,
|
||||
"session_id": "abc-123",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/files/:file_id/preview'
|
||||
method='GET'
|
||||
|
||||
@@ -390,69 +390,6 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/end-users/:end_user_id'
|
||||
method='GET'
|
||||
title='获取终端用户'
|
||||
name='#end-user'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
通过终端用户 ID 获取终端用户信息。
|
||||
|
||||
当其他 API 返回终端用户 ID(例如:上传文件接口返回的 `created_by`)时,可使用该接口查询对应的终端用户信息。
|
||||
|
||||
### 路径参数
|
||||
- `end_user_id` (uuid) 必需
|
||||
终端用户 ID。
|
||||
|
||||
### Response
|
||||
返回 EndUser 对象。
|
||||
- `id` (uuid) ID
|
||||
- `tenant_id` (uuid) 工作空间(Tenant)ID
|
||||
- `app_id` (uuid) 应用 ID
|
||||
- `type` (string) 终端用户类型
|
||||
- `external_user_id` (string) 外部用户 ID
|
||||
- `name` (string) 名称
|
||||
- `is_anonymous` (boolean) 是否匿名
|
||||
- `session_id` (string) 会话 ID
|
||||
- `created_at` (string) ISO 8601 时间
|
||||
- `updated_at` (string) ISO 8601 时间
|
||||
|
||||
### Errors
|
||||
- 404,`end_user_not_found`,终端用户不存在
|
||||
- 500,内部服务器错误
|
||||
|
||||
</Col>
|
||||
<Col sticky>
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="GET"
|
||||
label="/end-users/:end_user_id"
|
||||
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
|
||||
--header 'Authorization: Bearer {api_key}'`}
|
||||
/>
|
||||
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
|
||||
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
|
||||
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
|
||||
"type": "service_api",
|
||||
"external_user_id": "abc-123",
|
||||
"name": "Alice",
|
||||
"is_anonymous": false,
|
||||
"session_id": "abc-123",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/files/:file_id/preview'
|
||||
method='GET'
|
||||
|
||||
@@ -362,71 +362,6 @@ Chat applications support session persistence, allowing previous chat history to
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/end-users/:end_user_id'
|
||||
method='GET'
|
||||
title='Get End User'
|
||||
name='#end-user'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
Retrieve an end user by ID.
|
||||
|
||||
This is useful when other APIs return an end-user ID (e.g. `created_by` from File Upload).
|
||||
|
||||
### Path Parameters
|
||||
- `end_user_id` (uuid) Required
|
||||
End user ID.
|
||||
|
||||
### Response
|
||||
Returns an EndUser object.
|
||||
- `id` (uuid) ID
|
||||
- `tenant_id` (uuid) Tenant ID
|
||||
- `app_id` (uuid) App ID
|
||||
- `type` (string) End user type
|
||||
- `external_user_id` (string) External user ID
|
||||
- `name` (string) Name
|
||||
- `is_anonymous` (boolean) Whether anonymous
|
||||
- `session_id` (string) Session ID
|
||||
- `created_at` (string) ISO 8601 datetime
|
||||
- `updated_at` (string) ISO 8601 datetime
|
||||
|
||||
### Errors
|
||||
- 404, `end_user_not_found`, end user not found
|
||||
- 500, internal server error
|
||||
|
||||
</Col>
|
||||
<Col sticky>
|
||||
### Request Example
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="GET"
|
||||
label="/end-users/:end_user_id"
|
||||
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
|
||||
--header 'Authorization: Bearer {api_key}'`}
|
||||
/>
|
||||
|
||||
### Response Example
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
|
||||
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
|
||||
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
|
||||
"type": "service_api",
|
||||
"external_user_id": "abc-123",
|
||||
"name": "Alice",
|
||||
"is_anonymous": false,
|
||||
"session_id": "abc-123",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/files/:file_id/preview'
|
||||
method='GET'
|
||||
|
||||
@@ -362,71 +362,6 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/end-users/:end_user_id'
|
||||
method='GET'
|
||||
title='エンドユーザーを取得'
|
||||
name='#end-user'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
エンドユーザー ID からエンドユーザー情報を取得します。
|
||||
|
||||
他の API がエンドユーザー ID(例:ファイルアップロードの `created_by`)を返す場合に利用できます。
|
||||
|
||||
### パスパラメータ
|
||||
- `end_user_id` (uuid) 必須
|
||||
エンドユーザー ID。
|
||||
|
||||
### レスポンス
|
||||
EndUser オブジェクトを返します。
|
||||
- `id` (uuid) ID
|
||||
- `tenant_id` (uuid) テナント ID
|
||||
- `app_id` (uuid) アプリ ID
|
||||
- `type` (string) エンドユーザー種別
|
||||
- `external_user_id` (string) 外部ユーザー ID
|
||||
- `name` (string) 名前
|
||||
- `is_anonymous` (boolean) 匿名ユーザーかどうか
|
||||
- `session_id` (string) セッション ID
|
||||
- `created_at` (string) ISO 8601 日時
|
||||
- `updated_at` (string) ISO 8601 日時
|
||||
|
||||
### エラー
|
||||
- 404, `end_user_not_found`, エンドユーザーが見つかりません
|
||||
- 500, 内部サーバーエラー
|
||||
|
||||
</Col>
|
||||
<Col sticky>
|
||||
### リクエスト例
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="GET"
|
||||
label="/end-users/:end_user_id"
|
||||
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
|
||||
--header 'Authorization: Bearer {api_key}'`}
|
||||
/>
|
||||
|
||||
### レスポンス例
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
|
||||
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
|
||||
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
|
||||
"type": "service_api",
|
||||
"external_user_id": "abc-123",
|
||||
"name": "Alice",
|
||||
"is_anonymous": false,
|
||||
"session_id": "abc-123",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/files/:file_id/preview'
|
||||
method='GET'
|
||||
|
||||
@@ -368,69 +368,6 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/end-users/:end_user_id'
|
||||
method='GET'
|
||||
title='获取终端用户'
|
||||
name='#end-user'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
通过终端用户 ID 获取终端用户信息。
|
||||
|
||||
当其他 API 返回终端用户 ID(例如:上传文件接口返回的 `created_by`)时,可使用该接口查询对应的终端用户信息。
|
||||
|
||||
### 路径参数
|
||||
- `end_user_id` (uuid) 必需
|
||||
终端用户 ID。
|
||||
|
||||
### Response
|
||||
返回 EndUser 对象。
|
||||
- `id` (uuid) ID
|
||||
- `tenant_id` (uuid) 工作空间(Tenant)ID
|
||||
- `app_id` (uuid) 应用 ID
|
||||
- `type` (string) 终端用户类型
|
||||
- `external_user_id` (string) 外部用户 ID
|
||||
- `name` (string) 名称
|
||||
- `is_anonymous` (boolean) 是否匿名
|
||||
- `session_id` (string) 会话 ID
|
||||
- `created_at` (string) ISO 8601 时间
|
||||
- `updated_at` (string) ISO 8601 时间
|
||||
|
||||
### Errors
|
||||
- 404,`end_user_not_found`,终端用户不存在
|
||||
- 500,内部服务器错误
|
||||
|
||||
</Col>
|
||||
<Col sticky>
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="GET"
|
||||
label="/end-users/:end_user_id"
|
||||
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
|
||||
--header 'Authorization: Bearer {api_key}'`}
|
||||
/>
|
||||
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
|
||||
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
|
||||
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
|
||||
"type": "service_api",
|
||||
"external_user_id": "abc-123",
|
||||
"name": "Alice",
|
||||
"is_anonymous": false,
|
||||
"session_id": "abc-123",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/files/:file_id/preview'
|
||||
method='GET'
|
||||
|
||||
@@ -740,71 +740,6 @@ Workflow applications offers non-session support and is ideal for translation, a
|
||||
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/end-users/:end_user_id'
|
||||
method='GET'
|
||||
title='Get End User'
|
||||
name='#end-user'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
Retrieve an end user by ID.
|
||||
|
||||
This is useful when other APIs return an end-user ID (e.g. `created_by` from File Upload).
|
||||
|
||||
### Path Parameters
|
||||
- `end_user_id` (uuid) Required
|
||||
End user ID.
|
||||
|
||||
### Response
|
||||
Returns an EndUser object.
|
||||
- `id` (uuid) ID
|
||||
- `tenant_id` (uuid) Tenant ID
|
||||
- `app_id` (uuid) App ID
|
||||
- `type` (string) End user type
|
||||
- `external_user_id` (string) External user ID
|
||||
- `name` (string) Name
|
||||
- `is_anonymous` (boolean) Whether anonymous
|
||||
- `session_id` (string) Session ID
|
||||
- `created_at` (string) ISO 8601 datetime
|
||||
- `updated_at` (string) ISO 8601 datetime
|
||||
|
||||
### Errors
|
||||
- 404, `end_user_not_found`, end user not found
|
||||
- 500, internal server error
|
||||
|
||||
</Col>
|
||||
<Col sticky>
|
||||
### Request Example
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="GET"
|
||||
label="/end-users/:end_user_id"
|
||||
targetCode={`curl -X GET '${props.appDetail.api_base_url}/end-users/6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13' \\
|
||||
--header 'Authorization: Bearer {api_key}'`}
|
||||
/>
|
||||
|
||||
### Response Example
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "6ad1ab0a-73ff-4ac1-b9e4-cdb312f71f13",
|
||||
"tenant_id": "8c0f3f3a-66b0-4b55-a0bf-8b8e0d6aee7d",
|
||||
"app_id": "6c8c3f41-2c6f-4e1b-8f4f-7f11c8f2ad2a",
|
||||
"type": "service_api",
|
||||
"external_user_id": "abc-123",
|
||||
"name": "Alice",
|
||||
"is_anonymous": false,
|
||||
"session_id": "abc-123",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/workflows/logs'
|
||||
method='GET'
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user