mirror of
https://github.com/langgenius/dify.git
synced 2026-04-11 11:49:23 +08:00
Compare commits
1 Commits
1.11.1
...
feat/fallb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61a6c6dbcf |
8
.github/CODEOWNERS
vendored
8
.github/CODEOWNERS
vendored
@@ -9,14 +9,6 @@
|
||||
# Backend (default owner, more specific rules below will override)
|
||||
api/ @QuantumGhost
|
||||
|
||||
# Backend - MCP
|
||||
api/core/mcp/ @Nov1c444
|
||||
api/core/entities/mcp_provider.py @Nov1c444
|
||||
api/services/tools/mcp_tools_manage_service.py @Nov1c444
|
||||
api/controllers/mcp/ @Nov1c444
|
||||
api/controllers/console/app/mcp_server.py @Nov1c444
|
||||
api/tests/**/*mcp* @Nov1c444
|
||||
|
||||
# Backend - Workflow - Engine (Core graph execution engine)
|
||||
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
||||
|
||||
14
.github/ISSUE_TEMPLATE/refactor.yml
vendored
14
.github/ISSUE_TEMPLATE/refactor.yml
vendored
@@ -1,6 +1,8 @@
|
||||
name: "✨ Refactor or Chore"
|
||||
description: Refactor existing code or perform maintenance chores to improve readability and reliability.
|
||||
title: "[Refactor/Chore] "
|
||||
name: "✨ Refactor"
|
||||
description: Refactor existing code for improved readability and maintainability.
|
||||
title: "[Chore/Refactor] "
|
||||
labels:
|
||||
- refactor
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
@@ -9,7 +11,7 @@ body:
|
||||
options:
|
||||
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
|
||||
required: true
|
||||
- label: This is only for refactors or chores; if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
required: true
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
@@ -23,14 +25,14 @@ body:
|
||||
id: description
|
||||
attributes:
|
||||
label: Description
|
||||
placeholder: "Describe the refactor or chore you are proposing."
|
||||
placeholder: "Describe the refactor you are proposing."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: motivation
|
||||
attributes:
|
||||
label: Motivation
|
||||
placeholder: "Explain why this refactor or chore is necessary."
|
||||
placeholder: "Explain why this refactor is necessary."
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
|
||||
13
.github/ISSUE_TEMPLATE/tracker.yml
vendored
Normal file
13
.github/ISSUE_TEMPLATE/tracker.yml
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
name: "👾 Tracker"
|
||||
description: For inner usages, please do not use this template.
|
||||
title: "[Tracker] "
|
||||
labels:
|
||||
- tracker
|
||||
body:
|
||||
- type: textarea
|
||||
id: content
|
||||
attributes:
|
||||
label: Blockers
|
||||
placeholder: "- [ ] ..."
|
||||
validations:
|
||||
required: true
|
||||
21
.github/workflows/semantic-pull-request.yml
vendored
21
.github/workflows/semantic-pull-request.yml
vendored
@@ -1,21 +0,0 @@
|
||||
name: Semantic Pull Request
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types:
|
||||
- opened
|
||||
- edited
|
||||
- reopened
|
||||
- synchronize
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
name: Validate PR title
|
||||
permissions:
|
||||
pull-requests: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check title
|
||||
uses: amannn/action-semantic-pull-request@v6.1.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -654,9 +654,3 @@ TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
# Maximum number of segments for dataset segments API (0 for unlimited)
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
||||
|
||||
# Multimodal knowledgebase limit
|
||||
SINGLE_CHUNK_ATTACHMENT_LIMIT=10
|
||||
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
|
||||
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
|
||||
IMAGE_FILE_BATCH_LIMIT=10
|
||||
|
||||
@@ -360,26 +360,6 @@ class FileUploadConfig(BaseSettings):
|
||||
default=10,
|
||||
)
|
||||
|
||||
IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field(
|
||||
description="Maximum number of files allowed in a image batch upload operation",
|
||||
default=10,
|
||||
)
|
||||
|
||||
SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field(
|
||||
description="Maximum number of files allowed in a single chunk attachment",
|
||||
default=10,
|
||||
)
|
||||
|
||||
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||
description="Maximum allowed image file size for attachments in megabytes",
|
||||
default=2,
|
||||
)
|
||||
|
||||
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field(
|
||||
description="Timeout for downloading image attachments in seconds",
|
||||
default=60,
|
||||
)
|
||||
|
||||
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
|
||||
description=(
|
||||
"Comma-separated list of file extensions that are blocked from upload. "
|
||||
|
||||
@@ -61,7 +61,6 @@ class ChatMessagesQuery(BaseModel):
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
@@ -325,7 +324,6 @@ class MessageFeedbackApi(Resource):
|
||||
db.session.delete(feedback)
|
||||
elif args.rating and feedback:
|
||||
feedback.rating = args.rating
|
||||
feedback.content = args.content
|
||||
elif not args.rating and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
@@ -337,7 +335,6 @@ class MessageFeedbackApi(Resource):
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=rating_value,
|
||||
content=args.content,
|
||||
from_source="admin",
|
||||
from_account_id=current_user.id,
|
||||
)
|
||||
|
||||
@@ -114,7 +114,7 @@ class AppTriggersApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserEnable.__name__])
|
||||
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@@ -151,7 +151,6 @@ class DatasetUpdatePayload(BaseModel):
|
||||
external_knowledge_id: str | None = None
|
||||
external_knowledge_api_id: str | None = None
|
||||
icon_info: dict[str, Any] | None = None
|
||||
is_multimodal: bool | None = False
|
||||
|
||||
@field_validator("indexing_technique")
|
||||
@classmethod
|
||||
@@ -422,18 +421,19 @@ class DatasetApi(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
|
||||
payload_data = payload.model_dump(exclude_unset=True)
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check embedding model setting
|
||||
if (
|
||||
payload.indexing_technique == "high_quality"
|
||||
and payload.embedding_model_provider is not None
|
||||
and payload.embedding_model is not None
|
||||
):
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
|
||||
)
|
||||
payload.is_multimodal = is_multimodal
|
||||
payload_data = payload.model_dump(exclude_unset=True)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
current_user, dataset, payload.permission, payload.partial_member_list
|
||||
|
||||
@@ -424,10 +424,6 @@ class DatasetInitApi(Resource):
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=knowledge_config.embedding_model,
|
||||
)
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
|
||||
)
|
||||
knowledge_config.is_multimodal = is_multimodal
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
|
||||
@@ -51,7 +51,6 @@ class SegmentCreatePayload(BaseModel):
|
||||
content: str
|
||||
answer: str | None = None
|
||||
keywords: list[str] | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
||||
class SegmentUpdatePayload(BaseModel):
|
||||
@@ -59,7 +58,6 @@ class SegmentUpdatePayload(BaseModel):
|
||||
answer: str | None = None
|
||||
keywords: list[str] | None = None
|
||||
regenerate_child_chunks: bool = False
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
||||
class BatchImportPayload(BaseModel):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import marshal, reqparse
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
@@ -33,7 +33,6 @@ class HitTestingPayload(BaseModel):
|
||||
query: str = Field(max_length=250)
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@@ -55,28 +54,16 @@ class DatasetsHitTestingBase:
|
||||
def hit_testing_args_check(args: dict[str, Any]):
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
@staticmethod
|
||||
def parse_args():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("query", type=str, required=False, location="json")
|
||||
.add_argument("attachment_ids", type=list, required=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
assert isinstance(current_user, Account)
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args.get("query"),
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
retrieval_model=args.get("retrieval_model"),
|
||||
external_retrieval_model=args.get("external_retrieval_model"),
|
||||
attachment_ids=args.get("attachment_ids"),
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
)
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
|
||||
@@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@console_ns.expect(console_ns.models[Parser.__name__])
|
||||
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@@ -52,24 +52,10 @@ class ChatMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str
|
||||
files: list[dict[str, Any]] | None = None
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
conversation_id: UUID | None = None
|
||||
parent_message_id: UUID | None = None
|
||||
retriever_from: str = Field(default="explore_app")
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id", mode="before")
|
||||
@classmethod
|
||||
def normalize_uuid(cls, value: str | UUID | None) -> str | None:
|
||||
"""
|
||||
Accept blank IDs and validate UUID format when provided.
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
try:
|
||||
return helper.uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("must be a valid UUID") from exc
|
||||
|
||||
|
||||
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -30,16 +30,9 @@ class ConversationListQuery(BaseModel):
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
name: str
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
|
||||
|
||||
|
||||
@@ -45,9 +45,6 @@ class FileApi(Resource):
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||
"image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
|
||||
"single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
|
||||
"attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
|
||||
@@ -230,7 +230,7 @@ class ModelProviderModelApi(Resource):
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@@ -282,10 +282,9 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
tenant_id=tenant_id, provider_name=provider
|
||||
)
|
||||
else:
|
||||
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
|
||||
normalized_model_type = args.model_type.to_origin_model_type()
|
||||
model_type = args.model_type
|
||||
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
|
||||
tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model
|
||||
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model
|
||||
)
|
||||
|
||||
return jsonable_encoder(
|
||||
|
||||
@@ -4,7 +4,7 @@ from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@@ -52,23 +52,11 @@ class ChatRequestPayload(BaseModel):
|
||||
query: str
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: str | None = Field(default=None, description="Conversation UUID")
|
||||
conversation_id: UUID | None = None
|
||||
retriever_from: str = Field(default="dev")
|
||||
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
|
||||
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@classmethod
|
||||
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
|
||||
"""Allow missing or blank conversation IDs; enforce UUID format when provided."""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
try:
|
||||
return helper.uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from uuid import UUID
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx._http import HTTPStatus
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
@@ -37,16 +37,9 @@ class ConversationListQuery(BaseModel):
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
|
||||
name: str = Field(description="New conversation name")
|
||||
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
||||
|
||||
@@ -62,7 +62,8 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
@@ -72,7 +73,7 @@ from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -580,7 +581,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
with self._database_session() as session:
|
||||
# Save message
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif event.stopped_by in (
|
||||
@@ -590,7 +591,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
# When hitting input-moderation or annotation-reply, the workflow will not start
|
||||
with self._database_session() as session:
|
||||
# Save message
|
||||
self._save_message(session=session)
|
||||
self._save_message(session=session, trace_manager=trace_manager)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
@@ -599,6 +600,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
event: QueueAdvancedChatMessageEndEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle advanced chat message end events."""
|
||||
@@ -616,7 +618,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
# Save message
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
@@ -770,7 +772,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if self._conversation_name_generate_thread:
|
||||
logger.debug("Conversation name generation running as daemon thread")
|
||||
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
|
||||
def _save_message(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
):
|
||||
message = self._get_message(session=session)
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
@@ -809,6 +817,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
metadata = self._task_state.metadata.model_dump()
|
||||
message.message_metadata = json.dumps(jsonable_encoder(metadata))
|
||||
|
||||
# Extract model provider and model_id from workflow node executions for tracing
|
||||
if message.workflow_run_id:
|
||||
model_info = self._extract_model_info_from_workflow(session, message.workflow_run_id)
|
||||
if model_info:
|
||||
message.model_provider = model_info.get("provider")
|
||||
message.model_id = model_info.get("model")
|
||||
|
||||
message_files = [
|
||||
MessageFile(
|
||||
message_id=message.id,
|
||||
@@ -826,6 +842,68 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
# Trigger MESSAGE_TRACE for tracing integrations
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
|
||||
)
|
||||
)
|
||||
|
||||
def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None:
|
||||
"""
|
||||
Extract model provider and model_id from workflow node executions.
|
||||
Returns dict with 'provider' and 'model' keys, or None if not found.
|
||||
"""
|
||||
try:
|
||||
# Query workflow node executions for LLM or Agent nodes
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.where(WorkflowNodeExecutionModel.node_type.in_(["llm", "agent"]))
|
||||
.order_by(WorkflowNodeExecutionModel.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
node_execution = session.scalar(stmt)
|
||||
|
||||
if not node_execution:
|
||||
return None
|
||||
|
||||
# Try to extract from execution_metadata for agent nodes
|
||||
if node_execution.execution_metadata:
|
||||
try:
|
||||
metadata = json.loads(node_execution.execution_metadata)
|
||||
agent_log = metadata.get("agent_log", [])
|
||||
# Look for the first agent thought with provider info
|
||||
for log_entry in agent_log:
|
||||
entry_metadata = log_entry.get("metadata", {})
|
||||
provider_str = entry_metadata.get("provider")
|
||||
if provider_str:
|
||||
# Parse format like "langgenius/deepseek/deepseek"
|
||||
parts = provider_str.split("/")
|
||||
if len(parts) >= 3:
|
||||
return {"provider": parts[1], "model": parts[2]}
|
||||
elif len(parts) == 2:
|
||||
return {"provider": parts[0], "model": parts[1]}
|
||||
except (json.JSONDecodeError, KeyError, AttributeError) as e:
|
||||
logger.debug("Failed to parse execution_metadata: %s", e)
|
||||
|
||||
# Try to extract from process_data for llm nodes
|
||||
if node_execution.process_data:
|
||||
try:
|
||||
process_data = json.loads(node_execution.process_data)
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider and model:
|
||||
return {"provider": provider, "model": model}
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.debug("Failed to parse process_data: %s", e)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("Failed to extract model info from workflow: %s", e)
|
||||
return None
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
@@ -83,7 +83,6 @@ class AppRunner:
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
context_files: list["File"] | None = None,
|
||||
) -> tuple[list[PromptMessage], list[str] | None]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
@@ -112,7 +111,6 @@ class AppRunner:
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
image_detail_config=image_detail_config,
|
||||
context_files=context_files,
|
||||
)
|
||||
else:
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
|
||||
@@ -11,7 +11,6 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.file import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
@@ -147,7 +146,6 @@ class ChatAppRunner(AppRunner):
|
||||
|
||||
# get context from datasets
|
||||
context = None
|
||||
context_files: list[File] = []
|
||||
if app_config.dataset and app_config.dataset.dataset_ids:
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager,
|
||||
@@ -158,7 +156,7 @@ class ChatAppRunner(AppRunner):
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval(application_generate_entity)
|
||||
context, retrieved_files = dataset_retrieval.retrieve(
|
||||
context = dataset_retrieval.retrieve(
|
||||
app_id=app_record.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_record.tenant_id,
|
||||
@@ -173,11 +171,7 @@ class ChatAppRunner(AppRunner):
|
||||
memory=memory,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
|
||||
"enabled", False
|
||||
),
|
||||
)
|
||||
context_files = retrieved_files or []
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
@@ -192,7 +186,6 @@ class ChatAppRunner(AppRunner):
|
||||
context=context,
|
||||
memory=memory,
|
||||
image_detail_config=image_detail_config,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
|
||||
@@ -10,7 +10,6 @@ from core.app.entities.app_invoke_entities import (
|
||||
CompletionAppGenerateEntity,
|
||||
)
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.moderation.base import ModerationError
|
||||
@@ -103,7 +102,6 @@ class CompletionAppRunner(AppRunner):
|
||||
|
||||
# get context from datasets
|
||||
context = None
|
||||
context_files: list[File] = []
|
||||
if app_config.dataset and app_config.dataset.dataset_ids:
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager,
|
||||
@@ -118,7 +116,7 @@ class CompletionAppRunner(AppRunner):
|
||||
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
|
||||
|
||||
dataset_retrieval = DatasetRetrieval(application_generate_entity)
|
||||
context, retrieved_files = dataset_retrieval.retrieve(
|
||||
context = dataset_retrieval.retrieve(
|
||||
app_id=app_record.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_record.tenant_id,
|
||||
@@ -132,11 +130,7 @@ class CompletionAppRunner(AppRunner):
|
||||
hit_callback=hit_callback,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
|
||||
"enabled", False
|
||||
),
|
||||
)
|
||||
context_files = retrieved_files or []
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
@@ -150,7 +144,6 @@ class CompletionAppRunner(AppRunner):
|
||||
query=query,
|
||||
context=context,
|
||||
image_detail_config=image_detail_config,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
|
||||
@@ -40,6 +40,9 @@ class EasyUITaskState(TaskState):
|
||||
"""
|
||||
|
||||
llm_result: LLMResult
|
||||
first_token_time: float | None = None
|
||||
last_token_time: float | None = None
|
||||
is_streaming_response: bool = False
|
||||
|
||||
|
||||
class WorkflowTaskState(TaskState):
|
||||
|
||||
@@ -332,6 +332,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
if not self._task_state.llm_result.prompt_messages:
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
# Track streaming response times
|
||||
if self._task_state.first_token_time is None:
|
||||
self._task_state.first_token_time = time.perf_counter()
|
||||
self._task_state.is_streaming_response = True
|
||||
self._task_state.last_token_time = time.perf_counter()
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
|
||||
if should_direct_answer:
|
||||
@@ -398,6 +404,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.llm_result.usage.latency = message.provider_response_latency
|
||||
|
||||
# Add streaming metrics to usage if available
|
||||
if self._task_state.is_streaming_response and self._task_state.first_token_time:
|
||||
start_time = self.start_at
|
||||
first_token_time = self._task_state.first_token_time
|
||||
last_token_time = self._task_state.last_token_time or first_token_time
|
||||
usage.time_to_first_token = round(first_token_time - start_time, 3)
|
||||
usage.time_to_generate = round(last_token_time - first_token_time, 3)
|
||||
|
||||
# Update metadata with the complete usage info
|
||||
self._task_state.metadata.usage = usage
|
||||
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
if trace_manager:
|
||||
|
||||
@@ -7,7 +7,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
|
||||
@@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler:
|
||||
document_id,
|
||||
)
|
||||
continue
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk_stmt = select(ChildChunk).where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PreviewDetail(BaseModel):
|
||||
@@ -20,7 +20,7 @@ class IndexingEstimate(BaseModel):
|
||||
class PipelineDataset(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = Field(default="", description="knowledge dataset description")
|
||||
description: str
|
||||
chunk_structure: str
|
||||
|
||||
|
||||
|
||||
@@ -213,23 +213,12 @@ class MCPProviderEntity(BaseModel):
|
||||
return None
|
||||
|
||||
def retrieve_tokens(self) -> OAuthTokens | None:
|
||||
"""Retrieve OAuth tokens if authentication is complete.
|
||||
|
||||
Returns:
|
||||
OAuthTokens if the provider has been authenticated, None otherwise.
|
||||
"""
|
||||
"""OAuth tokens if available"""
|
||||
if not self.credentials:
|
||||
return None
|
||||
credentials = self.decrypt_credentials()
|
||||
access_token = credentials.get("access_token", "")
|
||||
# Return None if access_token is empty to avoid generating invalid "Authorization: Bearer " header.
|
||||
# Note: We don't check for whitespace-only strings here because:
|
||||
# 1. OAuth servers don't return whitespace-only access tokens in practice
|
||||
# 2. Even if they did, the server would return 401, triggering the OAuth flow correctly
|
||||
if not access_token:
|
||||
return None
|
||||
return OAuthTokens(
|
||||
access_token=access_token,
|
||||
access_token=credentials.get("access_token", ""),
|
||||
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
|
||||
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
|
||||
refresh_token=credentials.get("refresh_token", ""),
|
||||
|
||||
@@ -7,7 +7,7 @@ import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask import current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
@@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
@@ -36,7 +36,6 @@ from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import UploadFile
|
||||
@@ -90,17 +89,8 @@ class IndexingRunner:
|
||||
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
|
||||
|
||||
# transform
|
||||
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
|
||||
if not current_user:
|
||||
raise ValueError("no current user found")
|
||||
current_user.set_tenant_id(dataset.tenant_id)
|
||||
documents = self._transform(
|
||||
index_processor,
|
||||
dataset,
|
||||
text_docs,
|
||||
requeried_document.doc_language,
|
||||
processing_rule.to_dict(),
|
||||
current_user=current_user,
|
||||
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
|
||||
)
|
||||
# save segment
|
||||
self._load_segments(dataset, requeried_document, documents)
|
||||
@@ -146,7 +136,7 @@ class IndexingRunner:
|
||||
|
||||
for document_segment in document_segments:
|
||||
db.session.delete(document_segment)
|
||||
if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
# delete child chunks
|
||||
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
|
||||
db.session.commit()
|
||||
@@ -162,17 +152,8 @@ class IndexingRunner:
|
||||
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
|
||||
|
||||
# transform
|
||||
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
|
||||
if not current_user:
|
||||
raise ValueError("no current user found")
|
||||
current_user.set_tenant_id(dataset.tenant_id)
|
||||
documents = self._transform(
|
||||
index_processor,
|
||||
dataset,
|
||||
text_docs,
|
||||
requeried_document.doc_language,
|
||||
processing_rule.to_dict(),
|
||||
current_user=current_user,
|
||||
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
|
||||
)
|
||||
# save segment
|
||||
self._load_segments(dataset, requeried_document, documents)
|
||||
@@ -228,7 +209,7 @@ class IndexingRunner:
|
||||
"dataset_id": document_segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunks = document_segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
@@ -321,7 +302,6 @@ class IndexingRunner:
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
|
||||
documents = index_processor.transform(
|
||||
text_docs,
|
||||
current_user=None,
|
||||
embedding_model_instance=embedding_model_instance,
|
||||
process_rule=processing_rule.to_dict(),
|
||||
tenant_id=tenant_id,
|
||||
@@ -571,10 +551,7 @@ class IndexingRunner:
|
||||
indexing_start_at = time.perf_counter()
|
||||
tokens = 0
|
||||
create_keyword_thread = None
|
||||
if (
|
||||
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
|
||||
and dataset.indexing_technique == "economy"
|
||||
):
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
|
||||
# create keyword index
|
||||
create_keyword_thread = threading.Thread(
|
||||
target=self._process_keyword_index,
|
||||
@@ -613,7 +590,7 @@ class IndexingRunner:
|
||||
for future in futures:
|
||||
tokens += future.result()
|
||||
if (
|
||||
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
|
||||
dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX
|
||||
and dataset.indexing_technique == "economy"
|
||||
and create_keyword_thread is not None
|
||||
):
|
||||
@@ -658,13 +635,7 @@ class IndexingRunner:
|
||||
db.session.commit()
|
||||
|
||||
def _process_chunk(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
index_processor: BaseIndexProcessor,
|
||||
chunk_documents: list[Document],
|
||||
dataset: Dataset,
|
||||
dataset_document: DatasetDocument,
|
||||
embedding_model_instance: ModelInstance | None,
|
||||
self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance
|
||||
):
|
||||
with flask_app.app_context():
|
||||
# check document is paused
|
||||
@@ -675,15 +646,8 @@ class IndexingRunner:
|
||||
page_content_list = [document.page_content for document in chunk_documents]
|
||||
tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list))
|
||||
|
||||
multimodal_documents = []
|
||||
for document in chunk_documents:
|
||||
if document.attachments and dataset.is_multimodal:
|
||||
multimodal_documents.extend(document.attachments)
|
||||
|
||||
# load index
|
||||
index_processor.load(
|
||||
dataset, chunk_documents, multimodal_documents=multimodal_documents, with_keywords=False
|
||||
)
|
||||
index_processor.load(dataset, chunk_documents, with_keywords=False)
|
||||
|
||||
document_ids = [document.metadata["doc_id"] for document in chunk_documents]
|
||||
db.session.query(DocumentSegment).where(
|
||||
@@ -746,7 +710,6 @@ class IndexingRunner:
|
||||
text_docs: list[Document],
|
||||
doc_language: str,
|
||||
process_rule: dict,
|
||||
current_user: Account | None = None,
|
||||
) -> list[Document]:
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
@@ -766,7 +729,6 @@ class IndexingRunner:
|
||||
|
||||
documents = index_processor.transform(
|
||||
text_docs,
|
||||
current_user,
|
||||
embedding_model_instance=embedding_model_instance,
|
||||
process_rule=process_rule,
|
||||
tenant_id=dataset.tenant_id,
|
||||
@@ -775,16 +737,14 @@ class IndexingRunner:
|
||||
|
||||
return documents
|
||||
|
||||
def _load_segments(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]):
|
||||
def _load_segments(self, dataset, dataset_document, documents):
|
||||
# save node to document segment
|
||||
doc_store = DatasetDocumentStore(
|
||||
dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
|
||||
)
|
||||
|
||||
# add document segments
|
||||
doc_store.add_documents(
|
||||
docs=documents, save_child=dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX
|
||||
)
|
||||
doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX)
|
||||
|
||||
# update document status to indexing
|
||||
cur_time = naive_utc_now()
|
||||
|
||||
@@ -554,16 +554,11 @@ class LLMGenerator:
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_raw = response.message.get_text_content()
|
||||
generated_raw = cast(str, response.message.content)
|
||||
first_brace = generated_raw.find("{")
|
||||
last_brace = generated_raw.rfind("}")
|
||||
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
|
||||
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
|
||||
json_str = generated_raw[first_brace : last_brace + 1]
|
||||
data = json_repair.loads(json_str)
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
|
||||
return data
|
||||
return {**json.loads(generated_raw[first_brace : last_brace + 1])}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"error": f"Failed to generate code. Error: {error}"}
|
||||
|
||||
@@ -10,9 +10,9 @@ from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
@@ -200,7 +200,7 @@ class ModelInstance:
|
||||
|
||||
def invoke_text_embedding(
|
||||
self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
|
||||
) -> EmbeddingResult:
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@@ -212,7 +212,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return cast(
|
||||
EmbeddingResult,
|
||||
TextEmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
@@ -223,34 +223,6 @@ class ModelInstance:
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
multimodel_documents: list[dict],
|
||||
user: str | None = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param multimodel_documents: multimodel documents to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return cast(
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
user=user,
|
||||
input_type=input_type,
|
||||
),
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
|
||||
"""
|
||||
Get number of tokens for text embedding
|
||||
@@ -304,40 +276,6 @@ class ModelInstance:
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
query: dict,
|
||||
docs: list[dict],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_moderation(self, text: str, user: str | None = None) -> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
@@ -523,32 +461,6 @@ class ModelManager:
|
||||
model=default_model_entity.model,
|
||||
)
|
||||
|
||||
def check_model_support_vision(self, tenant_id: str, provider: str, model: str, model_type: ModelType) -> bool:
|
||||
"""
|
||||
Check if model supports vision
|
||||
:param tenant_id: tenant id
|
||||
:param provider: provider name
|
||||
:param model: model name
|
||||
:return: True if model supports vision, False otherwise
|
||||
"""
|
||||
model_instance = self.get_model_instance(tenant_id, provider, model_type, model)
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
match model_type:
|
||||
case ModelType.LLM:
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
case ModelType.TEXT_EMBEDDING:
|
||||
model_type_instance = cast(TextEmbeddingModel, model_type_instance)
|
||||
case ModelType.RERANK:
|
||||
model_type_instance = cast(RerankModel, model_type_instance)
|
||||
case _:
|
||||
raise ValueError(f"Model type {model_type} is not supported")
|
||||
model_schema = model_type_instance.get_model_schema(model, model_instance.credentials)
|
||||
if not model_schema:
|
||||
return False
|
||||
if model_schema.features and ModelFeature.VISION in model_schema.features:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class LBModelManager:
|
||||
def __init__(
|
||||
|
||||
@@ -19,7 +19,7 @@ class EmbeddingUsage(ModelUsage):
|
||||
latency: float
|
||||
|
||||
|
||||
class EmbeddingResult(BaseModel):
|
||||
class TextEmbeddingResult(BaseModel):
|
||||
"""
|
||||
Model class for text embedding result.
|
||||
"""
|
||||
@@ -27,13 +27,3 @@ class EmbeddingResult(BaseModel):
|
||||
model: str
|
||||
embeddings: list[list[float]]
|
||||
usage: EmbeddingUsage
|
||||
|
||||
|
||||
class FileEmbeddingResult(BaseModel):
|
||||
"""
|
||||
Model class for file embedding result.
|
||||
"""
|
||||
|
||||
model: str
|
||||
embeddings: list[list[float]]
|
||||
usage: EmbeddingUsage
|
||||
|
||||
@@ -50,43 +50,3 @@ class RerankModel(AIModel):
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: dict,
|
||||
docs: list[dict],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke multimodal rerank model
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_multimodal_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@@ -2,7 +2,7 @@ from pydantic import ConfigDict
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
@@ -20,18 +20,16 @@ class TextEmbeddingModel(AIModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str] | None = None,
|
||||
multimodel_documents: list[dict] | None = None,
|
||||
texts: list[str],
|
||||
user: str | None = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> EmbeddingResult:
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param files: files to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
@@ -40,29 +38,16 @@ class TextEmbeddingModel(AIModel):
|
||||
|
||||
try:
|
||||
plugin_model_manager = PluginModelClient()
|
||||
if texts:
|
||||
return plugin_model_manager.invoke_text_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
)
|
||||
if multimodel_documents:
|
||||
return plugin_model_manager.invoke_multimodal_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
documents=multimodel_documents,
|
||||
input_type=input_type,
|
||||
)
|
||||
raise ValueError("No texts or files provided")
|
||||
return plugin_model_manager.invoke_text_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
|
||||
@@ -222,6 +222,59 @@ class TencentSpanBuilder:
|
||||
links=links,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_message_llm_span(
|
||||
trace_info: MessageTraceInfo, trace_id: int, parent_span_id: int, user_id: str
|
||||
) -> SpanData:
|
||||
"""Build LLM span for message traces with detailed LLM attributes."""
|
||||
status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
# Extract model information from `metadata`` or `message_data`
|
||||
trace_metadata = trace_info.metadata or {}
|
||||
message_data = trace_info.message_data or {}
|
||||
|
||||
model_provider = trace_metadata.get("ls_provider") or (
|
||||
message_data.get("model_provider", "") if isinstance(message_data, dict) else ""
|
||||
)
|
||||
model_name = trace_metadata.get("ls_model_name") or (
|
||||
message_data.get("model_id", "") if isinstance(message_data, dict) else ""
|
||||
)
|
||||
|
||||
inputs_str = str(trace_info.inputs or "")
|
||||
outputs_str = str(trace_info.outputs or "")
|
||||
|
||||
attributes = {
|
||||
GEN_AI_SESSION_ID: trace_metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: str(model_name),
|
||||
GEN_AI_PROVIDER: str(model_provider),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens or 0),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens or 0),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens or 0),
|
||||
GEN_AI_PROMPT: inputs_str,
|
||||
GEN_AI_COMPLETION: outputs_str,
|
||||
INPUT_VALUE: inputs_str,
|
||||
OUTPUT_VALUE: outputs_str,
|
||||
}
|
||||
|
||||
if trace_info.is_streaming_request:
|
||||
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=parent_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "llm"),
|
||||
name="GENERATION",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes=attributes,
|
||||
status=status,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
|
||||
"""Build tool span."""
|
||||
|
||||
@@ -107,9 +107,13 @@ class TencentDataTrace(BaseTraceInstance):
|
||||
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
|
||||
|
||||
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
|
||||
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
# Add LLM child span with detailed attributes
|
||||
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
|
||||
llm_span = TencentSpanBuilder.build_message_llm_span(trace_info, trace_id, parent_span_id, str(user_id))
|
||||
self.trace_client.add_span(llm_span)
|
||||
|
||||
self._record_message_llm_metrics(trace_info)
|
||||
|
||||
# Record trace duration for entry span
|
||||
|
||||
@@ -6,7 +6,7 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
@@ -243,14 +243,14 @@ class PluginModelClient(BasePluginClient):
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
input_type: str,
|
||||
) -> EmbeddingResult:
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
|
||||
type_=EmbeddingResult,
|
||||
type_=TextEmbeddingResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@@ -275,48 +275,6 @@ class PluginModelClient(BasePluginClient):
|
||||
|
||||
raise ValueError("Failed to invoke text embedding")
|
||||
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
documents: list[dict],
|
||||
input_type: str,
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
Invoke file embedding
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke",
|
||||
type_=EmbeddingResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"documents": documents,
|
||||
"input_type": input_type,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("Failed to invoke file embedding")
|
||||
|
||||
def get_text_embedding_num_tokens(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -403,51 +361,6 @@ class PluginModelClient(BasePluginClient):
|
||||
|
||||
raise ValueError("Failed to invoke rerank")
|
||||
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: dict,
|
||||
docs: list[dict],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke multimodal rerank
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke",
|
||||
type_=RerankResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "rerank",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"query": query,
|
||||
"docs": docs,
|
||||
"score_threshold": score_threshold,
|
||||
"top_n": top_n,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("Failed to invoke multimodal rerank")
|
||||
|
||||
def invoke_tts(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
||||
@@ -49,7 +49,6 @@ class SimplePromptTransform(PromptTransform):
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
context_files: list["File"] | None = None,
|
||||
) -> tuple[list[PromptMessage], list[str] | None]:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
|
||||
@@ -65,7 +64,6 @@ class SimplePromptTransform(PromptTransform):
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
image_detail_config=image_detail_config,
|
||||
context_files=context_files,
|
||||
)
|
||||
else:
|
||||
prompt_messages, stops = self._get_completion_model_prompt_messages(
|
||||
@@ -78,7 +76,6 @@ class SimplePromptTransform(PromptTransform):
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
image_detail_config=image_detail_config,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
return prompt_messages, stops
|
||||
@@ -190,7 +187,6 @@ class SimplePromptTransform(PromptTransform):
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
context_files: list["File"] | None = None,
|
||||
) -> tuple[list[PromptMessage], list[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
@@ -220,9 +216,9 @@ class SimplePromptTransform(PromptTransform):
|
||||
)
|
||||
|
||||
if query:
|
||||
prompt_messages.append(self._get_last_user_message(query, files, image_detail_config, context_files))
|
||||
prompt_messages.append(self._get_last_user_message(query, files, image_detail_config))
|
||||
else:
|
||||
prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config, context_files))
|
||||
prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config))
|
||||
|
||||
return prompt_messages, None
|
||||
|
||||
@@ -237,7 +233,6 @@ class SimplePromptTransform(PromptTransform):
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
context_files: list["File"] | None = None,
|
||||
) -> tuple[list[PromptMessage], list[str] | None]:
|
||||
# get prompt
|
||||
prompt, prompt_rules = self._get_prompt_str_and_rules(
|
||||
@@ -280,27 +275,20 @@ class SimplePromptTransform(PromptTransform):
|
||||
if stops is not None and len(stops) == 0:
|
||||
stops = None
|
||||
|
||||
return [self._get_last_user_message(prompt, files, image_detail_config, context_files)], stops
|
||||
return [self._get_last_user_message(prompt, files, image_detail_config)], stops
|
||||
|
||||
def _get_last_user_message(
|
||||
self,
|
||||
prompt: str,
|
||||
files: Sequence["File"],
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
context_files: list["File"] | None = None,
|
||||
) -> UserPromptMessage:
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
if files:
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
if context_files:
|
||||
for file in context_files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
if prompt_message_contents:
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||
|
||||
prompt_message = UserPromptMessage(content=prompt_message_contents)
|
||||
|
||||
@@ -2,7 +2,6 @@ from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.data_post_processor.reorder import ReorderRunner
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
|
||||
from core.rag.rerank.rerank_base import BaseRerankRunner
|
||||
@@ -31,10 +30,9 @@ class DataPostProcessor:
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||
) -> list[Document]:
|
||||
if self.rerank_runner:
|
||||
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type)
|
||||
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
|
||||
|
||||
if self.reorder_runner:
|
||||
documents = self.reorder_runner.run(documents)
|
||||
|
||||
@@ -1,30 +1,23 @@
|
||||
import concurrent.futures
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.embedding.retrieval import RetrievalSegments
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank_type import RerankMode
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.signature import sign_upload_file
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import ChildChunk, Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import UploadFile
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
@@ -44,15 +37,14 @@ class RetrievalService:
|
||||
retrieval_method: RetrievalMethod,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int = 4,
|
||||
top_k: int,
|
||||
score_threshold: float | None = 0.0,
|
||||
reranking_model: dict | None = None,
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: dict | None = None,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
attachment_ids: list | None = None,
|
||||
):
|
||||
if not query and not attachment_ids:
|
||||
if not query:
|
||||
return []
|
||||
dataset = cls._get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
@@ -64,52 +56,69 @@ class RetrievalService:
|
||||
# Optimize multithreading with thread pools
|
||||
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
|
||||
futures = []
|
||||
retrieval_service = RetrievalService()
|
||||
if query:
|
||||
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
retrieval_service._retrieve,
|
||||
cls.keyword_search,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
retrieval_method=retrieval_method,
|
||||
dataset=dataset,
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
all_documents=all_documents,
|
||||
exceptions=exceptions,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
)
|
||||
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
||||
futures.append(
|
||||
executor.submit(
|
||||
cls.embedding_search,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
reranking_mode=reranking_mode,
|
||||
weights=weights,
|
||||
document_ids_filter=document_ids_filter,
|
||||
attachment_id=None,
|
||||
all_documents=all_documents,
|
||||
retrieval_method=retrieval_method,
|
||||
exceptions=exceptions,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
)
|
||||
if attachment_ids:
|
||||
for attachment_id in attachment_ids:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
retrieval_service._retrieve,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
retrieval_method=retrieval_method,
|
||||
dataset=dataset,
|
||||
query=None,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
reranking_mode=reranking_mode,
|
||||
weights=weights,
|
||||
document_ids_filter=document_ids_filter,
|
||||
attachment_id=attachment_id,
|
||||
all_documents=all_documents,
|
||||
exceptions=exceptions,
|
||||
)
|
||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
||||
futures.append(
|
||||
executor.submit(
|
||||
cls.full_text_index_search,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
all_documents=all_documents,
|
||||
retrieval_method=retrieval_method,
|
||||
exceptions=exceptions,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
|
||||
concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED)
|
||||
)
|
||||
concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED)
|
||||
|
||||
if exceptions:
|
||||
raise ValueError(";\n".join(exceptions))
|
||||
|
||||
# Deduplicate documents for hybrid search to avoid duplicate chunks
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
|
||||
all_documents = cls._deduplicate_documents(all_documents)
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
|
||||
)
|
||||
all_documents = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
)
|
||||
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
@@ -214,7 +223,6 @@ class RetrievalService:
|
||||
retrieval_method: RetrievalMethod,
|
||||
exceptions: list,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
@@ -223,30 +231,14 @@ class RetrievalService:
|
||||
raise ValueError("dataset not found")
|
||||
|
||||
vector = Vector(dataset=dataset)
|
||||
documents = []
|
||||
if query_type == QueryType.TEXT_QUERY:
|
||||
documents.extend(
|
||||
vector.search_by_vector(
|
||||
query,
|
||||
search_type="similarity_score_threshold",
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
filter={"group_id": [dataset.id]},
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
)
|
||||
if query_type == QueryType.IMAGE_QUERY:
|
||||
if not dataset.is_multimodal:
|
||||
return
|
||||
documents.extend(
|
||||
vector.search_by_file(
|
||||
file_id=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
filter={"group_id": [dataset.id]},
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
)
|
||||
documents = vector.search_by_vector(
|
||||
query,
|
||||
search_type="similarity_score_threshold",
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
filter={"group_id": [dataset.id]},
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
|
||||
if documents:
|
||||
if (
|
||||
@@ -258,37 +250,14 @@ class RetrievalService:
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
|
||||
)
|
||||
if dataset.is_multimodal:
|
||||
model_manager = ModelManager()
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=reranking_model.get("reranking_provider_name") or "",
|
||||
model=reranking_model.get("reranking_model_name") or "",
|
||||
model_type=ModelType.RERANK,
|
||||
)
|
||||
if is_support_vision:
|
||||
all_documents.extend(
|
||||
data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents),
|
||||
query_type=query_type,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# not effective, return original documents
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
all_documents.extend(
|
||||
data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents),
|
||||
query_type=query_type,
|
||||
)
|
||||
all_documents.extend(
|
||||
data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents),
|
||||
)
|
||||
)
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
except Exception as e:
|
||||
@@ -370,161 +339,103 @@ class RetrievalService:
|
||||
records = []
|
||||
include_segment_ids = set()
|
||||
segment_child_map = {}
|
||||
segment_file_map = {}
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Process documents
|
||||
for document in documents:
|
||||
segment_id = None
|
||||
attachment_info = None
|
||||
child_chunk = None
|
||||
document_id = document.metadata.get("document_id")
|
||||
if document_id not in dataset_documents:
|
||||
|
||||
# Process documents
|
||||
for document in documents:
|
||||
document_id = document.metadata.get("document_id")
|
||||
if document_id not in dataset_documents:
|
||||
continue
|
||||
|
||||
dataset_document = dataset_documents[document_id]
|
||||
if not dataset_document:
|
||||
continue
|
||||
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
# Handle parent-child documents
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
|
||||
if not child_chunk:
|
||||
continue
|
||||
|
||||
dataset_document = dataset_documents[document_id]
|
||||
if not dataset_document:
|
||||
continue
|
||||
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
# Handle parent-child documents
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
attachment_info_dict = cls.get_segment_attachment_info(
|
||||
dataset_document.dataset_id,
|
||||
dataset_document.tenant_id,
|
||||
document.metadata.get("doc_id") or "",
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
attachment_info = attachment_info_dict["attachment_info"]
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
else:
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
|
||||
child_chunk = session.scalar(child_chunk_stmt)
|
||||
|
||||
if not child_chunk:
|
||||
continue
|
||||
segment_id = child_chunk.segment_id
|
||||
|
||||
if not segment_id:
|
||||
continue
|
||||
|
||||
segment = (
|
||||
session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == segment_id,
|
||||
)
|
||||
.first()
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == child_chunk.segment_id,
|
||||
)
|
||||
.options(
|
||||
load_only(
|
||||
DocumentSegment.id,
|
||||
DocumentSegment.content,
|
||||
DocumentSegment.answer,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
if child_chunk:
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
map_detail = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
record = {
|
||||
"segment": segment,
|
||||
}
|
||||
if attachment_info:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
records.append(record)
|
||||
else:
|
||||
if child_chunk:
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
if segment.id in segment_child_map:
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||
)
|
||||
else:
|
||||
segment_child_map[segment.id] = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
if attachment_info:
|
||||
if segment.id in segment_file_map:
|
||||
segment_file_map[segment.id].append(attachment_info)
|
||||
else:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
map_detail = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
record = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
else:
|
||||
# Handle normal documents
|
||||
segment = None
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
attachment_info_dict = cls.get_segment_attachment_info(
|
||||
dataset_document.dataset_id,
|
||||
dataset_document.tenant_id,
|
||||
document.metadata.get("doc_id") or "",
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
attachment_info = attachment_info_dict["attachment_info"]
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == segment_id,
|
||||
)
|
||||
segment = session.scalar(document_segment_stmt)
|
||||
if segment:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
else:
|
||||
index_node_id = document.metadata.get("doc_id")
|
||||
if not index_node_id:
|
||||
continue
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
segment = session.scalar(document_segment_stmt)
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||
)
|
||||
else:
|
||||
# Handle normal documents
|
||||
index_node_id = document.metadata.get("doc_id")
|
||||
if not index_node_id:
|
||||
continue
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
segment = db.session.scalar(document_segment_stmt)
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata.get("score"), # type: ignore
|
||||
}
|
||||
if attachment_info:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
records.append(record)
|
||||
else:
|
||||
if attachment_info:
|
||||
attachment_infos = segment_file_map.get(segment.id, [])
|
||||
if attachment_info not in attachment_infos:
|
||||
attachment_infos.append(attachment_info)
|
||||
segment_file_map[segment.id] = attachment_infos
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
include_segment_ids.add(segment.id)
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata.get("score"), # type: ignore
|
||||
}
|
||||
records.append(record)
|
||||
|
||||
# Add child chunks information to records
|
||||
for record in records:
|
||||
if record["segment"].id in segment_child_map:
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
||||
if record["segment"].id in segment_file_map:
|
||||
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
|
||||
|
||||
result = []
|
||||
for record in records:
|
||||
@@ -536,11 +447,6 @@ class RetrievalService:
|
||||
if not isinstance(child_chunks, list):
|
||||
child_chunks = None
|
||||
|
||||
# Extract files, ensuring it's a list or None
|
||||
files = record.get("files")
|
||||
if not isinstance(files, list):
|
||||
files = None
|
||||
|
||||
# Extract score, ensuring it's a float or None
|
||||
score_value = record.get("score")
|
||||
score = (
|
||||
@@ -550,149 +456,10 @@ class RetrievalService:
|
||||
)
|
||||
|
||||
# Create RetrievalSegments object
|
||||
retrieval_segment = RetrievalSegments(
|
||||
segment=segment, child_chunks=child_chunks, score=score, files=files
|
||||
)
|
||||
retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score)
|
||||
result.append(retrieval_segment)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
raise e
|
||||
|
||||
def _retrieve(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
retrieval_method: RetrievalMethod,
|
||||
dataset: Dataset,
|
||||
query: str | None = None,
|
||||
top_k: int = 4,
|
||||
score_threshold: float | None = 0.0,
|
||||
reranking_model: dict | None = None,
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: dict | None = None,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
attachment_id: str | None = None,
|
||||
all_documents: list[Document] = [],
|
||||
exceptions: list[str] = [],
|
||||
):
|
||||
if not query and not attachment_id:
|
||||
return
|
||||
with flask_app.app_context():
|
||||
all_documents_item: list[Document] = []
|
||||
# Optimize multithreading with thread pools
|
||||
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
|
||||
futures = []
|
||||
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH and query:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self.keyword_search,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
all_documents=all_documents_item,
|
||||
exceptions=exceptions,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
)
|
||||
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
||||
if query:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self.embedding_search,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
all_documents=all_documents_item,
|
||||
retrieval_method=retrieval_method,
|
||||
exceptions=exceptions,
|
||||
document_ids_filter=document_ids_filter,
|
||||
query_type=QueryType.TEXT_QUERY,
|
||||
)
|
||||
)
|
||||
if attachment_id:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self.embedding_search,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset.id,
|
||||
query=attachment_id,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
all_documents=all_documents_item,
|
||||
retrieval_method=retrieval_method,
|
||||
exceptions=exceptions,
|
||||
document_ids_filter=document_ids_filter,
|
||||
query_type=QueryType.IMAGE_QUERY,
|
||||
)
|
||||
)
|
||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method) and query:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self.full_text_index_search,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
all_documents=all_documents_item,
|
||||
retrieval_method=retrieval_method,
|
||||
exceptions=exceptions,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
)
|
||||
concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED)
|
||||
|
||||
if exceptions:
|
||||
raise ValueError(";\n".join(exceptions))
|
||||
|
||||
# Deduplicate documents for hybrid search to avoid duplicate chunks
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
|
||||
if attachment_id and reranking_mode == RerankMode.WEIGHTED_SCORE:
|
||||
all_documents.extend(all_documents_item)
|
||||
all_documents_item = self._deduplicate_documents(all_documents_item)
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
|
||||
)
|
||||
|
||||
query = query or attachment_id
|
||||
if not query:
|
||||
return
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY,
|
||||
)
|
||||
|
||||
all_documents.extend(all_documents_item)
|
||||
|
||||
@classmethod
|
||||
def get_segment_attachment_info(
|
||||
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
|
||||
) -> dict[str, Any] | None:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
|
||||
if upload_file:
|
||||
attachment_binding = (
|
||||
session.query(SegmentAttachmentBinding)
|
||||
.where(SegmentAttachmentBinding.attachment_id == upload_file.id)
|
||||
.first()
|
||||
)
|
||||
if attachment_binding:
|
||||
attachment_info = {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"extension": "." + upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
|
||||
"size": upload_file.size,
|
||||
}
|
||||
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
|
||||
return None
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -13,13 +12,10 @@ from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from models.dataset import Dataset, Whitelist
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -207,47 +203,6 @@ class Vector:
|
||||
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
|
||||
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
|
||||
|
||||
def create_multimodal(self, file_documents: list | None = None, **kwargs):
|
||||
if file_documents:
|
||||
start = time.time()
|
||||
logger.info("start embedding %s files %s", len(file_documents), start)
|
||||
batch_size = 1000
|
||||
total_batches = len(file_documents) + batch_size - 1
|
||||
for i in range(0, len(file_documents), batch_size):
|
||||
batch = file_documents[i : i + batch_size]
|
||||
batch_start = time.time()
|
||||
logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch))
|
||||
|
||||
# Batch query all upload files to avoid N+1 queries
|
||||
attachment_ids = [doc.metadata["doc_id"] for doc in batch]
|
||||
stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||
upload_files = db.session.scalars(stmt).all()
|
||||
upload_file_map = {str(f.id): f for f in upload_files}
|
||||
|
||||
file_base64_list = []
|
||||
real_batch = []
|
||||
for document in batch:
|
||||
attachment_id = document.metadata["doc_id"]
|
||||
doc_type = document.metadata["doc_type"]
|
||||
upload_file = upload_file_map.get(attachment_id)
|
||||
if upload_file:
|
||||
blob = storage.load_once(upload_file.key)
|
||||
file_base64_str = base64.b64encode(blob).decode()
|
||||
file_base64_list.append(
|
||||
{
|
||||
"content": file_base64_str,
|
||||
"content_type": doc_type,
|
||||
"file_id": attachment_id,
|
||||
}
|
||||
)
|
||||
real_batch.append(document)
|
||||
batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list)
|
||||
logger.info(
|
||||
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
|
||||
)
|
||||
self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs)
|
||||
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
|
||||
|
||||
def add_texts(self, documents: list[Document], **kwargs):
|
||||
if kwargs.get("duplicate_check", False):
|
||||
documents = self._filter_duplicate_texts(documents)
|
||||
@@ -268,22 +223,6 @@ class Vector:
|
||||
query_vector = self._embeddings.embed_query(query)
|
||||
return self._vector_processor.search_by_vector(query_vector, **kwargs)
|
||||
|
||||
def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]:
|
||||
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
return []
|
||||
blob = storage.load_once(upload_file.key)
|
||||
file_base64_str = base64.b64encode(blob).decode()
|
||||
multimodal_vector = self._embeddings.embed_multimodal_query(
|
||||
{
|
||||
"content": file_base64_str,
|
||||
"content_type": DocType.IMAGE,
|
||||
"file_id": file_id,
|
||||
}
|
||||
)
|
||||
return self._vector_processor.search_by_vector(multimodal_vector, **kwargs)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._vector_processor.search_by_full_text(query, **kwargs)
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ from sqlalchemy import func, select
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import ChildChunk, Dataset, DocumentSegment
|
||||
|
||||
|
||||
class DatasetDocumentStore:
|
||||
@@ -120,9 +120,6 @@ class DatasetDocumentStore:
|
||||
|
||||
db.session.add(segment_document)
|
||||
db.session.flush()
|
||||
self.add_multimodel_documents_binding(
|
||||
segment_id=segment_document.id, multimodel_documents=doc.attachments
|
||||
)
|
||||
if save_child:
|
||||
if doc.children:
|
||||
for position, child in enumerate(doc.children, start=1):
|
||||
@@ -147,9 +144,6 @@ class DatasetDocumentStore:
|
||||
segment_document.index_node_hash = doc.metadata.get("doc_hash")
|
||||
segment_document.word_count = len(doc.page_content)
|
||||
segment_document.tokens = tokens
|
||||
self.add_multimodel_documents_binding(
|
||||
segment_id=segment_document.id, multimodel_documents=doc.attachments
|
||||
)
|
||||
if save_child and doc.children:
|
||||
# delete the existing child chunks
|
||||
db.session.query(ChildChunk).where(
|
||||
@@ -239,15 +233,3 @@ class DatasetDocumentStore:
|
||||
document_segment = db.session.scalar(stmt)
|
||||
|
||||
return document_segment
|
||||
|
||||
def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None):
|
||||
if multimodel_documents:
|
||||
for multimodel_document in multimodel_documents:
|
||||
binding = SegmentAttachmentBinding(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
dataset_id=self._dataset.id,
|
||||
document_id=self._document_id,
|
||||
segment_id=segment_id,
|
||||
attachment_id=multimodel_document.metadata["doc_id"],
|
||||
)
|
||||
db.session.add(binding)
|
||||
|
||||
@@ -104,88 +104,6 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return text_embeddings
|
||||
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
|
||||
"""Embed file documents."""
|
||||
# use doc embedding cache or store if not exists
|
||||
multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))]
|
||||
embedding_queue_indices = []
|
||||
for i, multimodel_document in enumerate(multimodel_documents):
|
||||
file_id = multimodel_document["file_id"]
|
||||
embedding = (
|
||||
db.session.query(Embedding)
|
||||
.filter_by(
|
||||
model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if embedding:
|
||||
multimodel_embeddings[i] = embedding.get_embedding()
|
||||
else:
|
||||
embedding_queue_indices.append(i)
|
||||
|
||||
# NOTE: avoid closing the shared scoped session here; downstream code may still have pending work
|
||||
|
||||
if embedding_queue_indices:
|
||||
embedding_queue_multimodel_documents = [multimodel_documents[i] for i in embedding_queue_indices]
|
||||
embedding_queue_embeddings = []
|
||||
try:
|
||||
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
self._model_instance.model, self._model_instance.credentials
|
||||
)
|
||||
max_chunks = (
|
||||
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
|
||||
else 1
|
||||
)
|
||||
for i in range(0, len(embedding_queue_multimodel_documents), max_chunks):
|
||||
batch_multimodel_documents = embedding_queue_multimodel_documents[i : i + max_chunks]
|
||||
|
||||
embedding_result = self._model_instance.invoke_multimodal_embedding(
|
||||
multimodel_documents=batch_multimodel_documents,
|
||||
user=self._user,
|
||||
input_type=EmbeddingInputType.DOCUMENT,
|
||||
)
|
||||
|
||||
for vector in embedding_result.embeddings:
|
||||
try:
|
||||
# FIXME: type ignore for numpy here
|
||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore
|
||||
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
|
||||
if np.isnan(normalized_embedding).any():
|
||||
# for issue #11827 float values are not json compliant
|
||||
logger.warning("Normalized embedding is nan: %s", normalized_embedding)
|
||||
continue
|
||||
embedding_queue_embeddings.append(normalized_embedding)
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
except Exception:
|
||||
logger.exception("Failed transform embedding")
|
||||
cache_embeddings = []
|
||||
try:
|
||||
for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
|
||||
multimodel_embeddings[i] = n_embedding
|
||||
file_id = multimodel_documents[i]["file_id"]
|
||||
if file_id not in cache_embeddings:
|
||||
embedding_cache = Embedding(
|
||||
model_name=self._model_instance.model,
|
||||
hash=file_id,
|
||||
provider_name=self._model_instance.provider,
|
||||
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
|
||||
)
|
||||
embedding_cache.set_embedding(n_embedding)
|
||||
db.session.add(embedding_cache)
|
||||
cache_embeddings.append(file_id)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
logger.exception("Failed to embed documents")
|
||||
raise ex
|
||||
|
||||
return multimodel_embeddings
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
@@ -228,46 +146,3 @@ class CacheEmbedding(Embeddings):
|
||||
raise ex
|
||||
|
||||
return embedding_results # type: ignore
|
||||
|
||||
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
|
||||
"""Embed multimodal documents."""
|
||||
# use doc embedding cache or store if not exists
|
||||
file_id = multimodel_document["file_id"]
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}"
|
||||
embedding = redis_client.get(embedding_cache_key)
|
||||
if embedding:
|
||||
redis_client.expire(embedding_cache_key, 600)
|
||||
decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float")
|
||||
return [float(x) for x in decoded_embedding]
|
||||
try:
|
||||
embedding_result = self._model_instance.invoke_multimodal_embedding(
|
||||
multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY
|
||||
)
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
# FIXME: type ignore for numpy here
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore
|
||||
if np.isnan(embedding_results).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
except Exception as ex:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Failed to embed multimodal document '%s'", multimodel_document["file_id"])
|
||||
raise ex
|
||||
|
||||
try:
|
||||
# encode embedding to base64
|
||||
embedding_vector = np.array(embedding_results)
|
||||
vector_bytes = embedding_vector.tobytes()
|
||||
# Transform to Base64
|
||||
encoded_vector = base64.b64encode(vector_bytes)
|
||||
# Transform to string
|
||||
encoded_str = encoded_vector.decode("utf-8")
|
||||
redis_client.setex(embedding_cache_key, 600, encoded_str)
|
||||
except Exception as ex:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception(
|
||||
"Failed to add embedding to redis for the multimodal document '%s'", multimodel_document["file_id"]
|
||||
)
|
||||
raise ex
|
||||
|
||||
return embedding_results # type: ignore
|
||||
|
||||
@@ -9,21 +9,11 @@ class Embeddings(ABC):
|
||||
"""Embed search docs."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
|
||||
"""Embed file documents."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
|
||||
"""Embed multimodal query."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Asynchronous Embed search docs."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -19,4 +19,3 @@ class RetrievalSegments(BaseModel):
|
||||
segment: DocumentSegment
|
||||
child_chunks: list[RetrievalChildChunk] | None = None
|
||||
score: float | None = None
|
||||
files: list[dict[str, str | int]] | None = None
|
||||
|
||||
@@ -21,4 +21,3 @@ class RetrievalSourceMetadata(BaseModel):
|
||||
page: int | None = None
|
||||
doc_metadata: dict[str, Any] | None = None
|
||||
title: str | None = None
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import os
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import load_workbook
|
||||
@@ -10,12 +10,6 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class Candidate(TypedDict):
|
||||
idx: int
|
||||
count: int
|
||||
map: dict[int, str]
|
||||
|
||||
|
||||
class ExcelExtractor(BaseExtractor):
|
||||
"""Load Excel files.
|
||||
|
||||
@@ -36,38 +30,32 @@ class ExcelExtractor(BaseExtractor):
|
||||
file_extension = os.path.splitext(self._file_path)[-1].lower()
|
||||
|
||||
if file_extension == ".xlsx":
|
||||
wb = load_workbook(self._file_path, read_only=True, data_only=True)
|
||||
try:
|
||||
for sheet_name in wb.sheetnames:
|
||||
sheet = wb[sheet_name]
|
||||
header_row_idx, column_map, max_col_idx = self._find_header_and_columns(sheet)
|
||||
if not column_map:
|
||||
continue
|
||||
start_row = header_row_idx + 1
|
||||
for row in sheet.iter_rows(min_row=start_row, max_col=max_col_idx, values_only=False):
|
||||
if all(cell.value is None for cell in row):
|
||||
continue
|
||||
page_content = []
|
||||
for col_idx, cell in enumerate(row):
|
||||
value = cell.value
|
||||
if col_idx in column_map:
|
||||
col_name = column_map[col_idx]
|
||||
if hasattr(cell, "hyperlink") and cell.hyperlink:
|
||||
target = getattr(cell.hyperlink, "target", None)
|
||||
if target:
|
||||
value = f"[{value}]({target})"
|
||||
if value is None:
|
||||
value = ""
|
||||
elif not isinstance(value, str):
|
||||
value = str(value)
|
||||
value = value.strip().replace('"', '\\"')
|
||||
page_content.append(f'"{col_name}":"{value}"')
|
||||
if page_content:
|
||||
documents.append(
|
||||
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
|
||||
)
|
||||
finally:
|
||||
wb.close()
|
||||
wb = load_workbook(self._file_path, data_only=True)
|
||||
for sheet_name in wb.sheetnames:
|
||||
sheet = wb[sheet_name]
|
||||
data = sheet.values
|
||||
cols = next(data, None)
|
||||
if cols is None:
|
||||
continue
|
||||
df = pd.DataFrame(data, columns=cols)
|
||||
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
for index, row in df.iterrows():
|
||||
page_content = []
|
||||
for col_index, (k, v) in enumerate(row.items()):
|
||||
if pd.notna(v):
|
||||
cell = sheet.cell(
|
||||
row=cast(int, index) + 2, column=col_index + 1
|
||||
) # +2 to account for header and 1-based index
|
||||
if cell.hyperlink:
|
||||
value = f"[{v}]({cell.hyperlink.target})"
|
||||
page_content.append(f'"{k}":"{value}"')
|
||||
else:
|
||||
page_content.append(f'"{k}":"{v}"')
|
||||
documents.append(
|
||||
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
|
||||
)
|
||||
|
||||
elif file_extension == ".xls":
|
||||
excel_file = pd.ExcelFile(self._file_path, engine="xlrd")
|
||||
@@ -75,9 +63,9 @@ class ExcelExtractor(BaseExtractor):
|
||||
df = excel_file.parse(sheet_name=excel_sheet_name)
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
for _, series_row in df.iterrows():
|
||||
for _, row in df.iterrows():
|
||||
page_content = []
|
||||
for k, v in series_row.items():
|
||||
for k, v in row.items():
|
||||
if pd.notna(v):
|
||||
page_content.append(f'"{k}":"{v}"')
|
||||
documents.append(
|
||||
@@ -87,61 +75,3 @@ class ExcelExtractor(BaseExtractor):
|
||||
raise ValueError(f"Unsupported file extension: {file_extension}")
|
||||
|
||||
return documents
|
||||
|
||||
def _find_header_and_columns(self, sheet, scan_rows=10) -> tuple[int, dict[int, str], int]:
|
||||
"""
|
||||
Scan first N rows to find the most likely header row.
|
||||
Returns:
|
||||
header_row_idx: 1-based index of the header row
|
||||
column_map: Dict mapping 0-based column index to column name
|
||||
max_col_idx: 1-based index of the last valid column (for iter_rows boundary)
|
||||
"""
|
||||
# Store potential candidates: (row_index, non_empty_count, column_map)
|
||||
candidates: list[Candidate] = []
|
||||
|
||||
# Limit scan to avoid performance issues on huge files
|
||||
# We iterate manually to control the read scope
|
||||
for current_row_idx, row in enumerate(sheet.iter_rows(min_row=1, max_row=scan_rows, values_only=True), start=1):
|
||||
# Filter out empty cells and build a temp map for this row
|
||||
# col_idx is 0-based
|
||||
row_map = {}
|
||||
for col_idx, cell_value in enumerate(row):
|
||||
if cell_value is not None and str(cell_value).strip():
|
||||
row_map[col_idx] = str(cell_value).strip().replace('"', '\\"')
|
||||
|
||||
if not row_map:
|
||||
continue
|
||||
|
||||
non_empty_count = len(row_map)
|
||||
|
||||
# Header selection heuristic (implemented):
|
||||
# - Prefer the first row with at least 2 non-empty columns.
|
||||
# - Fallback: choose the row with the most non-empty columns
|
||||
# (tie-breaker: smaller row index).
|
||||
candidates.append({"idx": current_row_idx, "count": non_empty_count, "map": row_map})
|
||||
|
||||
if not candidates:
|
||||
return 0, {}, 0
|
||||
|
||||
# Choose the best candidate header row.
|
||||
|
||||
best_candidate: Candidate | None = None
|
||||
|
||||
# Strategy: prefer the first row with >= 2 non-empty columns; otherwise fallback.
|
||||
|
||||
for cand in candidates:
|
||||
if cand["count"] >= 2:
|
||||
best_candidate = cand
|
||||
break
|
||||
|
||||
# Fallback: if no row has >= 2 columns, or all have 1, just take the one with max columns
|
||||
if not best_candidate:
|
||||
# Sort by count desc, then index asc
|
||||
candidates.sort(key=lambda x: (-x["count"], x["idx"]))
|
||||
best_candidate = candidates[0]
|
||||
|
||||
# Determine max_col_idx (1-based for openpyxl)
|
||||
# It is the index of the last valid column in our map + 1
|
||||
max_col_idx = max(best_candidate["map"].keys()) + 1
|
||||
|
||||
return best_candidate["idx"], best_candidate["map"], max_col_idx
|
||||
|
||||
@@ -84,46 +84,22 @@ class WordExtractor(BaseExtractor):
|
||||
image_count = 0
|
||||
image_map = {}
|
||||
|
||||
for rId, rel in doc.part.rels.items():
|
||||
for rel in doc.part.rels.values():
|
||||
if "image" in rel.target_ref:
|
||||
image_count += 1
|
||||
if rel.is_external:
|
||||
url = rel.target_ref
|
||||
if not self._is_valid_url(url):
|
||||
continue
|
||||
try:
|
||||
response = ssrf_proxy.get(url)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to download image from URL: %s: %s", url, str(e))
|
||||
continue
|
||||
response = ssrf_proxy.get(url)
|
||||
if response.status_code == 200:
|
||||
image_ext = mimetypes.guess_extension(response.headers.get("Content-Type", ""))
|
||||
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
|
||||
if image_ext is None:
|
||||
continue
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + image_ext
|
||||
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
|
||||
mime_type, _ = mimetypes.guess_type(file_key)
|
||||
storage.save(file_key, response.content)
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
extension=str(image_ext),
|
||||
mime_type=mime_type or "",
|
||||
created_by=self.user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
# Use rId as key for external images since target_part is undefined
|
||||
image_map[rId] = f""
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
image_ext = rel.target_ref.split(".")[-1]
|
||||
if image_ext is None:
|
||||
@@ -134,28 +110,26 @@ class WordExtractor(BaseExtractor):
|
||||
mime_type, _ = mimetypes.guess_type(file_key)
|
||||
|
||||
storage.save(file_key, rel.target_part.blob)
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
extension=str(image_ext),
|
||||
mime_type=mime_type or "",
|
||||
created_by=self.user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
# Use target_part as key for internal images
|
||||
image_map[rel.target_part] = (
|
||||
f""
|
||||
)
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
extension=str(image_ext),
|
||||
mime_type=mime_type or "",
|
||||
created_by=self.user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
image_map[rel.target_part] = f""
|
||||
|
||||
return image_map
|
||||
|
||||
@@ -212,17 +186,11 @@ class WordExtractor(BaseExtractor):
|
||||
image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
|
||||
if not image_id:
|
||||
continue
|
||||
rel = paragraph.part.rels.get(image_id)
|
||||
if rel is None:
|
||||
continue
|
||||
# For external images, use image_id as key; for internal, use target_part
|
||||
if rel.is_external:
|
||||
if image_id in image_map:
|
||||
paragraph_content.append(image_map[image_id])
|
||||
else:
|
||||
image_part = rel.target_part
|
||||
if image_part in image_map:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
image_part = paragraph.part.rels[image_id].target_part
|
||||
|
||||
if image_part in image_map:
|
||||
image_link = image_map[image_part]
|
||||
paragraph_content.append(image_link)
|
||||
else:
|
||||
paragraph_content.append(run.text)
|
||||
return "".join(paragraph_content).strip()
|
||||
@@ -259,18 +227,6 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
def parse_paragraph(paragraph):
|
||||
paragraph_content = []
|
||||
|
||||
def append_image_link(image_id, has_drawing):
|
||||
"""Helper to append image link from image_map based on relationship type."""
|
||||
rel = doc.part.rels[image_id]
|
||||
if rel.is_external:
|
||||
if image_id in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_id])
|
||||
else:
|
||||
image_part = rel.target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
|
||||
for run in paragraph.runs:
|
||||
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
|
||||
# Process drawing type images
|
||||
@@ -287,18 +243,10 @@ class WordExtractor(BaseExtractor):
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed"
|
||||
)
|
||||
if embed_id:
|
||||
rel = doc.part.rels.get(embed_id)
|
||||
if rel is not None and rel.is_external:
|
||||
# External image: use embed_id as key
|
||||
if embed_id in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[embed_id])
|
||||
else:
|
||||
# Internal image: use target_part as key
|
||||
image_part = doc.part.related_parts.get(embed_id)
|
||||
if image_part in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[image_part])
|
||||
image_part = doc.part.related_parts.get(embed_id)
|
||||
if image_part in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[image_part])
|
||||
# Process pict type images
|
||||
shape_elements = run.element.findall(
|
||||
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
|
||||
@@ -313,7 +261,9 @@ class WordExtractor(BaseExtractor):
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
append_image_link(image_id, has_drawing)
|
||||
image_part = doc.part.rels[image_id].target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
# Find imagedata element in VML
|
||||
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
|
||||
if image_data is not None:
|
||||
@@ -321,7 +271,9 @@ class WordExtractor(BaseExtractor):
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
append_image_link(image_id, has_drawing)
|
||||
image_part = doc.part.rels[image_id].target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
if run.text.strip():
|
||||
paragraph_content.append(run.text.strip())
|
||||
return "".join(paragraph_content) if paragraph_content else ""
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class DocType(StrEnum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
@@ -1,12 +1,7 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class IndexStructureType(StrEnum):
|
||||
class IndexType(StrEnum):
|
||||
PARAGRAPH_INDEX = "text_model"
|
||||
QA_INDEX = "qa_model"
|
||||
PARENT_CHILD_INDEX = "hierarchical_model"
|
||||
|
||||
|
||||
class IndexTechniqueType(StrEnum):
|
||||
ECONOMY = "economy"
|
||||
HIGH_QUALITY = "high_quality"
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class QueryType(StrEnum):
|
||||
TEXT_QUERY = "text_query"
|
||||
IMAGE_QUERY = "image_query"
|
||||
@@ -1,34 +1,20 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import cgi
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.rag.splitter.fixed_text_splitter import (
|
||||
EnhanceRecursiveCharacterTextSplitter,
|
||||
FixedRecursiveCharacterTextSplitter,
|
||||
)
|
||||
from core.rag.splitter.text_splitter import TextSplitter
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models import Account, ToolFile
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import UploadFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.model_manager import ModelInstance
|
||||
@@ -42,18 +28,11 @@ class BaseIndexProcessor(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
documents: list[Document],
|
||||
multimodal_documents: list[AttachmentDocument] | None = None,
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@@ -117,178 +96,3 @@ class BaseIndexProcessor(ABC):
|
||||
)
|
||||
|
||||
return character_splitter # type: ignore
|
||||
|
||||
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
|
||||
"""
|
||||
Get the content files from the document.
|
||||
"""
|
||||
multi_model_documents: list[AttachmentDocument] = []
|
||||
text = document.page_content
|
||||
images = self._extract_markdown_images(text)
|
||||
if not images:
|
||||
return multi_model_documents
|
||||
upload_file_id_list = []
|
||||
|
||||
for image in images:
|
||||
# Collect all upload_file_ids including duplicates to preserve occurrence count
|
||||
|
||||
# For data before v0.10.0
|
||||
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
|
||||
match = re.search(pattern, image)
|
||||
if match:
|
||||
upload_file_id = match.group(1)
|
||||
upload_file_id_list.append(upload_file_id)
|
||||
continue
|
||||
|
||||
# For data after v0.10.0
|
||||
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
|
||||
match = re.search(pattern, image)
|
||||
if match:
|
||||
upload_file_id = match.group(1)
|
||||
upload_file_id_list.append(upload_file_id)
|
||||
continue
|
||||
|
||||
# For tools directory - direct file formats (e.g., .png, .jpg, etc.)
|
||||
# Match URL including any query parameters up to common URL boundaries (space, parenthesis, quotes)
|
||||
pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
|
||||
match = re.search(pattern, image)
|
||||
if match:
|
||||
if current_user:
|
||||
tool_file_id = match.group(1)
|
||||
upload_file_id = self._download_tool_file(tool_file_id, current_user)
|
||||
if upload_file_id:
|
||||
upload_file_id_list.append(upload_file_id)
|
||||
continue
|
||||
if current_user:
|
||||
upload_file_id = self._download_image(image.split(" ")[0], current_user)
|
||||
if upload_file_id:
|
||||
upload_file_id_list.append(upload_file_id)
|
||||
|
||||
if not upload_file_id_list:
|
||||
return multi_model_documents
|
||||
|
||||
# Get unique IDs for database query
|
||||
unique_upload_file_ids = list(set(upload_file_id_list))
|
||||
upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all()
|
||||
|
||||
# Create a mapping from ID to UploadFile for quick lookup
|
||||
upload_file_map = {upload_file.id: upload_file for upload_file in upload_files}
|
||||
|
||||
# Create a Document for each occurrence (including duplicates)
|
||||
for upload_file_id in upload_file_id_list:
|
||||
upload_file = upload_file_map.get(upload_file_id)
|
||||
if upload_file:
|
||||
multi_model_documents.append(
|
||||
AttachmentDocument(
|
||||
page_content=upload_file.name,
|
||||
metadata={
|
||||
"doc_id": upload_file.id,
|
||||
"doc_hash": "",
|
||||
"document_id": document.metadata.get("document_id"),
|
||||
"dataset_id": document.metadata.get("dataset_id"),
|
||||
"doc_type": DocType.IMAGE,
|
||||
},
|
||||
)
|
||||
)
|
||||
return multi_model_documents
|
||||
|
||||
def _extract_markdown_images(self, text: str) -> list[str]:
|
||||
"""
|
||||
Extract the markdown images from the text.
|
||||
"""
|
||||
pattern = r"!\[.*?\]\((.*?)\)"
|
||||
return re.findall(pattern, text)
|
||||
|
||||
def _download_image(self, image_url: str, current_user: Account) -> str | None:
|
||||
"""
|
||||
Download the image from the URL.
|
||||
Image size must not exceed 2MB.
|
||||
"""
|
||||
from services.file_service import FileService
|
||||
|
||||
MAX_IMAGE_SIZE = dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024
|
||||
DOWNLOAD_TIMEOUT = dify_config.ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT
|
||||
|
||||
try:
|
||||
# Download with timeout
|
||||
response = ssrf_proxy.get(image_url, timeout=DOWNLOAD_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
|
||||
# Check Content-Length header if available
|
||||
content_length = response.headers.get("Content-Length")
|
||||
if content_length and int(content_length) > MAX_IMAGE_SIZE:
|
||||
logging.warning("Image from %s exceeds 2MB limit (size: %s bytes)", image_url, content_length)
|
||||
return None
|
||||
|
||||
filename = None
|
||||
|
||||
content_disposition = response.headers.get("content-disposition")
|
||||
if content_disposition:
|
||||
_, params = cgi.parse_header(content_disposition)
|
||||
if "filename" in params:
|
||||
filename = params["filename"]
|
||||
filename = unquote(filename)
|
||||
|
||||
if not filename:
|
||||
parsed_url = urlparse(image_url)
|
||||
# unquote 处理 URL 中的中文
|
||||
path = unquote(parsed_url.path)
|
||||
filename = os.path.basename(path)
|
||||
|
||||
if not filename:
|
||||
filename = "downloaded_image_file"
|
||||
|
||||
name, current_ext = os.path.splitext(filename)
|
||||
|
||||
content_type = response.headers.get("content-type", "").split(";")[0].strip()
|
||||
|
||||
real_ext = mimetypes.guess_extension(content_type)
|
||||
|
||||
if not current_ext and real_ext or current_ext in [".php", ".jsp", ".asp", ".html"] and real_ext:
|
||||
filename = f"{name}{real_ext}"
|
||||
# Download content with size limit
|
||||
blob = b""
|
||||
for chunk in response.iter_bytes(chunk_size=8192):
|
||||
blob += chunk
|
||||
if len(blob) > MAX_IMAGE_SIZE:
|
||||
logging.warning("Image from %s exceeds 2MB limit during download", image_url)
|
||||
return None
|
||||
|
||||
if not blob:
|
||||
logging.warning("Image from %s is empty", image_url)
|
||||
return None
|
||||
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=filename,
|
||||
content=blob,
|
||||
mimetype=content_type,
|
||||
user=current_user,
|
||||
)
|
||||
return upload_file.id
|
||||
except httpx.TimeoutException:
|
||||
logging.warning("Timeout downloading image from %s after %s seconds", image_url, DOWNLOAD_TIMEOUT)
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logging.warning("Error downloading image from %s: %s", image_url, str(e))
|
||||
return None
|
||||
except Exception:
|
||||
logging.exception("Unexpected error downloading image from %s", image_url)
|
||||
return None
|
||||
|
||||
def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None:
|
||||
"""
|
||||
Download the tool file from the ID.
|
||||
"""
|
||||
from services.file_service import FileService
|
||||
|
||||
tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
|
||||
if not tool_file:
|
||||
return None
|
||||
blob = storage.load_once(tool_file.file_key)
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=tool_file.name,
|
||||
content=blob,
|
||||
mimetype=tool_file.mimetype,
|
||||
user=current_user,
|
||||
)
|
||||
return upload_file.id
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
|
||||
@@ -19,11 +19,11 @@ class IndexProcessorFactory:
|
||||
if not self._index_type:
|
||||
raise ValueError("Index type must be specified.")
|
||||
|
||||
if self._index_type == IndexStructureType.PARAGRAPH_INDEX:
|
||||
if self._index_type == IndexType.PARAGRAPH_INDEX:
|
||||
return ParagraphIndexProcessor()
|
||||
elif self._index_type == IndexStructureType.QA_INDEX:
|
||||
elif self._index_type == IndexType.QA_INDEX:
|
||||
return QAIndexProcessor()
|
||||
elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
elif self._index_type == IndexType.PARENT_CHILD_INDEX:
|
||||
return ParentChildIndexProcessor()
|
||||
else:
|
||||
raise ValueError(f"Index type {self._index_type} is not supported.")
|
||||
|
||||
@@ -11,17 +11,14 @@ from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
|
||||
|
||||
@@ -36,7 +33,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
return text_docs
|
||||
|
||||
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
process_rule = kwargs.get("process_rule")
|
||||
if not process_rule:
|
||||
raise ValueError("No process rule found.")
|
||||
@@ -72,11 +69,6 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
if document_node.metadata is not None:
|
||||
document_node.metadata["doc_id"] = doc_id
|
||||
document_node.metadata["doc_hash"] = hash
|
||||
multimodal_documents = (
|
||||
self._get_content_files(document_node, current_user) if document_node.metadata else None
|
||||
)
|
||||
if multimodal_documents:
|
||||
document_node.attachments = multimodal_documents
|
||||
# delete Splitter character
|
||||
page_content = remove_leading_symbols(document_node.page_content).strip()
|
||||
if len(page_content) > 0:
|
||||
@@ -85,19 +77,10 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
all_documents.extend(split_documents)
|
||||
return all_documents
|
||||
|
||||
def load(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
documents: list[Document],
|
||||
multimodal_documents: list[AttachmentDocument] | None = None,
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(multimodal_documents)
|
||||
with_keywords = False
|
||||
if with_keywords:
|
||||
keywords_list = kwargs.get("keywords_list")
|
||||
@@ -151,9 +134,8 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
return docs
|
||||
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
|
||||
documents: list[Any] = []
|
||||
all_multimodal_documents: list[Any] = []
|
||||
if isinstance(chunks, list):
|
||||
documents = []
|
||||
for content in chunks:
|
||||
metadata = {
|
||||
"dataset_id": dataset.id,
|
||||
@@ -162,68 +144,26 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
"doc_hash": helper.generate_text_hash(content),
|
||||
}
|
||||
doc = Document(page_content=content, metadata=metadata)
|
||||
attachments = self._get_content_files(doc)
|
||||
if attachments:
|
||||
doc.attachments = attachments
|
||||
all_multimodal_documents.extend(attachments)
|
||||
documents.append(doc)
|
||||
if documents:
|
||||
# save node to document segment
|
||||
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||
# add document segments
|
||||
doc_store.add_documents(docs=documents, save_child=False)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
elif dataset.indexing_technique == "economy":
|
||||
keyword = Keyword(dataset)
|
||||
keyword.add_texts(documents)
|
||||
else:
|
||||
multimodal_general_structure = MultimodalGeneralStructureChunk.model_validate(chunks)
|
||||
for general_chunk in multimodal_general_structure.general_chunks:
|
||||
metadata = {
|
||||
"dataset_id": dataset.id,
|
||||
"document_id": document.id,
|
||||
"doc_id": str(uuid.uuid4()),
|
||||
"doc_hash": helper.generate_text_hash(general_chunk.content),
|
||||
}
|
||||
doc = Document(page_content=general_chunk.content, metadata=metadata)
|
||||
if general_chunk.files:
|
||||
attachments = []
|
||||
for file in general_chunk.files:
|
||||
file_metadata = {
|
||||
"doc_id": file.id,
|
||||
"doc_hash": "",
|
||||
"document_id": document.id,
|
||||
"dataset_id": dataset.id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
}
|
||||
file_document = AttachmentDocument(
|
||||
page_content=file.filename or "image_file", metadata=file_metadata
|
||||
)
|
||||
attachments.append(file_document)
|
||||
all_multimodal_documents.append(file_document)
|
||||
doc.attachments = attachments
|
||||
else:
|
||||
account = AccountService.load_user(document.created_by)
|
||||
if not account:
|
||||
raise ValueError("Invalid account")
|
||||
doc.attachments = self._get_content_files(doc, current_user=account)
|
||||
if doc.attachments:
|
||||
all_multimodal_documents.extend(doc.attachments)
|
||||
documents.append(doc)
|
||||
if documents:
|
||||
# save node to document segment
|
||||
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||
# add document segments
|
||||
doc_store.add_documents(docs=documents, save_child=False)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if all_multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(all_multimodal_documents)
|
||||
elif dataset.indexing_technique == "economy":
|
||||
keyword = Keyword(dataset)
|
||||
keyword.add_texts(documents)
|
||||
raise ValueError("Chunks is not a list")
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
if isinstance(chunks, list):
|
||||
preview = []
|
||||
for content in chunks:
|
||||
preview.append({"content": content})
|
||||
return {
|
||||
"chunk_structure": IndexStructureType.PARAGRAPH_INDEX,
|
||||
"preview": preview,
|
||||
"total_segments": len(chunks),
|
||||
}
|
||||
return {"chunk_structure": IndexType.PARAGRAPH_INDEX, "preview": preview, "total_segments": len(chunks)}
|
||||
else:
|
||||
raise ValueError("Chunks is not a list")
|
||||
|
||||
@@ -13,17 +13,14 @@ from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
|
||||
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models import Account
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||
|
||||
|
||||
@@ -38,7 +35,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
return text_docs
|
||||
|
||||
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
process_rule = kwargs.get("process_rule")
|
||||
if not process_rule:
|
||||
raise ValueError("No process rule found.")
|
||||
@@ -80,9 +77,6 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
page_content = page_content
|
||||
if len(page_content) > 0:
|
||||
document_node.page_content = page_content
|
||||
multimodel_documents = self._get_content_files(document_node, current_user)
|
||||
if multimodel_documents:
|
||||
document_node.attachments = multimodel_documents
|
||||
# parse document to child nodes
|
||||
child_nodes = self._split_child_nodes(
|
||||
document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
|
||||
@@ -93,9 +87,6 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
elif rules.parent_mode == ParentMode.FULL_DOC:
|
||||
page_content = "\n".join([document.page_content for document in documents])
|
||||
document = Document(page_content=page_content, metadata=documents[0].metadata)
|
||||
multimodel_documents = self._get_content_files(document)
|
||||
if multimodel_documents:
|
||||
document.attachments = multimodel_documents
|
||||
# parse document to child nodes
|
||||
child_nodes = self._split_child_nodes(
|
||||
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
|
||||
@@ -113,14 +104,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
return all_documents
|
||||
|
||||
def load(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
documents: list[Document],
|
||||
multimodal_documents: list[AttachmentDocument] | None = None,
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
for document in documents:
|
||||
@@ -130,8 +114,6 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
Document.model_validate(child_document.model_dump()) for child_document in child_documents
|
||||
]
|
||||
vector.create(formatted_child_documents)
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(multimodal_documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# node_ids is segment's node_ids
|
||||
@@ -262,24 +244,6 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
}
|
||||
child_documents.append(ChildDocument(page_content=child, metadata=child_metadata))
|
||||
doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents)
|
||||
if parent_child.files and len(parent_child.files) > 0:
|
||||
attachments = []
|
||||
for file in parent_child.files:
|
||||
file_metadata = {
|
||||
"doc_id": file.id,
|
||||
"doc_hash": "",
|
||||
"document_id": document.id,
|
||||
"dataset_id": dataset.id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
}
|
||||
file_document = AttachmentDocument(page_content=file.filename or "", metadata=file_metadata)
|
||||
attachments.append(file_document)
|
||||
doc.attachments = attachments
|
||||
else:
|
||||
account = AccountService.load_user(document.created_by)
|
||||
if not account:
|
||||
raise ValueError("Invalid account")
|
||||
doc.attachments = self._get_content_files(doc, current_user=account)
|
||||
documents.append(doc)
|
||||
if documents:
|
||||
# update document parent mode
|
||||
@@ -303,17 +267,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
doc_store.add_documents(docs=documents, save_child=True)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
all_child_documents = []
|
||||
all_multimodal_documents = []
|
||||
for doc in documents:
|
||||
if doc.children:
|
||||
all_child_documents.extend(doc.children)
|
||||
if doc.attachments:
|
||||
all_multimodal_documents.extend(doc.attachments)
|
||||
vector = Vector(dataset)
|
||||
if all_child_documents:
|
||||
vector = Vector(dataset)
|
||||
vector.create(all_child_documents)
|
||||
if all_multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(all_multimodal_documents)
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
||||
@@ -321,7 +280,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
for parent_child in parent_childs.parent_child_chunks:
|
||||
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
|
||||
return {
|
||||
"chunk_structure": IndexStructureType.PARENT_CHILD_INDEX,
|
||||
"chunk_structure": IndexType.PARENT_CHILD_INDEX,
|
||||
"parent_mode": parent_childs.parent_mode,
|
||||
"preview": preview,
|
||||
"total_segments": len(parent_childs.parent_child_chunks),
|
||||
|
||||
@@ -18,13 +18,12 @@ from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
|
||||
from core.rag.models.document import Document, QAStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
@@ -42,7 +41,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
)
|
||||
return text_docs
|
||||
|
||||
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
preview = kwargs.get("preview")
|
||||
process_rule = kwargs.get("process_rule")
|
||||
if not process_rule:
|
||||
@@ -117,7 +116,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
try:
|
||||
# Skip the first row
|
||||
df = pd.read_csv(file) # type: ignore
|
||||
df = pd.read_csv(file)
|
||||
text_docs = []
|
||||
for _, row in df.iterrows():
|
||||
data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]})
|
||||
@@ -129,19 +128,10 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError(str(e))
|
||||
return text_docs
|
||||
|
||||
def load(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
documents: list[Document],
|
||||
multimodal_documents: list[AttachmentDocument] | None = None,
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(multimodal_documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
vector = Vector(dataset)
|
||||
@@ -207,7 +197,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
for qa_chunk in qa_chunks.qa_chunks:
|
||||
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
|
||||
return {
|
||||
"chunk_structure": IndexStructureType.QA_INDEX,
|
||||
"chunk_structure": IndexType.QA_INDEX,
|
||||
"qa_preview": preview,
|
||||
"total_segments": len(qa_chunks.qa_chunks),
|
||||
}
|
||||
|
||||
@@ -4,8 +4,6 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File
|
||||
|
||||
|
||||
class ChildDocument(BaseModel):
|
||||
"""Class for storing a piece of text and associated metadata."""
|
||||
@@ -17,19 +15,7 @@ class ChildDocument(BaseModel):
|
||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||
documents, etc.).
|
||||
"""
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AttachmentDocument(BaseModel):
|
||||
"""Class for storing a piece of text and associated metadata."""
|
||||
|
||||
page_content: str
|
||||
|
||||
provider: str | None = "dify"
|
||||
|
||||
vector: list[float] | None = None
|
||||
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
@@ -42,31 +28,12 @@ class Document(BaseModel):
|
||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||
documents, etc.).
|
||||
"""
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
|
||||
provider: str | None = "dify"
|
||||
|
||||
children: list[ChildDocument] | None = None
|
||||
|
||||
attachments: list[AttachmentDocument] | None = None
|
||||
|
||||
|
||||
class GeneralChunk(BaseModel):
|
||||
"""
|
||||
General Chunk.
|
||||
"""
|
||||
|
||||
content: str
|
||||
files: list[File] | None = None
|
||||
|
||||
|
||||
class MultimodalGeneralStructureChunk(BaseModel):
|
||||
"""
|
||||
Multimodal General Structure Chunk.
|
||||
"""
|
||||
|
||||
general_chunks: list[GeneralChunk]
|
||||
|
||||
|
||||
class GeneralStructureChunk(BaseModel):
|
||||
"""
|
||||
@@ -83,7 +50,6 @@ class ParentChildChunk(BaseModel):
|
||||
|
||||
parent_content: str
|
||||
child_contents: list[str]
|
||||
files: list[File] | None = None
|
||||
|
||||
|
||||
class ParentChildStructureChunk(BaseModel):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
@@ -13,7 +12,6 @@ class BaseRerankRunner(ABC):
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Run rerank model
|
||||
|
||||
@@ -1,15 +1,6 @@
|
||||
import base64
|
||||
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank_base import BaseRerankRunner
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
class RerankModelRunner(BaseRerankRunner):
|
||||
@@ -23,7 +14,6 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Run rerank model
|
||||
@@ -34,31 +24,38 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
:param user: unique user id if needed
|
||||
:return:
|
||||
"""
|
||||
model_manager = ModelManager()
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
|
||||
provider=self.rerank_model_instance.provider,
|
||||
model=self.rerank_model_instance.model,
|
||||
model_type=ModelType.RERANK,
|
||||
docs = []
|
||||
doc_ids = set()
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
if (
|
||||
document.provider == "dify"
|
||||
and document.metadata is not None
|
||||
and document.metadata["doc_id"] not in doc_ids
|
||||
):
|
||||
doc_ids.add(document.metadata["doc_id"])
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
elif document.provider == "external":
|
||||
if document not in unique_documents:
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
|
||||
documents = unique_documents
|
||||
|
||||
rerank_result = self.rerank_model_instance.invoke_rerank(
|
||||
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
|
||||
)
|
||||
if not is_support_vision:
|
||||
if query_type == QueryType.TEXT_QUERY:
|
||||
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
|
||||
else:
|
||||
return documents
|
||||
else:
|
||||
rerank_result, unique_documents = self.fetch_multimodal_rerank(
|
||||
query, documents, score_threshold, top_n, user, query_type
|
||||
)
|
||||
|
||||
rerank_documents = []
|
||||
|
||||
for result in rerank_result.docs:
|
||||
if score_threshold is None or result.score >= score_threshold:
|
||||
# format document
|
||||
rerank_document = Document(
|
||||
page_content=result.text,
|
||||
metadata=unique_documents[result.index].metadata,
|
||||
provider=unique_documents[result.index].provider,
|
||||
metadata=documents[result.index].metadata,
|
||||
provider=documents[result.index].provider,
|
||||
)
|
||||
if rerank_document.metadata is not None:
|
||||
rerank_document.metadata["score"] = result.score
|
||||
@@ -66,126 +63,3 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
|
||||
rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
|
||||
return rerank_documents[:top_n] if top_n else rerank_documents
|
||||
|
||||
def fetch_text_rerank(
|
||||
self,
|
||||
query: str,
|
||||
documents: list[Document],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> tuple[RerankResult, list[Document]]:
|
||||
"""
|
||||
Fetch text rerank
|
||||
:param query: search query
|
||||
:param documents: documents for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id if needed
|
||||
:return:
|
||||
"""
|
||||
docs = []
|
||||
doc_ids = set()
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
if (
|
||||
document.provider == "dify"
|
||||
and document.metadata is not None
|
||||
and document.metadata["doc_id"] not in doc_ids
|
||||
):
|
||||
if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT:
|
||||
doc_ids.add(document.metadata["doc_id"])
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
elif document.provider == "external":
|
||||
if document not in unique_documents:
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
|
||||
rerank_result = self.rerank_model_instance.invoke_rerank(
|
||||
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
|
||||
)
|
||||
return rerank_result, unique_documents
|
||||
|
||||
def fetch_multimodal_rerank(
|
||||
self,
|
||||
query: str,
|
||||
documents: list[Document],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||
) -> tuple[RerankResult, list[Document]]:
|
||||
"""
|
||||
Fetch multimodal rerank
|
||||
:param query: search query
|
||||
:param documents: documents for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id if needed
|
||||
:param query_type: query type
|
||||
:return: rerank result
|
||||
"""
|
||||
docs = []
|
||||
doc_ids = set()
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
if (
|
||||
document.provider == "dify"
|
||||
and document.metadata is not None
|
||||
and document.metadata["doc_id"] not in doc_ids
|
||||
):
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
# Query file info within db.session context to ensure thread-safe access
|
||||
upload_file = (
|
||||
db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first()
|
||||
)
|
||||
if upload_file:
|
||||
blob = storage.load_once(upload_file.key)
|
||||
document_file_base64 = base64.b64encode(blob).decode()
|
||||
document_file_dict = {
|
||||
"content": document_file_base64,
|
||||
"content_type": document.metadata["doc_type"],
|
||||
}
|
||||
docs.append(document_file_dict)
|
||||
else:
|
||||
document_text_dict = {
|
||||
"content": document.page_content,
|
||||
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
|
||||
}
|
||||
docs.append(document_text_dict)
|
||||
doc_ids.add(document.metadata["doc_id"])
|
||||
unique_documents.append(document)
|
||||
elif document.provider == "external":
|
||||
if document not in unique_documents:
|
||||
docs.append(
|
||||
{
|
||||
"content": document.page_content,
|
||||
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
|
||||
}
|
||||
)
|
||||
unique_documents.append(document)
|
||||
|
||||
documents = unique_documents
|
||||
if query_type == QueryType.TEXT_QUERY:
|
||||
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
|
||||
return rerank_result, unique_documents
|
||||
elif query_type == QueryType.IMAGE_QUERY:
|
||||
# Query file info within db.session context to ensure thread-safe access
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first()
|
||||
if upload_file:
|
||||
blob = storage.load_once(upload_file.key)
|
||||
file_query = base64.b64encode(blob).decode()
|
||||
file_query_dict = {
|
||||
"content": file_query,
|
||||
"content_type": DocType.IMAGE,
|
||||
}
|
||||
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
|
||||
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
|
||||
)
|
||||
return rerank_result, unique_documents
|
||||
else:
|
||||
raise ValueError(f"Upload file not found for query: {query}")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Query type {query_type} is not supported")
|
||||
|
||||
@@ -7,8 +7,6 @@ from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.entity.weight import VectorSetting, Weights
|
||||
from core.rag.rerank.rerank_base import BaseRerankRunner
|
||||
@@ -26,7 +24,6 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Run rerank model
|
||||
@@ -46,10 +43,8 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
and document.metadata is not None
|
||||
and document.metadata["doc_id"] not in doc_ids
|
||||
):
|
||||
# weight rerank only support text documents
|
||||
if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT:
|
||||
doc_ids.add(document.metadata["doc_id"])
|
||||
unique_documents.append(document)
|
||||
doc_ids.add(document.metadata["doc_id"])
|
||||
unique_documents.append(document)
|
||||
else:
|
||||
if document not in unique_documents:
|
||||
unique_documents.append(document)
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import Any, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@@ -20,7 +19,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
@@ -39,9 +37,7 @@ from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank_type import RerankMode
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@@ -56,12 +52,10 @@ from core.rag.retrieval.template_prompts import (
|
||||
METADATA_FILTER_USER_PROMPT_2,
|
||||
METADATA_FILTER_USER_PROMPT_3,
|
||||
)
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models import UploadFile
|
||||
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
@@ -105,8 +99,7 @@ class DatasetRetrieval:
|
||||
message_id: str,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
inputs: Mapping[str, Any] | None = None,
|
||||
vision_enabled: bool = False,
|
||||
) -> tuple[str | None, list[File] | None]:
|
||||
) -> str | None:
|
||||
"""
|
||||
Retrieve dataset.
|
||||
:param app_id: app_id
|
||||
@@ -125,7 +118,7 @@ class DatasetRetrieval:
|
||||
"""
|
||||
dataset_ids = config.dataset_ids
|
||||
if len(dataset_ids) == 0:
|
||||
return None, []
|
||||
return None
|
||||
retrieve_config = config.retrieve_config
|
||||
|
||||
# check model is support tool calling
|
||||
@@ -143,7 +136,7 @@ class DatasetRetrieval:
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
return None, []
|
||||
return None
|
||||
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
@@ -189,8 +182,8 @@ class DatasetRetrieval:
|
||||
tenant_id,
|
||||
user_id,
|
||||
user_from,
|
||||
query,
|
||||
available_datasets,
|
||||
query,
|
||||
model_instance,
|
||||
model_config,
|
||||
planning_strategy,
|
||||
@@ -220,7 +213,6 @@ class DatasetRetrieval:
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
document_context_list: list[DocumentContext] = []
|
||||
context_files: list[File] = []
|
||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||
# deal with external documents
|
||||
for item in external_documents:
|
||||
@@ -256,31 +248,6 @@ class DatasetRetrieval:
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
if vision_enabled:
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(
|
||||
SegmentAttachmentBinding.segment_id == segment.id,
|
||||
)
|
||||
).all()
|
||||
if attachments_with_bindings:
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
attachment_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=segment.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=sign_upload_file(upload_file.id, upload_file.extension),
|
||||
)
|
||||
context_files.append(attachment_info)
|
||||
if show_retrieve_source:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
@@ -321,10 +288,8 @@ class DatasetRetrieval:
|
||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||
if document_context_list:
|
||||
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
return str(
|
||||
"\n".join([document_context.content for document_context in document_context_list])
|
||||
), context_files
|
||||
return "", context_files
|
||||
return str("\n".join([document_context.content for document_context in document_context_list]))
|
||||
return ""
|
||||
|
||||
def single_retrieve(
|
||||
self,
|
||||
@@ -332,8 +297,8 @@ class DatasetRetrieval:
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
query: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
@@ -371,7 +336,7 @@ class DatasetRetrieval:
|
||||
dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
|
||||
|
||||
self._record_usage(router_usage)
|
||||
timer = None
|
||||
|
||||
if dataset_id:
|
||||
# get retrieval model config
|
||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
@@ -441,19 +406,10 @@ class DatasetRetrieval:
|
||||
weights=retrieval_model_config.get("weights", None),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
self._on_query(query, None, [dataset_id], app_id, user_from, user_id)
|
||||
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
||||
|
||||
if results:
|
||||
thread = threading.Thread(
|
||||
target=self._on_retrieval_end,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"documents": results,
|
||||
"message_id": message_id,
|
||||
"timer": timer,
|
||||
},
|
||||
)
|
||||
thread.start()
|
||||
self._on_retrieval_end(results, message_id, timer)
|
||||
|
||||
return results
|
||||
return []
|
||||
@@ -465,7 +421,7 @@ class DatasetRetrieval:
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str | None,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_mode: str,
|
||||
@@ -475,11 +431,10 @@ class DatasetRetrieval:
|
||||
message_id: str | None = None,
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
):
|
||||
if not available_datasets:
|
||||
return []
|
||||
all_threads = []
|
||||
threads = []
|
||||
all_documents: list[Document] = []
|
||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||
index_type_check = all(
|
||||
@@ -512,187 +467,102 @@ class DatasetRetrieval:
|
||||
0
|
||||
].embedding_model_provider
|
||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||
with measure_time() as timer:
|
||||
if query:
|
||||
query_thread = threading.Thread(
|
||||
target=self._multiple_retrieve_thread,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"available_datasets": available_datasets,
|
||||
"metadata_condition": metadata_condition,
|
||||
"metadata_filter_document_ids": metadata_filter_document_ids,
|
||||
"all_documents": all_documents,
|
||||
"tenant_id": tenant_id,
|
||||
"reranking_enable": reranking_enable,
|
||||
"reranking_mode": reranking_mode,
|
||||
"reranking_model": reranking_model,
|
||||
"weights": weights,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
"query": query,
|
||||
"attachment_id": None,
|
||||
},
|
||||
)
|
||||
all_threads.append(query_thread)
|
||||
query_thread.start()
|
||||
if attachment_ids:
|
||||
for attachment_id in attachment_ids:
|
||||
attachment_thread = threading.Thread(
|
||||
target=self._multiple_retrieve_thread,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"available_datasets": available_datasets,
|
||||
"metadata_condition": metadata_condition,
|
||||
"metadata_filter_document_ids": metadata_filter_document_ids,
|
||||
"all_documents": all_documents,
|
||||
"tenant_id": tenant_id,
|
||||
"reranking_enable": reranking_enable,
|
||||
"reranking_mode": reranking_mode,
|
||||
"reranking_model": reranking_model,
|
||||
"weights": weights,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
"query": None,
|
||||
"attachment_id": attachment_id,
|
||||
},
|
||||
)
|
||||
all_threads.append(attachment_thread)
|
||||
attachment_thread.start()
|
||||
for thread in all_threads:
|
||||
thread.join()
|
||||
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
|
||||
|
||||
if all_documents:
|
||||
# add thread to call _on_retrieval_end
|
||||
retrieval_end_thread = threading.Thread(
|
||||
target=self._on_retrieval_end,
|
||||
for dataset in available_datasets:
|
||||
index_type = dataset.indexing_technique
|
||||
document_ids_filter = None
|
||||
if dataset.provider != "external":
|
||||
if metadata_condition and not metadata_filter_document_ids:
|
||||
continue
|
||||
if metadata_filter_document_ids:
|
||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if document_ids:
|
||||
document_ids_filter = document_ids
|
||||
else:
|
||||
continue
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"documents": all_documents,
|
||||
"message_id": message_id,
|
||||
"timer": timer,
|
||||
"dataset_id": dataset.id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents,
|
||||
"document_ids_filter": document_ids_filter,
|
||||
"metadata_condition": metadata_condition,
|
||||
},
|
||||
)
|
||||
retrieval_end_thread.start()
|
||||
retrieval_resource_list = []
|
||||
doc_ids_filter = []
|
||||
for document in all_documents:
|
||||
if document.provider == "dify":
|
||||
doc_id = document.metadata.get("doc_id")
|
||||
if doc_id and doc_id not in doc_ids_filter:
|
||||
doc_ids_filter.append(doc_id)
|
||||
retrieval_resource_list.append(document)
|
||||
elif document.provider == "external":
|
||||
retrieval_resource_list.append(document)
|
||||
return retrieval_resource_list
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
def _on_retrieval_end(
|
||||
self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None
|
||||
):
|
||||
with measure_time() as timer:
|
||||
if reranking_enable:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
|
||||
all_documents = data_post_processor.invoke(
|
||||
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
|
||||
)
|
||||
else:
|
||||
if index_type == "economy":
|
||||
all_documents = self.calculate_keyword_score(query, all_documents, top_k)
|
||||
elif index_type == "high_quality":
|
||||
all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
|
||||
else:
|
||||
all_documents = all_documents[:top_k] if top_k else all_documents
|
||||
|
||||
self._on_query(query, dataset_ids, app_id, user_from, user_id)
|
||||
|
||||
if all_documents:
|
||||
self._on_retrieval_end(all_documents, message_id, timer)
|
||||
|
||||
return all_documents
|
||||
|
||||
def _on_retrieval_end(self, documents: list[Document], message_id: str | None = None, timer: dict | None = None):
|
||||
"""Handle retrieval end."""
|
||||
with flask_app.app_context():
|
||||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
if not dify_documents:
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
return
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Collect all document_ids and batch fetch DatasetDocuments
|
||||
document_ids = {
|
||||
doc.metadata["document_id"]
|
||||
for doc in dify_documents
|
||||
if doc.metadata and "document_id" in doc.metadata
|
||||
}
|
||||
if not document_ids:
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
return
|
||||
|
||||
dataset_docs_stmt = select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))
|
||||
dataset_docs = session.scalars(dataset_docs_stmt).all()
|
||||
dataset_doc_map = {str(doc.id): doc for doc in dataset_docs}
|
||||
|
||||
# Categorize documents by type and collect necessary IDs
|
||||
parent_child_text_docs: list[tuple[Document, DatasetDocument]] = []
|
||||
parent_child_image_docs: list[tuple[Document, DatasetDocument]] = []
|
||||
normal_text_docs: list[tuple[Document, DatasetDocument]] = []
|
||||
normal_image_docs: list[tuple[Document, DatasetDocument]] = []
|
||||
|
||||
for doc in dify_documents:
|
||||
if not doc.metadata or "document_id" not in doc.metadata:
|
||||
continue
|
||||
dataset_doc = dataset_doc_map.get(doc.metadata["document_id"])
|
||||
if not dataset_doc:
|
||||
continue
|
||||
|
||||
is_image = doc.metadata.get("doc_type") == DocType.IMAGE
|
||||
is_parent_child = dataset_doc.doc_form == IndexStructureType.PARENT_CHILD_INDEX
|
||||
|
||||
if is_parent_child:
|
||||
if is_image:
|
||||
parent_child_image_docs.append((doc, dataset_doc))
|
||||
else:
|
||||
parent_child_text_docs.append((doc, dataset_doc))
|
||||
else:
|
||||
if is_image:
|
||||
normal_image_docs.append((doc, dataset_doc))
|
||||
else:
|
||||
normal_text_docs.append((doc, dataset_doc))
|
||||
|
||||
segment_ids_to_update: set[str] = set()
|
||||
|
||||
# Process PARENT_CHILD_INDEX text documents - batch fetch ChildChunks
|
||||
if parent_child_text_docs:
|
||||
index_node_ids = [doc.metadata["doc_id"] for doc, _ in parent_child_text_docs if doc.metadata]
|
||||
if index_node_ids:
|
||||
child_chunks_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(index_node_ids))
|
||||
child_chunks = session.scalars(child_chunks_stmt).all()
|
||||
child_chunk_map = {chunk.index_node_id: chunk.segment_id for chunk in child_chunks}
|
||||
for doc, _ in parent_child_text_docs:
|
||||
if doc.metadata:
|
||||
segment_id = child_chunk_map.get(doc.metadata["doc_id"])
|
||||
if segment_id:
|
||||
segment_ids_to_update.add(str(segment_id))
|
||||
|
||||
# Process non-PARENT_CHILD_INDEX text documents - batch fetch DocumentSegments
|
||||
if normal_text_docs:
|
||||
index_node_ids = [doc.metadata["doc_id"] for doc, _ in normal_text_docs if doc.metadata]
|
||||
if index_node_ids:
|
||||
segments_stmt = select(DocumentSegment).where(DocumentSegment.index_node_id.in_(index_node_ids))
|
||||
segments = session.scalars(segments_stmt).all()
|
||||
segment_map = {seg.index_node_id: seg.id for seg in segments}
|
||||
for doc, _ in normal_text_docs:
|
||||
if doc.metadata:
|
||||
segment_id = segment_map.get(doc.metadata["doc_id"])
|
||||
if segment_id:
|
||||
segment_ids_to_update.add(str(segment_id))
|
||||
|
||||
# Process IMAGE documents - batch fetch SegmentAttachmentBindings
|
||||
all_image_docs = parent_child_image_docs + normal_image_docs
|
||||
if all_image_docs:
|
||||
attachment_ids = [
|
||||
doc.metadata["doc_id"]
|
||||
for doc, _ in all_image_docs
|
||||
if doc.metadata and doc.metadata.get("doc_id")
|
||||
]
|
||||
if attachment_ids:
|
||||
bindings_stmt = select(SegmentAttachmentBinding).where(
|
||||
SegmentAttachmentBinding.attachment_id.in_(attachment_ids)
|
||||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
for document in dify_documents:
|
||||
if document.metadata is not None:
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == document.metadata["document_id"]
|
||||
)
|
||||
dataset_document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk_stmt = select(ChildChunk).where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
_ = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == child_chunk.segment_id)
|
||||
.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
bindings = session.scalars(bindings_stmt).all()
|
||||
segment_ids_to_update.update(str(binding.segment_id) for binding in bindings)
|
||||
|
||||
# Batch update hit_count for all segments
|
||||
if segment_ids_to_update:
|
||||
session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
session.commit()
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
# add hit count to document segment
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
||||
)
|
||||
|
||||
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None):
|
||||
"""Send trace task if trace manager is available."""
|
||||
db.session.commit()
|
||||
|
||||
# get tracing instance
|
||||
trace_manager: TraceQueueManager | None = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
@@ -703,40 +573,25 @@ class DatasetRetrieval:
|
||||
)
|
||||
)
|
||||
|
||||
def _on_query(
|
||||
self,
|
||||
query: str | None,
|
||||
attachment_ids: list[str] | None,
|
||||
dataset_ids: list[str],
|
||||
app_id: str,
|
||||
user_from: str,
|
||||
user_id: str,
|
||||
):
|
||||
def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str):
|
||||
"""
|
||||
Handle query.
|
||||
"""
|
||||
if not query and not attachment_ids:
|
||||
if not query:
|
||||
return
|
||||
dataset_queries = []
|
||||
for dataset_id in dataset_ids:
|
||||
contents = []
|
||||
if query:
|
||||
contents.append({"content_type": QueryType.TEXT_QUERY, "content": query})
|
||||
if attachment_ids:
|
||||
for attachment_id in attachment_ids:
|
||||
contents.append({"content_type": QueryType.IMAGE_QUERY, "content": attachment_id})
|
||||
if contents:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=json.dumps(contents),
|
||||
source="app",
|
||||
source_app_id=app_id,
|
||||
created_by_role=user_from,
|
||||
created_by=user_id,
|
||||
)
|
||||
dataset_queries.append(dataset_query)
|
||||
if dataset_queries:
|
||||
db.session.add_all(dataset_queries)
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=query,
|
||||
source="app",
|
||||
source_app_id=app_id,
|
||||
created_by_role=user_from,
|
||||
created_by=user_id,
|
||||
)
|
||||
dataset_queries.append(dataset_query)
|
||||
if dataset_queries:
|
||||
db.session.add_all(dataset_queries)
|
||||
db.session.commit()
|
||||
|
||||
def _retriever(
|
||||
@@ -748,7 +603,6 @@ class DatasetRetrieval:
|
||||
all_documents: list,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
@@ -757,7 +611,7 @@ class DatasetRetrieval:
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
if dataset.provider == "external" and query:
|
||||
if dataset.provider == "external":
|
||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
@@ -809,7 +663,6 @@ class DatasetRetrieval:
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
document_ids_filter=document_ids_filter,
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
@@ -1369,86 +1222,3 @@ class DatasetRetrieval:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
return full_text, usage
|
||||
|
||||
def _multiple_retrieve_thread(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
available_datasets: list,
|
||||
metadata_condition: MetadataCondition | None,
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None,
|
||||
all_documents: list[Document],
|
||||
tenant_id: str,
|
||||
reranking_enable: bool,
|
||||
reranking_mode: str,
|
||||
reranking_model: dict | None,
|
||||
weights: dict[str, Any] | None,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
query: str | None,
|
||||
attachment_id: str | None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
threads = []
|
||||
all_documents_item: list[Document] = []
|
||||
index_type = None
|
||||
for dataset in available_datasets:
|
||||
index_type = dataset.indexing_technique
|
||||
document_ids_filter = None
|
||||
if dataset.provider != "external":
|
||||
if metadata_condition and not metadata_filter_document_ids:
|
||||
continue
|
||||
if metadata_filter_document_ids:
|
||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if document_ids:
|
||||
document_ids_filter = document_ids
|
||||
else:
|
||||
continue
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": flask_app,
|
||||
"dataset_id": dataset.id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents_item,
|
||||
"document_ids_filter": document_ids_filter,
|
||||
"metadata_condition": metadata_condition,
|
||||
"attachment_ids": [attachment_id] if attachment_id else None,
|
||||
},
|
||||
)
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if reranking_enable:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
if query:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.TEXT_QUERY,
|
||||
)
|
||||
if attachment_id:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.IMAGE_QUERY,
|
||||
query=attachment_id,
|
||||
)
|
||||
else:
|
||||
if index_type == IndexTechniqueType.ECONOMY:
|
||||
if not query:
|
||||
all_documents_item = []
|
||||
else:
|
||||
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
|
||||
elif index_type == IndexTechniqueType.HIGH_QUALITY:
|
||||
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
|
||||
else:
|
||||
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
||||
if all_documents_item:
|
||||
all_documents.extend(all_documents_item)
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
{
|
||||
"$id": "https://dify.ai/schemas/v1/multimodal_general_structure.json",
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"version": "1.0.0",
|
||||
"type": "array",
|
||||
"title": "Multimodal General Structure",
|
||||
"description": "Schema for multimodal general structure (v1) - array of objects",
|
||||
"properties": {
|
||||
"general_chunks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content"
|
||||
},
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "file name"
|
||||
},
|
||||
"size": {
|
||||
"type": "number",
|
||||
"description": "file size"
|
||||
},
|
||||
"extension": {
|
||||
"type": "string",
|
||||
"description": "file extension"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "file type"
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"description": "file mime type"
|
||||
},
|
||||
"transfer_method": {
|
||||
"type": "string",
|
||||
"description": "file transfer method"
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "file url"
|
||||
},
|
||||
"related_id": {
|
||||
"type": "string",
|
||||
"description": "file related id"
|
||||
}
|
||||
},
|
||||
"description": "List of files"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["content"]
|
||||
},
|
||||
"description": "List of content and files"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
{
|
||||
"$id": "https://dify.ai/schemas/v1/multimodal_parent_child_structure.json",
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"version": "1.0.0",
|
||||
"type": "object",
|
||||
"title": "Multimodal Parent-Child Structure",
|
||||
"description": "Schema for multimodal parent-child structure (v1)",
|
||||
"properties": {
|
||||
"parent_mode": {
|
||||
"type": "string",
|
||||
"description": "The mode of parent-child relationship"
|
||||
},
|
||||
"parent_child_chunks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"parent_content": {
|
||||
"type": "string",
|
||||
"description": "The parent content"
|
||||
},
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "file name"
|
||||
},
|
||||
"size": {
|
||||
"type": "number",
|
||||
"description": "file size"
|
||||
},
|
||||
"extension": {
|
||||
"type": "string",
|
||||
"description": "file extension"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "file type"
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"description": "file mime type"
|
||||
},
|
||||
"transfer_method": {
|
||||
"type": "string",
|
||||
"description": "file transfer method"
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "file url"
|
||||
},
|
||||
"related_id": {
|
||||
"type": "string",
|
||||
"description": "file related id"
|
||||
}
|
||||
},
|
||||
"required": ["name", "size", "extension", "type", "mime_type", "transfer_method", "url", "related_id"]
|
||||
},
|
||||
"description": "List of files"
|
||||
},
|
||||
"child_contents": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "List of child contents"
|
||||
}
|
||||
},
|
||||
"required": ["parent_content", "child_contents"]
|
||||
},
|
||||
"description": "List of parent-child chunk pairs"
|
||||
}
|
||||
},
|
||||
"required": ["parent_mode", "parent_child_chunks"]
|
||||
}
|
||||
@@ -25,24 +25,6 @@ def sign_tool_file(tool_file_id: str, extension: str) -> str:
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
|
||||
def sign_upload_file(upload_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
sign file to get a temporary url for plugin access
|
||||
"""
|
||||
# Use internal URL for plugin/tool file access in Docker environments
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
file_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
|
||||
def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
|
||||
@@ -13,5 +13,5 @@ def remove_leading_symbols(text: str) -> str:
|
||||
"""
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'
|
||||
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
|
||||
return re.sub(pattern, "", text)
|
||||
|
||||
@@ -221,7 +221,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == self.provider_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.file import File
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
@@ -15,7 +14,6 @@ from .base import NodeEventBase
|
||||
class RunRetrieverResourceEvent(NodeEventBase):
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
context_files: list[File] | None = Field(default=None, description="context files")
|
||||
|
||||
|
||||
class ModelInvokeCompletedEvent(NodeEventBase):
|
||||
|
||||
@@ -59,7 +59,7 @@ class OutputVariableEntity(BaseModel):
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value_type: OutputVariableType = OutputVariableType.ANY
|
||||
value_type: OutputVariableType
|
||||
value_selector: Sequence[str]
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
|
||||
@@ -412,20 +412,16 @@ class Executor:
|
||||
body_string += f"--{boundary}\r\n"
|
||||
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
|
||||
# decode content safely
|
||||
# Do not decode binary content; use a placeholder with file metadata instead.
|
||||
# Includes filename, size, and MIME type for better logging context.
|
||||
body_string += (
|
||||
f"<file_content_binary: '{file_entry[1][0] or 'unknown'}', "
|
||||
f"type='{file_entry[1][2] if len(file_entry[1]) > 2 else 'unknown'}', "
|
||||
f"size={len(content)} bytes>\r\n"
|
||||
)
|
||||
try:
|
||||
body_string += content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
body_string += content.decode("utf-8", errors="replace")
|
||||
body_string += "\r\n"
|
||||
body_string += f"--{boundary}--\r\n"
|
||||
elif self.node_data.body:
|
||||
if self.content:
|
||||
# If content is bytes, do not decode it; show a placeholder with size.
|
||||
# Provides content size information for binary data without exposing the raw bytes.
|
||||
if isinstance(self.content, bytes):
|
||||
body_string = f"<binary_content: size={len(self.content)} bytes>"
|
||||
body_string = self.content.decode("utf-8", errors="replace")
|
||||
else:
|
||||
body_string = self.content
|
||||
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
|
||||
|
||||
@@ -114,8 +114,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
"""
|
||||
|
||||
type: str = "knowledge-retrieval"
|
||||
query_variable_selector: list[str] | None | str = None
|
||||
query_attachment_selector: list[str] | None | str = None
|
||||
query_variable_selector: list[str]
|
||||
dataset_ids: list[str]
|
||||
retrieval_mode: Literal["single", "multiple"]
|
||||
multiple_retrieval_config: MultipleRetrievalConfig | None = None
|
||||
|
||||
@@ -25,8 +25,6 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.variables import (
|
||||
ArrayFileSegment,
|
||||
FileSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.variables.segments import ArrayObjectSegment
|
||||
@@ -121,41 +119,20 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={},
|
||||
process_data={},
|
||||
outputs={},
|
||||
metadata={},
|
||||
llm_usage=LLMUsage.empty_usage(),
|
||||
)
|
||||
variables: dict[str, Any] = {}
|
||||
# extract variables
|
||||
if self._node_data.query_variable_selector:
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
|
||||
if not isinstance(variable, StringSegment):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error="Query variable is not string type.",
|
||||
)
|
||||
query = variable.value
|
||||
variables["query"] = query
|
||||
|
||||
if self._node_data.query_attachment_selector:
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_attachment_selector)
|
||||
if not isinstance(variable, ArrayFileSegment) and not isinstance(variable, FileSegment):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error="Attachments variable is not array file or file type.",
|
||||
)
|
||||
if isinstance(variable, ArrayFileSegment):
|
||||
variables["attachments"] = variable.value
|
||||
else:
|
||||
variables["attachments"] = [variable.value]
|
||||
|
||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
|
||||
if not isinstance(variable, StringSegment):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error="Query variable is not string type.",
|
||||
)
|
||||
query = variable.value
|
||||
variables = {"query": query}
|
||||
if not query:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
|
||||
)
|
||||
# TODO(-LAN-): Move this check outside.
|
||||
# check rate limit
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
|
||||
@@ -184,7 +161,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
# retrieve knowledge
|
||||
usage = LLMUsage.empty_usage()
|
||||
try:
|
||||
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
|
||||
results, usage = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
|
||||
outputs = {"result": ArrayObjectSegment(value=results)}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@@ -221,16 +198,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
db.session.close()
|
||||
|
||||
def _fetch_dataset_retriever(
|
||||
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
|
||||
self, node_data: KnowledgeRetrievalNodeData, query: str
|
||||
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
available_datasets = []
|
||||
dataset_ids = node_data.dataset_ids
|
||||
query = variables.get("query")
|
||||
attachments = variables.get("attachments")
|
||||
metadata_filter_document_ids = None
|
||||
metadata_condition = None
|
||||
metadata_usage = LLMUsage.empty_usage()
|
||||
|
||||
# Subquery: Count the number of available documents for each dataset
|
||||
subquery = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
|
||||
@@ -261,14 +234,13 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
if not dataset:
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
if query:
|
||||
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
|
||||
[dataset.id for dataset in available_datasets], query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, metadata_usage)
|
||||
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
|
||||
[dataset.id for dataset in available_datasets], query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, metadata_usage)
|
||||
all_documents = []
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
# fetch model config
|
||||
if node_data.single_retrieval_config is None:
|
||||
raise ValueError("single_retrieval_config is required")
|
||||
@@ -300,7 +272,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
|
||||
@@ -347,7 +319,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
|
||||
)
|
||||
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
|
||||
|
||||
@@ -356,7 +327,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
retrieval_resource_list = []
|
||||
# deal with external documents
|
||||
for item in external_documents:
|
||||
source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = {
|
||||
source = {
|
||||
"metadata": {
|
||||
"_source": "knowledge",
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
@@ -413,7 +384,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
"doc_metadata": document.doc_metadata,
|
||||
},
|
||||
"title": document.name,
|
||||
"files": list(record.files) if record.files else None,
|
||||
}
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
|
||||
@@ -423,21 +393,13 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
if retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=self._score, # type: ignore[arg-type, return-value]
|
||||
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||
item["metadata"]["position"] = position # type: ignore[index]
|
||||
item["metadata"]["position"] = position
|
||||
return retrieval_resource_list, usage
|
||||
|
||||
def _score(self, item: dict[str, Any]) -> float:
|
||||
meta = item.get("metadata")
|
||||
if isinstance(meta, dict):
|
||||
s = meta.get("score")
|
||||
if isinstance(s, (int, float)):
|
||||
return float(s)
|
||||
return 0.0
|
||||
|
||||
def _get_metadata_filter_condition(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
|
||||
@@ -697,10 +659,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
if typed_node_data.query_variable_selector:
|
||||
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
|
||||
if typed_node_data.query_attachment_selector:
|
||||
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
|
||||
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
|
||||
return variable_mapping
|
||||
|
||||
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
|
||||
@@ -7,10 +7,8 @@ import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.file import FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
@@ -46,7 +44,6 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.variables import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
@@ -75,9 +72,6 @@ from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import SegmentAttachmentBinding
|
||||
from models.model import UploadFile
|
||||
|
||||
from . import llm_utils
|
||||
from .entities import (
|
||||
@@ -185,17 +179,12 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# fetch context value
|
||||
generator = self._fetch_context(node_data=self.node_data)
|
||||
context = None
|
||||
context_files: list[File] = []
|
||||
for event in generator:
|
||||
context = event.context
|
||||
context_files = event.context_files or []
|
||||
yield event
|
||||
if context:
|
||||
node_inputs["#context#"] = context
|
||||
|
||||
if context_files:
|
||||
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = LLMNode._fetch_model_config(
|
||||
node_data_model=self.node_data.model,
|
||||
@@ -231,7 +220,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
@@ -334,7 +322,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -345,8 +332,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -669,13 +654,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
|
||||
if context_value_variable:
|
||||
if isinstance(context_value_variable, StringSegment):
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=[], context=context_value_variable.value, context_files=[]
|
||||
)
|
||||
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
|
||||
elif isinstance(context_value_variable, ArraySegment):
|
||||
context_str = ""
|
||||
original_retriever_resource: list[RetrievalSourceMetadata] = []
|
||||
context_files: list[File] = []
|
||||
for item in context_value_variable.value:
|
||||
if isinstance(item, str):
|
||||
context_str += item + "\n"
|
||||
@@ -688,34 +670,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
retriever_resource = self._convert_to_original_retriever_resource(item)
|
||||
if retriever_resource:
|
||||
original_retriever_resource.append(retriever_resource)
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(
|
||||
SegmentAttachmentBinding.segment_id == retriever_resource.segment_id,
|
||||
)
|
||||
).all()
|
||||
if attachments_with_bindings:
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
attachment_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=sign_upload_file(upload_file.id, upload_file.extension),
|
||||
)
|
||||
context_files.append(attachment_info)
|
||||
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=original_retriever_resource,
|
||||
context=context_str.strip(),
|
||||
context_files=context_files,
|
||||
retriever_resources=original_retriever_resource, context=context_str.strip()
|
||||
)
|
||||
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None:
|
||||
@@ -743,7 +700,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
content=context_dict.get("content"),
|
||||
page=metadata.get("page"),
|
||||
doc_metadata=metadata.get("doc_metadata"),
|
||||
files=context_dict.get("files"),
|
||||
)
|
||||
|
||||
return source
|
||||
@@ -785,7 +741,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
tenant_id: str,
|
||||
context_files: list["File"] | None = None,
|
||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
@@ -898,23 +853,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# The context_files
|
||||
if vision_enabled and context_files:
|
||||
file_prompts = []
|
||||
for file in context_files:
|
||||
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||
file_prompts.append(file_prompt)
|
||||
# If last prompt is a user prompt, add files into its contents,
|
||||
# otherwise append a new user prompt
|
||||
if (
|
||||
len(prompt_messages) > 0
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
and isinstance(prompt_messages[-1].content, list)
|
||||
):
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# Remove empty messages and filter unsupported content
|
||||
filtered_prompt_messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
|
||||
@@ -221,7 +221,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
|
||||
@@ -97,27 +97,11 @@ dataset_detail_fields = {
|
||||
"total_documents": fields.Integer,
|
||||
"total_available_documents": fields.Integer,
|
||||
"enable_api": fields.Boolean,
|
||||
"is_multimodal": fields.Boolean,
|
||||
}
|
||||
|
||||
file_info_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"size": fields.Integer,
|
||||
"extension": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"source_url": fields.String,
|
||||
}
|
||||
|
||||
content_fields = {
|
||||
"content_type": fields.String,
|
||||
"content": fields.String,
|
||||
"file_info": fields.Nested(file_info_fields, allow_null=True),
|
||||
}
|
||||
|
||||
dataset_query_detail_fields = {
|
||||
"id": fields.String,
|
||||
"queries": fields.Nested(content_fields),
|
||||
"content": fields.String,
|
||||
"source": fields.String,
|
||||
"source_app_id": fields.String,
|
||||
"created_by_role": fields.String,
|
||||
|
||||
@@ -9,8 +9,6 @@ upload_config_fields = {
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"workflow_file_upload_limit": fields.Integer,
|
||||
"image_file_batch_limit": fields.Integer,
|
||||
"single_chunk_attachment_limit": fields.Integer,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -43,19 +43,9 @@ child_chunk_fields = {
|
||||
"score": fields.Float,
|
||||
}
|
||||
|
||||
files_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"size": fields.Integer,
|
||||
"extension": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"source_url": fields.String,
|
||||
}
|
||||
|
||||
hit_testing_record_fields = {
|
||||
"segment": fields.Nested(segment_fields),
|
||||
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
|
||||
"score": fields.Float,
|
||||
"tsne_position": fields.Raw,
|
||||
"files": fields.List(fields.Nested(files_fields)),
|
||||
}
|
||||
|
||||
@@ -13,15 +13,6 @@ child_chunk_fields = {
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
attachment_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"size": fields.Integer,
|
||||
"extension": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"source_url": fields.String,
|
||||
}
|
||||
|
||||
segment_fields = {
|
||||
"id": fields.String,
|
||||
"position": fields.Integer,
|
||||
@@ -48,5 +39,4 @@ segment_fields = {
|
||||
"error": fields.String,
|
||||
"stopped_at": TimestampField,
|
||||
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
|
||||
"attachments": fields.List(fields.Nested(attachment_fields)),
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ def email(email):
|
||||
EmailStr = Annotated[str, AfterValidator(email)]
|
||||
|
||||
|
||||
def uuid_value(value: Any) -> str:
|
||||
def uuid_value(value):
|
||||
if value == "":
|
||||
return str(value)
|
||||
|
||||
@@ -215,11 +215,7 @@ def generate_text_hash(text: str) -> str:
|
||||
|
||||
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(
|
||||
response=json.dumps(jsonable_encoder(response)),
|
||||
status=200,
|
||||
content_type="application/json; charset=utf-8",
|
||||
)
|
||||
return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json")
|
||||
else:
|
||||
|
||||
def generate() -> Generator:
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
"""support-multi-modal
|
||||
|
||||
Revision ID: d57accd375ae
|
||||
Revises: 03f8dcbc611e
|
||||
Create Date: 2025-11-12 15:37:12.363670
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd57accd375ae'
|
||||
down_revision = '7bb281b7a422'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('segment_attachment_bindings',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('segment_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('attachment_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='segment_attachment_binding_pkey')
|
||||
)
|
||||
with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
'segment_attachment_binding_tenant_dataset_document_segment_idx',
|
||||
['tenant_id', 'dataset_id', 'document_id', 'segment_id'],
|
||||
unique=False
|
||||
)
|
||||
batch_op.create_index('segment_attachment_binding_attachment_idx', ['attachment_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('is_multimodal', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.drop_column('is_multimodal')
|
||||
|
||||
|
||||
with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op:
|
||||
batch_op.drop_index('segment_attachment_binding_attachment_idx')
|
||||
batch_op.drop_index('segment_attachment_binding_tenant_dataset_document_segment_idx')
|
||||
|
||||
op.drop_table('segment_attachment_bindings')
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,4 +1,4 @@
|
||||
"""mysql adaptation
|
||||
"""empty message
|
||||
|
||||
Revision ID: 09cfdda155d1
|
||||
Revises: 669ffd70119c
|
||||
@@ -97,31 +97,11 @@ def downgrade():
|
||||
batch_op.alter_column('include_plugins',
|
||||
existing_type=sa.JSON(),
|
||||
type_=postgresql.ARRAY(sa.VARCHAR(length=255)),
|
||||
existing_nullable=False,
|
||||
postgresql_using="""
|
||||
COALESCE(
|
||||
regexp_replace(
|
||||
replace(replace(include_plugins::text, '[', '{'), ']', '}'),
|
||||
'"',
|
||||
'',
|
||||
'g'
|
||||
)::varchar(255)[],
|
||||
ARRAY[]::varchar(255)[]
|
||||
)""")
|
||||
existing_nullable=False)
|
||||
batch_op.alter_column('exclude_plugins',
|
||||
existing_type=sa.JSON(),
|
||||
type_=postgresql.ARRAY(sa.VARCHAR(length=255)),
|
||||
existing_nullable=False,
|
||||
postgresql_using="""
|
||||
COALESCE(
|
||||
regexp_replace(
|
||||
replace(replace(exclude_plugins::text, '[', '{'), ']', '}'),
|
||||
'"',
|
||||
'',
|
||||
'g'
|
||||
)::varchar(255)[],
|
||||
ARRAY[]::varchar(255)[]
|
||||
)""")
|
||||
existing_nullable=False)
|
||||
|
||||
with op.batch_alter_table('external_knowledge_bindings', schema=None) as batch_op:
|
||||
batch_op.alter_column('external_knowledge_id',
|
||||
|
||||
@@ -19,9 +19,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.signature import sign_upload_file
|
||||
from extensions.ext_storage import storage
|
||||
from libs.uuid_utils import uuidv7
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||
@@ -78,7 +76,6 @@ class Dataset(Base):
|
||||
pipeline_id = mapped_column(StringUUID, nullable=True)
|
||||
chunk_structure = mapped_column(sa.String(255), nullable=True)
|
||||
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
is_multimodal = mapped_column(sa.Boolean, default=False, nullable=False, server_default=db.text("false"))
|
||||
|
||||
@property
|
||||
def total_documents(self):
|
||||
@@ -731,7 +728,9 @@ class DocumentSegment(Base):
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
error = mapped_column(LongText, nullable=True)
|
||||
@@ -867,47 +866,6 @@ class DocumentSegment(Base):
|
||||
|
||||
return text
|
||||
|
||||
@property
|
||||
def attachments(self) -> list[dict[str, Any]]:
|
||||
# Use JOIN to fetch attachments in a single query instead of two separate queries
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(
|
||||
SegmentAttachmentBinding.tenant_id == self.tenant_id,
|
||||
SegmentAttachmentBinding.dataset_id == self.dataset_id,
|
||||
SegmentAttachmentBinding.document_id == self.document_id,
|
||||
SegmentAttachmentBinding.segment_id == self.id,
|
||||
)
|
||||
).all()
|
||||
if not attachments_with_bindings:
|
||||
return []
|
||||
attachment_list = []
|
||||
for _, attachment in attachments_with_bindings:
|
||||
upload_file_id = attachment.id
|
||||
nonce = os.urandom(16).hex()
|
||||
timestamp = str(int(time.time()))
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
reference_url = dify_config.CONSOLE_API_URL or ""
|
||||
base_url = f"{reference_url}/files/{upload_file_id}/image-preview"
|
||||
source_url = f"{base_url}?{params}"
|
||||
attachment_list.append(
|
||||
{
|
||||
"id": attachment.id,
|
||||
"name": attachment.name,
|
||||
"size": attachment.size,
|
||||
"extension": attachment.extension,
|
||||
"mime_type": attachment.mime_type,
|
||||
"source_url": source_url,
|
||||
}
|
||||
)
|
||||
return attachment_list
|
||||
|
||||
|
||||
class ChildChunk(Base):
|
||||
__tablename__ = "child_chunks"
|
||||
@@ -1005,38 +963,6 @@ class DatasetQuery(TypeBase):
|
||||
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
@property
|
||||
def queries(self) -> list[dict[str, Any]]:
|
||||
try:
|
||||
queries = json.loads(self.content)
|
||||
if isinstance(queries, list):
|
||||
for query in queries:
|
||||
if query["content_type"] == QueryType.IMAGE_QUERY:
|
||||
file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first()
|
||||
if file_info:
|
||||
query["file_info"] = {
|
||||
"id": file_info.id,
|
||||
"name": file_info.name,
|
||||
"size": file_info.size,
|
||||
"extension": file_info.extension,
|
||||
"mime_type": file_info.mime_type,
|
||||
"source_url": sign_upload_file(file_info.id, file_info.extension),
|
||||
}
|
||||
else:
|
||||
query["file_info"] = None
|
||||
|
||||
return queries
|
||||
else:
|
||||
return [queries]
|
||||
except JSONDecodeError:
|
||||
return [
|
||||
{
|
||||
"content_type": QueryType.TEXT_QUERY,
|
||||
"content": self.content,
|
||||
"file_info": None,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class DatasetKeywordTable(TypeBase):
|
||||
__tablename__ = "dataset_keyword_tables"
|
||||
@@ -1544,25 +1470,3 @@ class PipelineRecommendedPlugin(TypeBase):
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
|
||||
class SegmentAttachmentBinding(Base):
|
||||
__tablename__ = "segment_attachment_bindings"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="segment_attachment_binding_pkey"),
|
||||
sa.Index(
|
||||
"segment_attachment_binding_tenant_dataset_document_segment_idx",
|
||||
"tenant_id",
|
||||
"dataset_id",
|
||||
"document_id",
|
||||
"segment_id",
|
||||
),
|
||||
sa.Index("segment_attachment_binding_attachment_idx", "attachment_id"),
|
||||
)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@@ -111,11 +111,7 @@ class App(Base):
|
||||
else:
|
||||
app_model_config = self.app_model_config
|
||||
if app_model_config:
|
||||
pre_prompt = app_model_config.pre_prompt or ""
|
||||
# Truncate to 200 characters with ellipsis if using prompt as description
|
||||
if len(pre_prompt) > 200:
|
||||
return pre_prompt[:200] + "..."
|
||||
return pre_prompt
|
||||
return app_model_config.pre_prompt
|
||||
else:
|
||||
return ""
|
||||
|
||||
@@ -263,7 +259,7 @@ class App(Base):
|
||||
provider_id = tool.get("provider_id", "")
|
||||
|
||||
if provider_type == ToolProviderType.API:
|
||||
if provider_id not in existing_api_providers:
|
||||
if uuid.UUID(provider_id) not in existing_api_providers:
|
||||
deleted_tools.append(
|
||||
{
|
||||
"type": ToolProviderType.API,
|
||||
@@ -839,29 +835,7 @@ class Conversation(Base):
|
||||
|
||||
@property
|
||||
def status_count(self):
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
# Get all messages with workflow_run_id for this conversation
|
||||
messages = db.session.scalars(
|
||||
select(Message).where(Message.conversation_id == self.id, Message.workflow_run_id.isnot(None))
|
||||
).all()
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Batch load all workflow runs in a single query, filtered by this conversation's app_id
|
||||
workflow_run_ids = [msg.workflow_run_id for msg in messages if msg.workflow_run_id]
|
||||
workflow_runs = {}
|
||||
|
||||
if workflow_run_ids:
|
||||
workflow_runs_query = db.session.scalars(
|
||||
select(WorkflowRun).where(
|
||||
WorkflowRun.id.in_(workflow_run_ids),
|
||||
WorkflowRun.app_id == self.app_id, # Filter by this conversation's app_id
|
||||
)
|
||||
).all()
|
||||
workflow_runs = {run.id: run for run in workflow_runs_query}
|
||||
|
||||
messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all()
|
||||
status_counts = {
|
||||
WorkflowExecutionStatus.RUNNING: 0,
|
||||
WorkflowExecutionStatus.SUCCEEDED: 0,
|
||||
@@ -871,24 +845,18 @@ class Conversation(Base):
|
||||
}
|
||||
|
||||
for message in messages:
|
||||
# Guard against None to satisfy type checker and avoid invalid dict lookups
|
||||
if message.workflow_run_id is None:
|
||||
continue
|
||||
workflow_run = workflow_runs.get(message.workflow_run_id)
|
||||
if not workflow_run:
|
||||
continue
|
||||
if message.workflow_run:
|
||||
status_counts[WorkflowExecutionStatus(message.workflow_run.status)] += 1
|
||||
|
||||
try:
|
||||
status_counts[WorkflowExecutionStatus(workflow_run.status)] += 1
|
||||
except (ValueError, KeyError):
|
||||
# Handle invalid status values gracefully
|
||||
pass
|
||||
|
||||
return {
|
||||
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
|
||||
"failed": status_counts[WorkflowExecutionStatus.FAILED],
|
||||
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
|
||||
}
|
||||
return (
|
||||
{
|
||||
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
|
||||
"failed": status_counts[WorkflowExecutionStatus.FAILED],
|
||||
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
|
||||
}
|
||||
if messages
|
||||
else None
|
||||
)
|
||||
|
||||
@property
|
||||
def first_message(self):
|
||||
@@ -1287,9 +1255,13 @@ class Message(Base):
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"model_provider": self.model_provider,
|
||||
"model_id": self.model_id,
|
||||
"inputs": self.inputs,
|
||||
"query": self.query,
|
||||
"message_tokens": self.message_tokens,
|
||||
"answer_tokens": self.answer_tokens,
|
||||
"provider_response_latency": self.provider_response_latency,
|
||||
"total_price": self.total_price,
|
||||
"message": self.message,
|
||||
"answer": self.answer,
|
||||
@@ -1311,8 +1283,12 @@ class Message(Base):
|
||||
id=data["id"],
|
||||
app_id=data["app_id"],
|
||||
conversation_id=data["conversation_id"],
|
||||
model_provider=data.get("model_provider"),
|
||||
model_id=data["model_id"],
|
||||
inputs=data["inputs"],
|
||||
message_tokens=data.get("message_tokens", 0),
|
||||
answer_tokens=data.get("answer_tokens", 0),
|
||||
provider_response_latency=data.get("provider_response_latency", 0.0),
|
||||
total_price=data["total_price"],
|
||||
query=data["query"],
|
||||
message=data["message"],
|
||||
|
||||
@@ -907,29 +907,19 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
@property
|
||||
def extras(self) -> dict[str, Any]:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
|
||||
extras: dict[str, Any] = {}
|
||||
execution_metadata = self.execution_metadata_dict
|
||||
if execution_metadata:
|
||||
if self.node_type == NodeType.TOOL and "tool_info" in execution_metadata:
|
||||
tool_info: dict[str, Any] = execution_metadata["tool_info"]
|
||||
if self.execution_metadata_dict:
|
||||
if self.node_type == NodeType.TOOL and "tool_info" in self.execution_metadata_dict:
|
||||
tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"]
|
||||
extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_type=tool_info["provider_type"],
|
||||
provider_id=tool_info["provider_id"],
|
||||
)
|
||||
elif self.node_type == NodeType.DATASOURCE and "datasource_info" in execution_metadata:
|
||||
datasource_info = execution_metadata["datasource_info"]
|
||||
elif self.node_type == NodeType.DATASOURCE and "datasource_info" in self.execution_metadata_dict:
|
||||
datasource_info = self.execution_metadata_dict["datasource_info"]
|
||||
extras["icon"] = datasource_info.get("icon")
|
||||
elif self.node_type == NodeType.TRIGGER_PLUGIN and "trigger_info" in execution_metadata:
|
||||
trigger_info = execution_metadata["trigger_info"] or {}
|
||||
provider_id = trigger_info.get("provider_id")
|
||||
if provider_id:
|
||||
extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return extras
|
||||
|
||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.11.1"
|
||||
version = "1.10.1"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
@@ -151,7 +151,7 @@ dev = [
|
||||
"types-pywin32~=310.0.0",
|
||||
"types-pyyaml~=6.0.12",
|
||||
"types-regex~=2024.11.6",
|
||||
"types-shapely~=2.1.0",
|
||||
"types-shapely~=2.0.0",
|
||||
"types-simplejson>=3.20.0",
|
||||
"types-six>=1.17.0",
|
||||
"types-tensorflow>=2.18.0",
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
import base64
|
||||
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
class AttachmentService:
|
||||
_session_maker: sessionmaker
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine | None = None):
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_maker = sessionmaker(bind=session_factory)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_maker = session_factory
|
||||
else:
|
||||
raise AssertionError("must be a sessionmaker or an Engine.")
|
||||
|
||||
def get_file_base64(self, file_id: str) -> str:
|
||||
upload_file = (
|
||||
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
)
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
blob = storage.load_once(upload_file.key)
|
||||
return base64.b64encode(blob).decode()
|
||||
@@ -118,7 +118,7 @@ class ConversationService:
|
||||
app_model: App,
|
||||
conversation_id: str,
|
||||
user: Union[Account, EndUser] | None,
|
||||
name: str | None,
|
||||
name: str,
|
||||
auto_generate: bool,
|
||||
):
|
||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||
|
||||
@@ -7,7 +7,7 @@ import time
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
@@ -19,10 +19,9 @@ from configs import dify_config
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from events.dataset_event import dataset_was_deleted
|
||||
@@ -47,7 +46,6 @@ from models.dataset import (
|
||||
DocumentSegment,
|
||||
ExternalKnowledgeBindings,
|
||||
Pipeline,
|
||||
SegmentAttachmentBinding,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
from models.provider_ids import ModelProviderID
|
||||
@@ -365,27 +363,6 @@ class DatasetService:
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
|
||||
@staticmethod
|
||||
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=model,
|
||||
)
|
||||
text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance)
|
||||
model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
if model_schema.features and ModelFeature.VISION in model_schema.features:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except LLMBadRequestError:
|
||||
raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.")
|
||||
|
||||
@staticmethod
|
||||
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
|
||||
try:
|
||||
@@ -425,13 +402,13 @@ class DatasetService:
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
# check if dataset name is exists
|
||||
if data.get("name") and data.get("name") != dataset.name:
|
||||
if DatasetService._has_dataset_same_name(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
name=data.get("name", dataset.name),
|
||||
):
|
||||
raise ValueError("Dataset name already exists")
|
||||
|
||||
if DatasetService._has_dataset_same_name(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
name=data.get("name", dataset.name),
|
||||
):
|
||||
raise ValueError("Dataset name already exists")
|
||||
|
||||
# Verify user has permission to update this dataset
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
@@ -673,8 +650,6 @@ class DatasetService:
|
||||
Returns:
|
||||
str: Action to perform ('add', 'remove', 'update', or None)
|
||||
"""
|
||||
if "indexing_technique" not in data:
|
||||
return None
|
||||
if dataset.indexing_technique != data["indexing_technique"]:
|
||||
if data["indexing_technique"] == "economy":
|
||||
# Remove embedding model configuration for economy mode
|
||||
@@ -869,12 +844,6 @@ class DatasetService:
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=knowledge_configuration.embedding_model or "",
|
||||
)
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_user.current_tenant_id,
|
||||
knowledge_configuration.embedding_model_provider,
|
||||
knowledge_configuration.embedding_model,
|
||||
)
|
||||
dataset.is_multimodal = is_multimodal
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
@@ -911,12 +880,6 @@ class DatasetService:
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_user.current_tenant_id,
|
||||
knowledge_configuration.embedding_model_provider,
|
||||
knowledge_configuration.embedding_model,
|
||||
)
|
||||
dataset.is_multimodal = is_multimodal
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
dataset.indexing_technique = knowledge_configuration.indexing_technique
|
||||
except LLMBadRequestError:
|
||||
@@ -974,12 +937,6 @@ class DatasetService:
|
||||
)
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_user.current_tenant_id,
|
||||
knowledge_configuration.embedding_model_provider,
|
||||
knowledge_configuration.embedding_model,
|
||||
)
|
||||
dataset.is_multimodal = is_multimodal
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
@@ -1636,20 +1593,6 @@ class DocumentService:
|
||||
return [], ""
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.flush()
|
||||
else:
|
||||
# Fallback when no process_rule provided in knowledge_config:
|
||||
# 1) reuse dataset.latest_process_rule if present
|
||||
# 2) otherwise create an automatic rule
|
||||
dataset_process_rule = getattr(dataset, "latest_process_rule", None)
|
||||
if not dataset_process_rule:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode="automatic",
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.flush()
|
||||
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
|
||||
try:
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
@@ -1661,67 +1604,65 @@ class DocumentService:
|
||||
if not knowledge_config.data_source.info_list.file_info_list:
|
||||
raise ValueError("File source info is required")
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
files = (
|
||||
db.session.query(UploadFile)
|
||||
.where(
|
||||
UploadFile.tenant_id == dataset.tenant_id,
|
||||
UploadFile.id.in_(upload_file_list),
|
||||
for file_id in upload_file_list:
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if len(files) != len(set(upload_file_list)):
|
||||
raise FileNotExistsError("One or more files not found.")
|
||||
|
||||
file_names = [file.name for file in files]
|
||||
db_documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
Document.data_source_type == "upload_file",
|
||||
Document.enabled == True,
|
||||
Document.name.in_(file_names),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
documents_map = {document.name: document for document in db_documents}
|
||||
for file in files:
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise FileNotExistsError()
|
||||
|
||||
file_name = file.name
|
||||
data_source_info: dict[str, str | bool] = {
|
||||
"upload_file_id": file.id,
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
document = documents_map.get(file.name)
|
||||
if knowledge_config.duplicate and document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.updated_at = naive_utc_now()
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
document.doc_language = knowledge_config.doc_language
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
document.indexing_status = "waiting"
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
duplicate_document_ids.append(document.id)
|
||||
continue
|
||||
else:
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
file.name,
|
||||
batch,
|
||||
# check duplicate
|
||||
if knowledge_config.duplicate:
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="upload_file",
|
||||
enabled=True,
|
||||
name=file_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.updated_at = naive_utc_now()
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
document.doc_language = knowledge_config.doc_language
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
document.indexing_status = "waiting"
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
duplicate_document_ids.append(document.id)
|
||||
continue
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
file_name,
|
||||
batch,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||
if not notion_info_list:
|
||||
@@ -2364,7 +2305,6 @@ class DocumentService:
|
||||
embedding_model_provider=knowledge_config.embedding_model_provider,
|
||||
collection_binding_id=dataset_collection_binding_id,
|
||||
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
|
||||
is_multimodal=knowledge_config.is_multimodal,
|
||||
)
|
||||
|
||||
db.session.add(dataset)
|
||||
@@ -2745,13 +2685,6 @@ class SegmentService:
|
||||
if "content" not in args or not args["content"] or not args["content"].strip():
|
||||
raise ValueError("Content is empty")
|
||||
|
||||
if args.get("attachment_ids"):
|
||||
if not isinstance(args["attachment_ids"], list):
|
||||
raise ValueError("Attachment IDs is invalid")
|
||||
single_chunk_attachment_limit = dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT
|
||||
if len(args["attachment_ids"]) > single_chunk_attachment_limit:
|
||||
raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}")
|
||||
|
||||
@classmethod
|
||||
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
|
||||
assert isinstance(current_user, Account)
|
||||
@@ -2798,23 +2731,11 @@ class SegmentService:
|
||||
segment_document.word_count += len(args["answer"])
|
||||
segment_document.answer = args["answer"]
|
||||
|
||||
db.session.add(segment_document)
|
||||
# update document word count
|
||||
assert document.word_count is not None
|
||||
document.word_count += segment_document.word_count
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
if args["attachment_ids"]:
|
||||
for attachment_id in args["attachment_ids"]:
|
||||
binding = SegmentAttachmentBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
document_id=document.id,
|
||||
segment_id=segment_document.id,
|
||||
attachment_id=attachment_id,
|
||||
)
|
||||
db.session.add(binding)
|
||||
db.session.add(segment_document)
|
||||
# update document word count
|
||||
assert document.word_count is not None
|
||||
document.word_count += segment_document.word_count
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
# save vector index
|
||||
@@ -2978,7 +2899,7 @@ class SegmentService:
|
||||
document.word_count = max(0, document.word_count + word_count_change)
|
||||
db.session.add(document)
|
||||
# update segment index task
|
||||
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
# regenerate child chunks
|
||||
# get embedding model instance
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
@@ -3005,11 +2926,12 @@ class SegmentService:
|
||||
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if processing_rule:
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
|
||||
if not processing_rule:
|
||||
raise ValueError("No processing rule found.")
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
|
||||
if args.enabled or keyword_changed:
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
@@ -3054,7 +2976,7 @@ class SegmentService:
|
||||
db.session.add(document)
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
# get embedding model instance
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# check embedding model setting
|
||||
@@ -3080,15 +3002,15 @@ class SegmentService:
|
||||
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if processing_rule:
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
|
||||
if not processing_rule:
|
||||
raise ValueError("No processing rule found.")
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
# update multimodel vector index
|
||||
VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("update segment index failed")
|
||||
segment.enabled = False
|
||||
@@ -3126,9 +3048,7 @@ class SegmentService:
|
||||
)
|
||||
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
|
||||
|
||||
delete_segment_from_index_task.delay(
|
||||
[segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids
|
||||
)
|
||||
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids)
|
||||
|
||||
db.session.delete(segment)
|
||||
# update document word count
|
||||
@@ -3177,9 +3097,7 @@ class SegmentService:
|
||||
|
||||
# Start async cleanup with both parent and child node IDs
|
||||
if index_node_ids or child_node_ids:
|
||||
delete_segment_from_index_task.delay(
|
||||
index_node_ids, dataset.id, document.id, segment_db_ids, child_node_ids
|
||||
)
|
||||
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids)
|
||||
|
||||
if document.word_count is None:
|
||||
document.word_count = 0
|
||||
|
||||
@@ -29,14 +29,8 @@ def get_current_user():
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
try:
|
||||
user_object = current_user._get_current_object()
|
||||
except AttributeError:
|
||||
# Handle case where current_user might not be a LocalProxy in test environments
|
||||
user_object = current_user
|
||||
|
||||
if not isinstance(user_object, (Account, EndUser)):
|
||||
raise TypeError(f"current_user must be Account or EndUser, got {type(user_object).__name__}")
|
||||
if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore
|
||||
raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}")
|
||||
return current_user
|
||||
|
||||
|
||||
|
||||
@@ -124,14 +124,6 @@ class KnowledgeConfig(BaseModel):
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
name: str | None = None
|
||||
is_multimodal: bool = False
|
||||
|
||||
|
||||
class SegmentCreateArgs(BaseModel):
|
||||
content: str | None = None
|
||||
answer: str | None = None
|
||||
keywords: list[str] | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
||||
class SegmentUpdateArgs(BaseModel):
|
||||
@@ -140,7 +132,6 @@ class SegmentUpdateArgs(BaseModel):
|
||||
keywords: list[str] | None = None
|
||||
regenerate_child_chunks: bool = False
|
||||
enabled: bool | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
||||
class ChildChunkUpdateArgs(BaseModel):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
@@ -124,15 +123,6 @@ class FileService:
|
||||
|
||||
return file_size <= file_size_limit
|
||||
|
||||
def get_file_base64(self, file_id: str) -> str:
|
||||
upload_file = (
|
||||
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
)
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
blob = storage.load_once(upload_file.key)
|
||||
return base64.b64encode(blob).decode()
|
||||
|
||||
def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile:
|
||||
if len(text_name) > 200:
|
||||
text_name = text_name[:200]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
@@ -6,7 +5,6 @@ from typing import Any
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.model_runtime.entities import LLMMode
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@@ -34,7 +32,6 @@ class HitTestingService:
|
||||
account: Account,
|
||||
retrieval_model: Any, # FIXME drop this any
|
||||
external_retrieval_model: dict,
|
||||
attachment_ids: list | None = None,
|
||||
limit: int = 10,
|
||||
):
|
||||
start = time.perf_counter()
|
||||
@@ -44,7 +41,7 @@ class HitTestingService:
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
document_ids_filter = None
|
||||
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
|
||||
if metadata_filtering_conditions and query:
|
||||
if metadata_filtering_conditions:
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
from core.app.app_config.entities import MetadataFilteringCondition
|
||||
@@ -69,7 +66,6 @@ class HitTestingService:
|
||||
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
attachment_ids=attachment_ids,
|
||||
top_k=retrieval_model.get("top_k", 4),
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
@@ -84,24 +80,17 @@ class HitTestingService:
|
||||
|
||||
end = time.perf_counter()
|
||||
logger.debug("Hit testing retrieve in %s seconds", end - start)
|
||||
dataset_queries = []
|
||||
if query:
|
||||
content = {"content_type": QueryType.TEXT_QUERY, "content": query}
|
||||
dataset_queries.append(content)
|
||||
if attachment_ids:
|
||||
for attachment_id in attachment_ids:
|
||||
content = {"content_type": QueryType.IMAGE_QUERY, "content": attachment_id}
|
||||
dataset_queries.append(content)
|
||||
if dataset_queries:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset.id,
|
||||
content=json.dumps(dataset_queries),
|
||||
source="hit_testing",
|
||||
source_app_id=None,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset_query)
|
||||
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset.id,
|
||||
content=query,
|
||||
source="hit_testing",
|
||||
source_app_id=None,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
return cls.compact_retrieve_response(query, all_documents)
|
||||
@@ -178,15 +167,10 @@ class HitTestingService:
|
||||
|
||||
@classmethod
|
||||
def hit_testing_args_check(cls, args):
|
||||
query = args.get("query")
|
||||
attachment_ids = args.get("attachment_ids")
|
||||
query = args["query"]
|
||||
|
||||
if not attachment_ids and not query:
|
||||
raise ValueError("Query or attachment_ids is required")
|
||||
if query and len(query) > 250:
|
||||
raise ValueError("Query cannot exceed 250 characters")
|
||||
if attachment_ids and not isinstance(attachment_ids, list):
|
||||
raise ValueError("Attachment_ids must be a list")
|
||||
if not query or len(query) > 250:
|
||||
raise ValueError("Query is required and cannot exceed 250 characters")
|
||||
|
||||
@staticmethod
|
||||
def escape_query_for_search(query: str) -> str:
|
||||
|
||||
@@ -70,28 +70,9 @@ class ModelProviderService:
|
||||
continue
|
||||
|
||||
provider_config = provider_configuration.custom_configuration.provider
|
||||
models = provider_configuration.custom_configuration.models
|
||||
model_config = provider_configuration.custom_configuration.models
|
||||
can_added_models = provider_configuration.custom_configuration.can_added_models
|
||||
|
||||
# IMPORTANT: Never expose decrypted credentials in the provider list API.
|
||||
# Sanitize custom model configurations by dropping the credentials payload.
|
||||
sanitized_model_config = []
|
||||
if models:
|
||||
from core.entities.provider_entities import CustomModelConfiguration # local import to avoid cycles
|
||||
|
||||
for model in models:
|
||||
sanitized_model_config.append(
|
||||
CustomModelConfiguration(
|
||||
model=model.model,
|
||||
model_type=model.model_type,
|
||||
credentials=None, # strip secrets from list view
|
||||
current_credential_id=model.current_credential_id,
|
||||
current_credential_name=model.current_credential_name,
|
||||
available_model_credentials=model.available_model_credentials,
|
||||
unadded_to_model_list=model.unadded_to_model_list,
|
||||
)
|
||||
)
|
||||
|
||||
provider_response = ProviderResponse(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_configuration.provider.provider,
|
||||
@@ -114,7 +95,7 @@ class ModelProviderService:
|
||||
current_credential_id=getattr(provider_config, "current_credential_id", None),
|
||||
current_credential_name=getattr(provider_config, "current_credential_name", None),
|
||||
available_credentials=getattr(provider_config, "available_credentials", []),
|
||||
custom_models=sanitized_model_config,
|
||||
custom_models=model_config,
|
||||
can_added_models=can_added_models,
|
||||
),
|
||||
system_configuration=SystemConfigurationResponse(
|
||||
|
||||
@@ -4,14 +4,11 @@ from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models import UploadFile
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode
|
||||
|
||||
@@ -24,10 +21,9 @@ class VectorService:
|
||||
cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str
|
||||
):
|
||||
documents: list[Document] = []
|
||||
multimodal_documents: list[AttachmentDocument] = []
|
||||
|
||||
for segment in segments:
|
||||
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
if doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
|
||||
if not dataset_document:
|
||||
logger.warning(
|
||||
@@ -74,29 +70,12 @@ class VectorService:
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.TEXT,
|
||||
},
|
||||
)
|
||||
documents.append(rag_document)
|
||||
if dataset.is_multimodal:
|
||||
for attachment in segment.attachments:
|
||||
multimodal_document: AttachmentDocument = AttachmentDocument(
|
||||
page_content=attachment["name"],
|
||||
metadata={
|
||||
"doc_id": attachment["id"],
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
},
|
||||
)
|
||||
multimodal_documents.append(multimodal_document)
|
||||
index_processor: BaseIndexProcessor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
|
||||
if len(documents) > 0:
|
||||
index_processor.load(dataset, documents, None, with_keywords=True, keywords_list=keywords_list)
|
||||
if len(multimodal_documents) > 0:
|
||||
index_processor.load(dataset, [], multimodal_documents, with_keywords=False)
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
|
||||
|
||||
@classmethod
|
||||
def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset):
|
||||
@@ -151,7 +130,6 @@ class VectorService:
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.TEXT,
|
||||
},
|
||||
)
|
||||
# use full doc mode to generate segment's child chunk
|
||||
@@ -248,92 +226,3 @@ class VectorService:
|
||||
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
|
||||
vector = Vector(dataset=dataset)
|
||||
vector.delete_by_ids([child_chunk.index_node_id])
|
||||
|
||||
@classmethod
|
||||
def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset):
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
return
|
||||
|
||||
attachments = segment.attachments
|
||||
old_attachment_ids = [attachment["id"] for attachment in attachments] if attachments else []
|
||||
|
||||
# Check if there's any actual change needed
|
||||
if set(attachment_ids) == set(old_attachment_ids):
|
||||
return
|
||||
|
||||
try:
|
||||
vector = Vector(dataset=dataset)
|
||||
if dataset.is_multimodal:
|
||||
# Delete old vectors if they exist
|
||||
if old_attachment_ids:
|
||||
vector.delete_by_ids(old_attachment_ids)
|
||||
|
||||
# Delete existing segment attachment bindings in one operation
|
||||
db.session.query(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
if not attachment_ids:
|
||||
db.session.commit()
|
||||
return
|
||||
|
||||
# Bulk fetch upload files - only fetch needed fields
|
||||
upload_file_list = db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
|
||||
|
||||
if not upload_file_list:
|
||||
db.session.commit()
|
||||
return
|
||||
|
||||
# Create a mapping for quick lookup
|
||||
upload_file_map = {upload_file.id: upload_file for upload_file in upload_file_list}
|
||||
|
||||
# Prepare batch operations
|
||||
bindings = []
|
||||
documents = []
|
||||
|
||||
# Create common metadata base to avoid repetition
|
||||
base_metadata = {
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
}
|
||||
|
||||
# Process attachments in the order specified by attachment_ids
|
||||
for attachment_id in attachment_ids:
|
||||
upload_file = upload_file_map.get(attachment_id)
|
||||
if not upload_file:
|
||||
logger.warning("Upload file not found for attachment_id: %s", attachment_id)
|
||||
continue
|
||||
|
||||
# Create segment attachment binding
|
||||
bindings.append(
|
||||
SegmentAttachmentBinding(
|
||||
tenant_id=segment.tenant_id,
|
||||
dataset_id=segment.dataset_id,
|
||||
document_id=segment.document_id,
|
||||
segment_id=segment.id,
|
||||
attachment_id=upload_file.id,
|
||||
)
|
||||
)
|
||||
|
||||
# Create document for vector indexing
|
||||
documents.append(
|
||||
Document(page_content=upload_file.name, metadata={**base_metadata, "doc_id": upload_file.id})
|
||||
)
|
||||
|
||||
# Bulk insert all bindings at once
|
||||
if bindings:
|
||||
db.session.add_all(bindings)
|
||||
|
||||
# Add documents to vector store if any
|
||||
if documents and dataset.is_multimodal:
|
||||
vector.add_texts(documents, duplicate_check=True)
|
||||
|
||||
# Single commit for all operations
|
||||
db.session.commit()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to update multimodal vector for segment %s", segment.id)
|
||||
db.session.rollback()
|
||||
raise
|
||||
|
||||
@@ -4,10 +4,9 @@ import time
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -56,7 +55,6 @@ def add_document_to_index_task(dataset_document_id: str):
|
||||
)
|
||||
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
@@ -67,7 +65,7 @@ def add_document_to_index_task(dataset_document_id: str):
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
@@ -83,25 +81,11 @@ def add_document_to_index_task(dataset_document_id: str):
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
if dataset.is_multimodal:
|
||||
for attachment in segment.attachments:
|
||||
multimodal_documents.append(
|
||||
AttachmentDocument(
|
||||
page_content=attachment["name"],
|
||||
metadata={
|
||||
"doc_id": attachment["id"],
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
},
|
||||
)
|
||||
)
|
||||
documents.append(document)
|
||||
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
|
||||
index_processor.load(dataset, documents)
|
||||
|
||||
# delete auto disable log
|
||||
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()
|
||||
|
||||
@@ -18,7 +18,6 @@ from models.dataset import (
|
||||
DatasetQuery,
|
||||
Document,
|
||||
DocumentSegment,
|
||||
SegmentAttachmentBinding,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
|
||||
@@ -59,20 +58,14 @@ def clean_dataset_task(
|
||||
)
|
||||
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
|
||||
# Use JOIN to fetch attachments with bindings in a single query
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id)
|
||||
).all()
|
||||
|
||||
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
|
||||
# This ensures all invalid doc_form values are properly handled
|
||||
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
|
||||
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
|
||||
doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
doc_form = IndexType.PARAGRAPH_INDEX
|
||||
logger.info(
|
||||
click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
|
||||
)
|
||||
@@ -97,7 +90,6 @@ def clean_dataset_task(
|
||||
|
||||
for document in documents:
|
||||
db.session.delete(document)
|
||||
# delete document file
|
||||
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
@@ -115,19 +107,6 @@ def clean_dataset_task(
|
||||
)
|
||||
db.session.delete(image_file)
|
||||
db.session.delete(segment)
|
||||
# delete segment attachments
|
||||
if attachments_with_bindings:
|
||||
for binding, attachment_file in attachments_with_bindings:
|
||||
try:
|
||||
storage.delete(attachment_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete attachment_file failed when storage deleted, \
|
||||
attachment_file_id: %s",
|
||||
binding.attachment_id,
|
||||
)
|
||||
db.session.delete(attachment_file)
|
||||
db.session.delete(binding)
|
||||
|
||||
db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
|
||||
db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
|
||||
|
||||
@@ -9,7 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -36,16 +36,6 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
raise Exception("Document has no dataset")
|
||||
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
# Use JOIN to fetch attachments with bindings in a single query
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(
|
||||
SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
|
||||
SegmentAttachmentBinding.dataset_id == dataset_id,
|
||||
SegmentAttachmentBinding.document_id == document_id,
|
||||
)
|
||||
).all()
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
@@ -79,19 +69,6 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
||||
db.session.delete(file)
|
||||
db.session.commit()
|
||||
# delete segment attachments
|
||||
if attachments_with_bindings:
|
||||
for binding, attachment_file in attachments_with_bindings:
|
||||
try:
|
||||
storage.delete(attachment_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete attachment_file failed when storage deleted, \
|
||||
attachment_file_id: %s",
|
||||
binding.attachment_id,
|
||||
)
|
||||
db.session.delete(attachment_file)
|
||||
db.session.delete(binding)
|
||||
|
||||
# delete dataset metadata binding
|
||||
db.session.query(DatasetMetadataBinding).where(
|
||||
|
||||
@@ -4,10 +4,9 @@ import time
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
@@ -29,7 +28,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
|
||||
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if action == "upgrade":
|
||||
dataset_documents = (
|
||||
@@ -120,7 +119,6 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
@@ -131,7 +129,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
@@ -147,25 +145,9 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
if dataset.is_multimodal:
|
||||
for attachment in segment.attachments:
|
||||
multimodal_documents.append(
|
||||
AttachmentDocument(
|
||||
page_content=attachment["name"],
|
||||
metadata={
|
||||
"doc_id": attachment["id"],
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
},
|
||||
)
|
||||
)
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(
|
||||
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
|
||||
)
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]):
|
||||
"""
|
||||
Async deal dataset from index
|
||||
:param dataset_id: dataset_id
|
||||
@@ -32,7 +32,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
|
||||
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if action == "remove":
|
||||
index_processor.clean(dataset, None, with_keywords=False)
|
||||
@@ -119,7 +119,6 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
@@ -130,7 +129,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
@@ -146,25 +145,9 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
if dataset.is_multimodal:
|
||||
for attachment in segment.attachments:
|
||||
multimodal_documents.append(
|
||||
AttachmentDocument(
|
||||
page_content=attachment["name"],
|
||||
metadata={
|
||||
"doc_id": attachment["id"],
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
},
|
||||
)
|
||||
)
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(
|
||||
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
|
||||
)
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
|
||||
@@ -6,15 +6,14 @@ from celery import shared_task
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, SegmentAttachmentBinding
|
||||
from models.model import UploadFile
|
||||
from models.dataset import Dataset, Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def delete_segment_from_index_task(
|
||||
index_node_ids: list, dataset_id: str, document_id: str, segment_ids: list, child_node_ids: list | None = None
|
||||
index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None
|
||||
):
|
||||
"""
|
||||
Async Remove segment from index
|
||||
@@ -50,21 +49,6 @@ def delete_segment_from_index_task(
|
||||
delete_child_chunks=True,
|
||||
precomputed_child_node_ids=child_node_ids,
|
||||
)
|
||||
if dataset.is_multimodal:
|
||||
# delete segment attachment binding
|
||||
segment_attachment_bindings = (
|
||||
db.session.query(SegmentAttachmentBinding)
|
||||
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
|
||||
.all()
|
||||
)
|
||||
if segment_attachment_bindings:
|
||||
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
|
||||
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
|
||||
for binding in segment_attachment_bindings:
|
||||
db.session.delete(binding)
|
||||
# delete upload file
|
||||
db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
|
||||
db.session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user