Compare commits

...

74 Commits

Author SHA1 Message Date
-LAN-
2058186f22 chore: Bump version references to 1.11.1 (#29568) 2025-12-12 14:42:25 +08:00
Taka Sasaki
bece2f101c fix: return None from retrieve_tokens when access_token is empty (#29516) 2025-12-12 13:49:11 +08:00
Jyong
04d09c2d77 fix: hit-test failed when attachment id is not exist (#29563) 2025-12-12 13:45:00 +08:00
Jyong
db42f467c8 fix: docx extractor external image failed (#29558) 2025-12-12 13:41:51 +08:00
wangxiaolei
ac40309850 perf: optimize save_document_with_dataset_id (#29550) 2025-12-12 13:14:45 +08:00
Nie Ronghua
12e39365fa perf(core/rag): optimize Excel extractor performance and memory usage (#29551)
Co-authored-by: 01393547 <nieronghua@sf-express.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-12 12:15:03 +08:00
Maries
d48300d08c fix: remove validate=True to fix flask-restx AttributeError (#29552)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-12 12:01:20 +08:00
Pleasure1234
761f8c8043 fix: set response content type with charset in helper (#29534) 2025-12-12 11:50:35 +08:00
Coding On Star
05f63c88c6 feat: integrate Amplitude API key into layout and provider components (#29546)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2025-12-12 11:49:12 +08:00
Stream
8daf9ce98d test(trigger): add container integration tests for trigger (#29527)
Signed-off-by: Stream <Stream_2@qq.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-12 11:31:34 +08:00
Shua Chen
61ee1b9094 fix: truncate auto-populated description to prevent 400-char limit error (#28681) 2025-12-12 11:19:53 +08:00
NFish
87c4b4c576 fix: nextjs security update (#29545) 2025-12-12 11:05:48 +08:00
wangxiaolei
193c8e2362 fix: fix available_credentials is empty (#29521) 2025-12-12 09:51:55 +08:00
NFish
4d57460356 fix: upgrade react and react-dom to 19.2.3,fix cve errors (#29532) 2025-12-12 09:38:32 +08:00
-LAN-
063b39ada5 fix: conversation rename payload validation (#29510) 2025-12-11 18:05:41 +09:00
Joel
6419ce02c7 test: add testcase for config prompt components (#29491)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-11 16:56:20 +08:00
yyh
1a877bb4d0 chore: add .nvmrc for Node 22 alignment (#29495) 2025-12-11 16:29:30 +08:00
-LAN-
281e9d4f51 fix: chat api in explore page reject blank conversation id (#29500) 2025-12-11 16:26:42 +08:00
github-actions[bot]
a195b410d1 chore(i18n): translate i18n files and update type definitions (#29499)
Co-authored-by: iamjoel <2120155+iamjoel@users.noreply.github.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
2025-12-11 16:08:52 +08:00
Joel
91e5db3e83 chore: Advance the timing of the dataset payment prompt (#29497)
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-11 15:49:42 +08:00
Coding On Star
f20a2d1586 chore: add placeholder for Amplitude API key in .env.example (#29489)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2025-12-11 15:21:52 +08:00
wangxiaolei
6e802a343e perf: remove the n+1 query (#29483) 2025-12-11 15:18:27 +08:00
yyh
a30cbe3c95 test: add debug-with-multiple-model spec (#29490) 2025-12-11 15:05:37 +08:00
Coding On Star
7344adf65e feat: add Amplitude API key to Docker entrypoint script (#29477)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2025-12-11 14:44:12 +08:00
非法操作
fcadee9413 fix: flask db downgrade not work (#29465) 2025-12-11 14:30:09 +08:00
Jyong
69a22af1c9 fix: optimize database query when retrieval knowledge in App (#29467) 2025-12-11 13:50:46 +08:00
CrabSAMA
aac6f44562 fix: workflow end node validate error (#29473)
Co-authored-by: Novice <novice12185727@gmail.com>
2025-12-11 13:47:37 +08:00
-LAN-
2e1efd62e1 revert: "fix(ops): add streaming metrics and LLM span for agent-chat traces" (#29469) 2025-12-11 12:53:37 +08:00
crazywoola
1847609926 fix: failed to delete model (#29456) 2025-12-11 12:05:44 +08:00
Hengdong Gong
91f6d25dae fix: knowledge dataset description field validation error #29404 (#29405)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-11 11:17:08 +08:00
yyh
acdbcdb6f8 chore: update packageManager version in package.json to pnpm@10.25.0 (#29407) 2025-12-11 09:51:30 +08:00
dependabot[bot]
a9627ba60a chore(deps): bump types-shapely from 2.0.0.20250404 to 2.1.0.20250917 in /api (#29441)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-11 09:49:19 +08:00
wangxiaolei
266d1c70ac fix: fix custom model credentials display as plaintext (#29425) 2025-12-11 09:48:45 +08:00
wangxiaolei
d152d63e7d chore: update remove_leading_symbols pattern, keep 【 (#29419)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-11 09:47:39 +08:00
-LAN-
b4afc7e435 fix: Can not blank conversation ID validation in chat payloads (#29436)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-11 09:47:10 +08:00
-LAN-
2d496e7e08 ci: enforce semantic pull request titles (#29438)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-12-11 09:45:55 +08:00
AuditAIH
693877e5e4 Fix: Prevent binary content from being stored in process_data for HTTP nodes (#27532)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-11 02:52:40 +08:00
NeatGuyCoding
8cab3e5a1e minor fix: get_tools wrong condition (#27253)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2025-12-11 02:33:14 +08:00
Jyong
18082752a0 fix knowledge pipeline run multimodal document failed (#29431) 2025-12-10 20:42:51 +08:00
-LAN-
813a734f27 chore: bump dify release to 1.11.0 (#29355) 2025-12-10 19:54:25 +08:00
Wu Tianwei
94244ed8f6 fix: handle potential undefined values in query_attachment_selector across multiple components (#29429) 2025-12-10 19:30:21 +08:00
yyh
ec3a52f012 Fix immediate window open defaults and error handling (#29417)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-10 19:12:14 +08:00
Wu Tianwei
ea063a1139 fix(i18n): remove unused credentialSelector translations from dataset-pipeline files (#29423) 2025-12-10 19:04:34 +08:00
Jyong
784008997b fix parent-child check when child chunk is not exist (#29426) 2025-12-10 18:45:43 +08:00
Coding On Star
0c2a354115 Using SonarJS to analyze components' complexity (#29412)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
Co-authored-by: 姜涵煦 <hanxujiang@jianghanxudeMacBook-Pro.local>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-10 17:25:54 +08:00
yyh
e477e6c928 fix: harden async window open placeholder logic (#29393) 2025-12-10 16:46:48 +08:00
Wu Tianwei
bafd093fa9 fix: Add dataset file upload restrictions (#29397)
Co-authored-by: kurokobo <kuro664@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
2025-12-10 16:41:05 +08:00
Jyong
88b20bc6d0 fix dataset multimodal field not update (#29403) 2025-12-10 15:18:38 +08:00
Coding On Star
12d019cd31 fix: improve compatibility of @headlessui/react with happy-dom by ensuring HTMLElement.prototype.focus is writable (#29399)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-10 14:40:48 +08:00
Jyong
b49e2646ff fix: session unbound during parent-child retrieval (#29396) 2025-12-10 14:08:55 +08:00
github-actions[bot]
e8720de9ad chore(i18n): translate i18n files and update type definitions (#29395)
Co-authored-by: zhsama <33454514+zhsama@users.noreply.github.com>
2025-12-10 13:52:54 +08:00
GuanMu
0867c1800b refactor: simplify plugin task handling and improve UI feedback (#26293) 2025-12-10 13:34:05 +08:00
Coding On Star
681c06186e add @testing-library/user-event and create tests for external-knowledge-base/ (#29323)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-10 12:46:52 +08:00
yyh
f722fdfa6d fix: prevent popup blocker from blocking async window.open (#29391) 2025-12-10 12:46:01 +08:00
wangxiaolei
c033030d8c fix: 'list' object has no attribute 'find' (#29384)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-10 12:45:53 +08:00
Nie Ronghua
51330c0ee6 fix(App.deleted_tools): incorrect compare between UUID and map with string-typed key. (#29340)
Co-authored-by: 01393547 <nieronghua@sf-express.com>
Co-authored-by: Yeuoly <45712896+Yeuoly@users.noreply.github.com>
2025-12-10 10:47:45 +08:00
非法操作
7df360a292 fix: workflow log missing trigger icon (#29379) 2025-12-10 10:15:21 +08:00
wangxiaolei
e205182e1f fix: Parent instance <DocumentSegment at 0x7955b5572c90> is not bound… (#29377) 2025-12-10 10:01:45 +08:00
wangxiaolei
4a88c8fd19 chore: set is_multimodal db define default = false (#29362) 2025-12-10 09:44:47 +08:00
znn
1b9165624f adding llm_usage and error_type (#26546) 2025-12-10 09:19:13 +08:00
-LAN-
56f8bdd724 Update refactor issue template and remove tracker (#29357) 2025-12-09 22:03:21 +08:00
Nan LI
efa1b452da feat: Add startup parameters for language-specific Weaviate tokenizer (#29347)
Co-authored-by: Jing <jingguo92@gmail.com>
2025-12-09 21:00:19 +08:00
-LAN-
bcbc07e99c Add MCP backend codeowners (#29354) 2025-12-09 20:45:57 +08:00
Joel
d79d0a47a7 chore: not set empty tool config to default value (#29338) 2025-12-09 17:14:04 +08:00
Joel
f5d676f3f1 fix: agent app add tool hasn't add default params config (#29330) 2025-12-09 16:17:27 +08:00
非法操作
8f7173b69b fix: admin dislike feedback lose content (#29327) 2025-12-09 16:07:59 +08:00
github-actions[bot]
8275533418 chore(i18n): translate i18n files and update type definitions (#29329)
Co-authored-by: iamjoel <2120155+iamjoel@users.noreply.github.com>
2025-12-09 15:57:35 +08:00
yyh
c1c1fd0509 feat: make billing management entry prominent and enable current plan portal (#29321) 2025-12-09 15:43:51 +08:00
wangxiaolei
c24835ca87 chore: update the error message (#29325) 2025-12-09 15:29:04 +08:00
yyh
022cfbd186 refactor: remove isMobile prop from Chat and TryToAsk components (#29319) 2025-12-09 15:11:05 +08:00
Jyong
9affc546c6 Feat/support multimodal embedding (#29115)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-09 14:41:46 +08:00
wangxiaolei
77cf8f6c27 chore: remove python sdk from dify repo (#29318) 2025-12-09 14:13:14 +08:00
yyh
18601d8b38 Refactor datasets service toward TanStack Query (#29008)
Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
2025-12-09 13:48:23 +08:00
kurokobo
57d244de69 feat: introduce init container to automatically fix storage permissions (#29297)
Co-authored-by: 朱通通 <zhutong66@163.com>
2025-12-09 13:48:23 +08:00
306 changed files with 11257 additions and 10868 deletions

8
.github/CODEOWNERS vendored
View File

@@ -9,6 +9,14 @@
# Backend (default owner, more specific rules below will override)
api/ @QuantumGhost
# Backend - MCP
api/core/mcp/ @Nov1c444
api/core/entities/mcp_provider.py @Nov1c444
api/services/tools/mcp_tools_manage_service.py @Nov1c444
api/controllers/mcp/ @Nov1c444
api/controllers/console/app/mcp_server.py @Nov1c444
api/tests/**/*mcp* @Nov1c444
# Backend - Workflow - Engine (Core graph execution engine)
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
api/core/workflow/runtime/ @laipz8200 @QuantumGhost

View File

@@ -1,8 +1,6 @@
name: "✨ Refactor"
description: Refactor existing code for improved readability and maintainability.
title: "[Chore/Refactor] "
labels:
- refactor
name: "✨ Refactor or Chore"
description: Refactor existing code or perform maintenance chores to improve readability and reliability.
title: "[Refactor/Chore] "
body:
- type: checkboxes
attributes:
@@ -11,7 +9,7 @@ body:
options:
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
required: true
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
- label: This is only for refactors or chores; if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
required: true
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
@@ -25,14 +23,14 @@ body:
id: description
attributes:
label: Description
placeholder: "Describe the refactor you are proposing."
placeholder: "Describe the refactor or chore you are proposing."
validations:
required: true
- type: textarea
id: motivation
attributes:
label: Motivation
placeholder: "Explain why this refactor is necessary."
placeholder: "Explain why this refactor or chore is necessary."
validations:
required: false
- type: textarea

View File

@@ -1,13 +0,0 @@
name: "👾 Tracker"
description: For inner usages, please do not use this template.
title: "[Tracker] "
labels:
- tracker
body:
- type: textarea
id: content
attributes:
label: Blockers
placeholder: "- [ ] ..."
validations:
required: true

View File

@@ -0,0 +1,21 @@
name: Semantic Pull Request
on:
pull_request:
types:
- opened
- edited
- reopened
- synchronize
jobs:
lint:
name: Validate PR title
permissions:
pull-requests: read
runs-on: ubuntu-latest
steps:
- name: Check title
uses: amannn/action-semantic-pull-request@v6.1.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

1
.nvmrc Normal file
View File

@@ -0,0 +1 @@
22.11.0

View File

@@ -654,3 +654,9 @@ TENANT_ISOLATED_TASK_CONCURRENCY=1
# Maximum number of segments for dataset segments API (0 for unlimited)
DATASET_MAX_SEGMENTS_PER_REQUEST=0
# Multimodal knowledgebase limit
SINGLE_CHUNK_ATTACHMENT_LIMIT=10
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
IMAGE_FILE_BATCH_LIMIT=10

View File

@@ -360,6 +360,26 @@ class FileUploadConfig(BaseSettings):
default=10,
)
IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field(
description="Maximum number of files allowed in a image batch upload operation",
default=10,
)
SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field(
description="Maximum number of files allowed in a single chunk attachment",
default=10,
)
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed image file size for attachments in megabytes",
default=2,
)
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field(
description="Timeout for downloading image attachments in seconds",
default=60,
)
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
description=(
"Comma-separated list of file extensions that are blocked from upload. "

View File

@@ -61,6 +61,7 @@ class ChatMessagesQuery(BaseModel):
class MessageFeedbackPayload(BaseModel):
message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
@field_validator("message_id")
@classmethod
@@ -324,6 +325,7 @@ class MessageFeedbackApi(Resource):
db.session.delete(feedback)
elif args.rating and feedback:
feedback.rating = args.rating
feedback.content = args.content
elif not args.rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
@@ -335,6 +337,7 @@ class MessageFeedbackApi(Resource):
conversation_id=message.conversation_id,
message_id=message.id,
rating=rating_value,
content=args.content,
from_source="admin",
from_account_id=current_user.id,
)

View File

@@ -114,7 +114,7 @@ class AppTriggersApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
class AppTriggerEnableApi(Resource):
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
@console_ns.expect(console_ns.models[ParserEnable.__name__])
@setup_required
@login_required
@account_initialization_required

View File

@@ -151,6 +151,7 @@ class DatasetUpdatePayload(BaseModel):
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None
icon_info: dict[str, Any] | None = None
is_multimodal: bool | None = False
@field_validator("indexing_technique")
@classmethod
@@ -421,19 +422,18 @@ class DatasetApi(Resource):
raise NotFound("Dataset not found.")
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
payload_data = payload.model_dump(exclude_unset=True)
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting
if (
payload.indexing_technique == "high_quality"
and payload.embedding_model_provider is not None
and payload.embedding_model is not None
):
DatasetService.check_embedding_model_setting(
is_multimodal = DatasetService.check_is_multimodal_model(
dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
)
payload.is_multimodal = is_multimodal
payload_data = payload.model_dump(exclude_unset=True)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, payload.permission, payload.partial_member_list

View File

@@ -424,6 +424,10 @@ class DatasetInitApi(Resource):
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_config.embedding_model,
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
)
knowledge_config.is_multimodal = is_multimodal
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."

View File

@@ -51,6 +51,7 @@ class SegmentCreatePayload(BaseModel):
content: str
answer: str | None = None
keywords: list[str] | None = None
attachment_ids: list[str] | None = None
class SegmentUpdatePayload(BaseModel):
@@ -58,6 +59,7 @@ class SegmentUpdatePayload(BaseModel):
answer: str | None = None
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None
class BatchImportPayload(BaseModel):

View File

@@ -1,7 +1,7 @@
import logging
from typing import Any
from flask_restx import marshal
from flask_restx import marshal, reqparse
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@@ -33,6 +33,7 @@ class HitTestingPayload(BaseModel):
query: str = Field(max_length=250)
retrieval_model: dict[str, Any] | None = None
external_retrieval_model: dict[str, Any] | None = None
attachment_ids: list[str] | None = None
class DatasetsHitTestingBase:
@@ -54,16 +55,28 @@ class DatasetsHitTestingBase:
def hit_testing_args_check(args: dict[str, Any]):
HitTestingService.hit_testing_args_check(args)
@staticmethod
def parse_args():
parser = (
reqparse.RequestParser()
.add_argument("query", type=str, required=False, location="json")
.add_argument("attachment_ids", type=list, required=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
return parser.parse_args()
@staticmethod
def perform_hit_testing(dataset, args):
assert isinstance(current_user, Account)
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
query=args.get("query"),
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
retrieval_model=args.get("retrieval_model"),
external_retrieval_model=args.get("external_retrieval_model"),
attachment_ids=args.get("attachment_ids"),
limit=10,
)
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}

View File

@@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
class DataSourceContentPreviewApi(Resource):
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
@console_ns.expect(console_ns.models[Parser.__name__])
@setup_required
@login_required
@account_initialization_required

View File

@@ -2,7 +2,7 @@ import logging
from typing import Any, Literal
from uuid import UUID
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
import services
@@ -52,10 +52,24 @@ class ChatMessagePayload(BaseModel):
inputs: dict[str, Any]
query: str
files: list[dict[str, Any]] | None = None
conversation_id: UUID | None = None
parent_message_id: UUID | None = None
conversation_id: str | None = None
parent_message_id: str | None = None
retriever_from: str = Field(default="explore_app")
@field_validator("conversation_id", "parent_message_id", mode="before")
@classmethod
def normalize_uuid(cls, value: str | UUID | None) -> str | None:
"""
Accept blank IDs and validate UUID format when provided.
"""
if not value:
return None
try:
return helper.uuid_value(value)
except ValueError as exc:
raise ValueError("must be a valid UUID") from exc
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)

View File

@@ -3,7 +3,7 @@ from uuid import UUID
from flask import request
from flask_restx import marshal_with
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@@ -30,9 +30,16 @@ class ConversationListQuery(BaseModel):
class ConversationRenamePayload(BaseModel):
name: str
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)

View File

@@ -45,6 +45,9 @@ class FileApi(Resource):
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
"image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
"single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
"attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
}, 200
@setup_required

View File

@@ -230,7 +230,7 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 200
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True)
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -282,9 +282,10 @@ class ModelProviderModelCredentialApi(Resource):
tenant_id=tenant_id, provider_name=provider
)
else:
model_type = args.model_type
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
normalized_model_type = args.model_type.to_origin_model_type()
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model
tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model
)
return jsonable_encoder(

View File

@@ -4,7 +4,7 @@ from uuid import UUID
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
@@ -52,11 +52,23 @@ class ChatRequestPayload(BaseModel):
query: str
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
conversation_id: UUID | None = None
conversation_id: str | None = Field(default=None, description="Conversation UUID")
retriever_from: str = Field(default="dev")
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
@field_validator("conversation_id", mode="before")
@classmethod
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
"""Allow missing or blank conversation IDs; enforce UUID format when provided."""
if not value:
return None
try:
return helper.uuid_value(value)
except ValueError as exc:
raise ValueError("conversation_id must be a valid UUID") from exc
register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)

View File

@@ -4,7 +4,7 @@ from uuid import UUID
from flask import request
from flask_restx import Resource
from flask_restx._http import HTTPStatus
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
@@ -37,9 +37,16 @@ class ConversationListQuery(BaseModel):
class ConversationRenamePayload(BaseModel):
name: str = Field(description="New conversation name")
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
class ConversationVariablesQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")

View File

@@ -62,8 +62,7 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
@@ -73,7 +72,7 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
from models.workflow import Workflow, WorkflowNodeExecutionModel
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@@ -581,7 +580,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
with self._database_session() as session:
# Save message
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
self._save_message(session=session, graph_runtime_state=resolved_state)
yield workflow_finish_resp
elif event.stopped_by in (
@@ -591,7 +590,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# When hitting input-moderation or annotation-reply, the workflow will not start
with self._database_session() as session:
# Save message
self._save_message(session=session, trace_manager=trace_manager)
self._save_message(session=session)
yield self._message_end_to_stream_response()
@@ -600,7 +599,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
event: QueueAdvancedChatMessageEndEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events."""
@@ -618,7 +616,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# Save message
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
self._save_message(session=session, graph_runtime_state=resolved_state)
yield self._message_end_to_stream_response()
@@ -772,13 +770,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if self._conversation_name_generate_thread:
logger.debug("Conversation name generation running as daemon thread")
def _save_message(
self,
*,
session: Session,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
):
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
message = self._get_message(session=session)
# If there are assistant files, remove markdown image links from answer
@@ -817,14 +809,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
metadata = self._task_state.metadata.model_dump()
message.message_metadata = json.dumps(jsonable_encoder(metadata))
# Extract model provider and model_id from workflow node executions for tracing
if message.workflow_run_id:
model_info = self._extract_model_info_from_workflow(session, message.workflow_run_id)
if model_info:
message.model_provider = model_info.get("provider")
message.model_id = model_info.get("model")
message_files = [
MessageFile(
message_id=message.id,
@@ -842,68 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
]
session.add_all(message_files)
# Trigger MESSAGE_TRACE for tracing integrations
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
)
)
def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None:
"""
Extract model provider and model_id from workflow node executions.
Returns dict with 'provider' and 'model' keys, or None if not found.
"""
try:
# Query workflow node executions for LLM or Agent nodes
stmt = (
select(WorkflowNodeExecutionModel)
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.where(WorkflowNodeExecutionModel.node_type.in_(["llm", "agent"]))
.order_by(WorkflowNodeExecutionModel.created_at.desc())
.limit(1)
)
node_execution = session.scalar(stmt)
if not node_execution:
return None
# Try to extract from execution_metadata for agent nodes
if node_execution.execution_metadata:
try:
metadata = json.loads(node_execution.execution_metadata)
agent_log = metadata.get("agent_log", [])
# Look for the first agent thought with provider info
for log_entry in agent_log:
entry_metadata = log_entry.get("metadata", {})
provider_str = entry_metadata.get("provider")
if provider_str:
# Parse format like "langgenius/deepseek/deepseek"
parts = provider_str.split("/")
if len(parts) >= 3:
return {"provider": parts[1], "model": parts[2]}
elif len(parts) == 2:
return {"provider": parts[0], "model": parts[1]}
except (json.JSONDecodeError, KeyError, AttributeError) as e:
logger.debug("Failed to parse execution_metadata: %s", e)
# Try to extract from process_data for llm nodes
if node_execution.process_data:
try:
process_data = json.loads(node_execution.process_data)
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider and model:
return {"provider": provider, "model": model}
except (json.JSONDecodeError, KeyError) as e:
logger.debug("Failed to parse process_data: %s", e)
return None
except Exception as e:
logger.warning("Failed to extract model info from workflow: %s", e)
return None
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state

View File

@@ -83,6 +83,7 @@ class AppRunner:
context: str | None = None,
memory: TokenBufferMemory | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
) -> tuple[list[PromptMessage], list[str] | None]:
"""
Organize prompt messages
@@ -111,6 +112,7 @@ class AppRunner:
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
context_files=context_files,
)
else:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))

View File

@@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
@@ -146,6 +147,7 @@ class ChatAppRunner(AppRunner):
# get context from datasets
context = None
context_files: list[File] = []
if app_config.dataset and app_config.dataset.dataset_ids:
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
@@ -156,7 +158,7 @@ class ChatAppRunner(AppRunner):
)
dataset_retrieval = DatasetRetrieval(application_generate_entity)
context = dataset_retrieval.retrieve(
context, retrieved_files = dataset_retrieval.retrieve(
app_id=app_record.id,
user_id=application_generate_entity.user_id,
tenant_id=app_record.tenant_id,
@@ -171,7 +173,11 @@ class ChatAppRunner(AppRunner):
memory=memory,
message_id=message.id,
inputs=inputs,
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
"enabled", False
),
)
context_files = retrieved_files or []
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
@@ -186,6 +192,7 @@ class ChatAppRunner(AppRunner):
context=context,
memory=memory,
image_detail_config=image_detail_config,
context_files=context_files,
)
# check hosting moderation

View File

@@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
CompletionAppGenerateEntity,
)
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file import File
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.moderation.base import ModerationError
@@ -102,6 +103,7 @@ class CompletionAppRunner(AppRunner):
# get context from datasets
context = None
context_files: list[File] = []
if app_config.dataset and app_config.dataset.dataset_ids:
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
@@ -116,7 +118,7 @@ class CompletionAppRunner(AppRunner):
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
dataset_retrieval = DatasetRetrieval(application_generate_entity)
context = dataset_retrieval.retrieve(
context, retrieved_files = dataset_retrieval.retrieve(
app_id=app_record.id,
user_id=application_generate_entity.user_id,
tenant_id=app_record.tenant_id,
@@ -130,7 +132,11 @@ class CompletionAppRunner(AppRunner):
hit_callback=hit_callback,
message_id=message.id,
inputs=inputs,
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
"enabled", False
),
)
context_files = retrieved_files or []
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
@@ -144,6 +150,7 @@ class CompletionAppRunner(AppRunner):
query=query,
context=context,
image_detail_config=image_detail_config,
context_files=context_files,
)
# check hosting moderation

View File

@@ -40,9 +40,6 @@ class EasyUITaskState(TaskState):
"""
llm_result: LLMResult
first_token_time: float | None = None
last_token_time: float | None = None
is_streaming_response: bool = False
class WorkflowTaskState(TaskState):

View File

@@ -332,12 +332,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if not self._task_state.llm_result.prompt_messages:
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
# Track streaming response times
if self._task_state.first_token_time is None:
self._task_state.first_token_time = time.perf_counter()
self._task_state.is_streaming_response = True
self._task_state.last_token_time = time.perf_counter()
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
if should_direct_answer:
@@ -404,18 +398,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.llm_result.usage.latency = message.provider_response_latency
# Add streaming metrics to usage if available
if self._task_state.is_streaming_response and self._task_state.first_token_time:
start_time = self.start_at
first_token_time = self._task_state.first_token_time
last_token_time = self._task_state.last_token_time or first_token_time
usage.time_to_first_token = round(first_token_time - start_time, 3)
usage.time_to_generate = round(last_token_time - first_token_time, 3)
# Update metadata with the complete usage info
self._task_state.metadata.usage = usage
message.message_metadata = self._task_state.metadata.model_dump_json()
if trace_manager:

View File

@@ -7,7 +7,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
@@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler:
document_id,
)
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field
class PreviewDetail(BaseModel):
@@ -20,7 +20,7 @@ class IndexingEstimate(BaseModel):
class PipelineDataset(BaseModel):
id: str
name: str
description: str
description: str | None = Field(default="", description="knowledge dataset description")
chunk_structure: str

View File

@@ -213,12 +213,23 @@ class MCPProviderEntity(BaseModel):
return None
def retrieve_tokens(self) -> OAuthTokens | None:
"""OAuth tokens if available"""
"""Retrieve OAuth tokens if authentication is complete.
Returns:
OAuthTokens if the provider has been authenticated, None otherwise.
"""
if not self.credentials:
return None
credentials = self.decrypt_credentials()
access_token = credentials.get("access_token", "")
# Return None if access_token is empty to avoid generating invalid "Authorization: Bearer " header.
# Note: We don't check for whitespace-only strings here because:
# 1. OAuth servers don't return whitespace-only access tokens in practice
# 2. Even if they did, the server would return 401, triggering the OAuth flow correctly
if not access_token:
return None
return OAuthTokens(
access_token=credentials.get("access_token", ""),
access_token=access_token,
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
refresh_token=credentials.get("refresh_token", ""),

View File

@@ -7,7 +7,7 @@ import time
import uuid
from typing import Any
from flask import current_app
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm.exc import ObjectDeletedError
@@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
@@ -36,6 +36,7 @@ from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import helper
from libs.datetime_utils import naive_utc_now
from models import Account
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
@@ -89,8 +90,17 @@ class IndexingRunner:
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
if not current_user:
raise ValueError("no current user found")
current_user.set_tenant_id(dataset.tenant_id)
documents = self._transform(
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
index_processor,
dataset,
text_docs,
requeried_document.doc_language,
processing_rule.to_dict(),
current_user=current_user,
)
# save segment
self._load_segments(dataset, requeried_document, documents)
@@ -136,7 +146,7 @@ class IndexingRunner:
for document_segment in document_segments:
db.session.delete(document_segment)
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
@@ -152,8 +162,17 @@ class IndexingRunner:
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
if not current_user:
raise ValueError("no current user found")
current_user.set_tenant_id(dataset.tenant_id)
documents = self._transform(
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
index_processor,
dataset,
text_docs,
requeried_document.doc_language,
processing_rule.to_dict(),
current_user=current_user,
)
# save segment
self._load_segments(dataset, requeried_document, documents)
@@ -209,7 +228,7 @@ class IndexingRunner:
"dataset_id": document_segment.dataset_id,
},
)
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = document_segment.get_child_chunks()
if child_chunks:
child_documents = []
@@ -302,6 +321,7 @@ class IndexingRunner:
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
documents = index_processor.transform(
text_docs,
current_user=None,
embedding_model_instance=embedding_model_instance,
process_rule=processing_rule.to_dict(),
tenant_id=tenant_id,
@@ -551,7 +571,10 @@ class IndexingRunner:
indexing_start_at = time.perf_counter()
tokens = 0
create_keyword_thread = None
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
if (
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
and dataset.indexing_technique == "economy"
):
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
@@ -590,7 +613,7 @@ class IndexingRunner:
for future in futures:
tokens += future.result()
if (
dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
and dataset.indexing_technique == "economy"
and create_keyword_thread is not None
):
@@ -635,7 +658,13 @@ class IndexingRunner:
db.session.commit()
def _process_chunk(
self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance
self,
flask_app: Flask,
index_processor: BaseIndexProcessor,
chunk_documents: list[Document],
dataset: Dataset,
dataset_document: DatasetDocument,
embedding_model_instance: ModelInstance | None,
):
with flask_app.app_context():
# check document is paused
@@ -646,8 +675,15 @@ class IndexingRunner:
page_content_list = [document.page_content for document in chunk_documents]
tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list))
multimodal_documents = []
for document in chunk_documents:
if document.attachments and dataset.is_multimodal:
multimodal_documents.extend(document.attachments)
# load index
index_processor.load(dataset, chunk_documents, with_keywords=False)
index_processor.load(
dataset, chunk_documents, multimodal_documents=multimodal_documents, with_keywords=False
)
document_ids = [document.metadata["doc_id"] for document in chunk_documents]
db.session.query(DocumentSegment).where(
@@ -710,6 +746,7 @@ class IndexingRunner:
text_docs: list[Document],
doc_language: str,
process_rule: dict,
current_user: Account | None = None,
) -> list[Document]:
# get embedding model instance
embedding_model_instance = None
@@ -729,6 +766,7 @@ class IndexingRunner:
documents = index_processor.transform(
text_docs,
current_user,
embedding_model_instance=embedding_model_instance,
process_rule=process_rule,
tenant_id=dataset.tenant_id,
@@ -737,14 +775,16 @@ class IndexingRunner:
return documents
def _load_segments(self, dataset, dataset_document, documents):
def _load_segments(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]):
# save node to document segment
doc_store = DatasetDocumentStore(
dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
)
# add document segments
doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX)
doc_store.add_documents(
docs=documents, save_child=dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX
)
# update document status to indexing
cur_time = naive_utc_now()

View File

@@ -554,11 +554,16 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_raw = cast(str, response.message.content)
generated_raw = response.message.get_text_content()
first_brace = generated_raw.find("{")
last_brace = generated_raw.rfind("}")
return {**json.loads(generated_raw[first_brace : last_brace + 1])}
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
json_str = generated_raw[first_brace : last_brace + 1]
data = json_repair.loads(json_str)
if not isinstance(data, dict):
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
return data
except InvokeError as e:
error = str(e)
return {"error": f"Failed to generate code. Error: {error}"}

View File

@@ -10,9 +10,9 @@ from core.errors.error import ProviderTokenNotInitError
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
@@ -200,7 +200,7 @@ class ModelInstance:
def invoke_text_embedding(
self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
) -> TextEmbeddingResult:
) -> EmbeddingResult:
"""
Invoke large language model
@@ -212,7 +212,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
TextEmbeddingResult,
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
@@ -223,6 +223,34 @@ class ModelInstance:
),
)
def invoke_multimodal_embedding(
self,
multimodel_documents: list[dict],
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> EmbeddingResult:
"""
Invoke large language model
:param multimodel_documents: multimodel documents to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
user=user,
input_type=input_type,
),
)
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
"""
Get number of tokens for text embedding
@@ -276,6 +304,40 @@ class ModelInstance:
),
)
def invoke_multimodal_rerank(
self,
query: dict,
docs: list[dict],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke rerank model
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return cast(
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
model=self.model,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user,
),
)
def invoke_moderation(self, text: str, user: str | None = None) -> bool:
"""
Invoke moderation model
@@ -461,6 +523,32 @@ class ModelManager:
model=default_model_entity.model,
)
def check_model_support_vision(self, tenant_id: str, provider: str, model: str, model_type: ModelType) -> bool:
"""
Check if model supports vision
:param tenant_id: tenant id
:param provider: provider name
:param model: model name
:return: True if model supports vision, False otherwise
"""
model_instance = self.get_model_instance(tenant_id, provider, model_type, model)
model_type_instance = model_instance.model_type_instance
match model_type:
case ModelType.LLM:
model_type_instance = cast(LargeLanguageModel, model_type_instance)
case ModelType.TEXT_EMBEDDING:
model_type_instance = cast(TextEmbeddingModel, model_type_instance)
case ModelType.RERANK:
model_type_instance = cast(RerankModel, model_type_instance)
case _:
raise ValueError(f"Model type {model_type} is not supported")
model_schema = model_type_instance.get_model_schema(model, model_instance.credentials)
if not model_schema:
return False
if model_schema.features and ModelFeature.VISION in model_schema.features:
return True
return False
class LBModelManager:
def __init__(

View File

@@ -19,7 +19,7 @@ class EmbeddingUsage(ModelUsage):
latency: float
class TextEmbeddingResult(BaseModel):
class EmbeddingResult(BaseModel):
"""
Model class for text embedding result.
"""
@@ -27,3 +27,13 @@ class TextEmbeddingResult(BaseModel):
model: str
embeddings: list[list[float]]
usage: EmbeddingUsage
class FileEmbeddingResult(BaseModel):
"""
Model class for file embedding result.
"""
model: str
embeddings: list[list[float]]
usage: EmbeddingUsage

View File

@@ -50,3 +50,43 @@ class RerankModel(AIModel):
)
except Exception as e:
raise self._transform_invoke_error(e)
def invoke_multimodal_rerank(
self,
model: str,
credentials: dict,
query: dict,
docs: list[dict],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke multimodal rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_multimodal_rerank(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
except Exception as e:
raise self._transform_invoke_error(e)

View File

@@ -2,7 +2,7 @@ from pydantic import ConfigDict
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
from core.model_runtime.model_providers.__base.ai_model import AIModel
@@ -20,16 +20,18 @@ class TextEmbeddingModel(AIModel):
self,
model: str,
credentials: dict,
texts: list[str],
texts: list[str] | None = None,
multimodel_documents: list[dict] | None = None,
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
) -> EmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param files: files to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
@@ -38,16 +40,29 @@ class TextEmbeddingModel(AIModel):
try:
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_text_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
texts=texts,
input_type=input_type,
)
if texts:
return plugin_model_manager.invoke_text_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
texts=texts,
input_type=input_type,
)
if multimodel_documents:
return plugin_model_manager.invoke_multimodal_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
documents=multimodel_documents,
input_type=input_type,
)
raise ValueError("No texts or files provided")
except Exception as e:
raise self._transform_invoke_error(e)

View File

@@ -222,59 +222,6 @@ class TencentSpanBuilder:
links=links,
)
@staticmethod
def build_message_llm_span(
trace_info: MessageTraceInfo, trace_id: int, parent_span_id: int, user_id: str
) -> SpanData:
"""Build LLM span for message traces with detailed LLM attributes."""
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
# Extract model information from `metadata`` or `message_data`
trace_metadata = trace_info.metadata or {}
message_data = trace_info.message_data or {}
model_provider = trace_metadata.get("ls_provider") or (
message_data.get("model_provider", "") if isinstance(message_data, dict) else ""
)
model_name = trace_metadata.get("ls_model_name") or (
message_data.get("model_id", "") if isinstance(message_data, dict) else ""
)
inputs_str = str(trace_info.inputs or "")
outputs_str = str(trace_info.outputs or "")
attributes = {
GEN_AI_SESSION_ID: trace_metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: str(model_name),
GEN_AI_PROVIDER: str(model_provider),
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens or 0),
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens or 0),
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens or 0),
GEN_AI_PROMPT: inputs_str,
GEN_AI_COMPLETION: outputs_str,
INPUT_VALUE: inputs_str,
OUTPUT_VALUE: outputs_str,
}
if trace_info.is_streaming_request:
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
return SpanData(
trace_id=trace_id,
parent_span_id=parent_span_id,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "llm"),
name="GENERATION",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes=attributes,
status=status,
)
@staticmethod
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
"""Build tool span."""

View File

@@ -107,12 +107,8 @@ class TencentDataTrace(BaseTraceInstance):
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
self.trace_client.add_span(message_span)
# Add LLM child span with detailed attributes
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
llm_span = TencentSpanBuilder.build_message_llm_span(trace_info, trace_id, parent_span_id, str(user_id))
self.trace_client.add_span(llm_span)
self.trace_client.add_span(message_span)
self._record_message_llm_metrics(trace_info)

View File

@@ -6,7 +6,7 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
@@ -243,14 +243,14 @@ class PluginModelClient(BasePluginClient):
credentials: dict,
texts: list[str],
input_type: str,
) -> TextEmbeddingResult:
) -> EmbeddingResult:
"""
Invoke text embedding
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
type_=TextEmbeddingResult,
type_=EmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
@@ -275,6 +275,48 @@ class PluginModelClient(BasePluginClient):
raise ValueError("Failed to invoke text embedding")
def invoke_multimodal_embedding(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
documents: list[dict],
input_type: str,
) -> EmbeddingResult:
"""
Invoke file embedding
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke",
type_=EmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
"provider": provider,
"model_type": "text-embedding",
"model": model,
"credentials": credentials,
"documents": documents,
"input_type": input_type,
},
}
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("Failed to invoke file embedding")
def get_text_embedding_num_tokens(
self,
tenant_id: str,
@@ -361,6 +403,51 @@ class PluginModelClient(BasePluginClient):
raise ValueError("Failed to invoke rerank")
def invoke_multimodal_rerank(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
query: dict,
docs: list[dict],
score_threshold: float | None = None,
top_n: int | None = None,
) -> RerankResult:
"""
Invoke multimodal rerank
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke",
type_=RerankResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
"provider": provider,
"model_type": "rerank",
"model": model,
"credentials": credentials,
"query": query,
"docs": docs,
"score_threshold": score_threshold,
"top_n": top_n,
},
}
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("Failed to invoke multimodal rerank")
def invoke_tts(
self,
tenant_id: str,

View File

@@ -49,6 +49,7 @@ class SimplePromptTransform(PromptTransform):
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
) -> tuple[list[PromptMessage], list[str] | None]:
inputs = {key: str(value) for key, value in inputs.items()}
@@ -64,6 +65,7 @@ class SimplePromptTransform(PromptTransform):
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
context_files=context_files,
)
else:
prompt_messages, stops = self._get_completion_model_prompt_messages(
@@ -76,6 +78,7 @@ class SimplePromptTransform(PromptTransform):
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
context_files=context_files,
)
return prompt_messages, stops
@@ -187,6 +190,7 @@ class SimplePromptTransform(PromptTransform):
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
) -> tuple[list[PromptMessage], list[str] | None]:
prompt_messages: list[PromptMessage] = []
@@ -216,9 +220,9 @@ class SimplePromptTransform(PromptTransform):
)
if query:
prompt_messages.append(self._get_last_user_message(query, files, image_detail_config))
prompt_messages.append(self._get_last_user_message(query, files, image_detail_config, context_files))
else:
prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config))
prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config, context_files))
return prompt_messages, None
@@ -233,6 +237,7 @@ class SimplePromptTransform(PromptTransform):
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
) -> tuple[list[PromptMessage], list[str] | None]:
# get prompt
prompt, prompt_rules = self._get_prompt_str_and_rules(
@@ -275,20 +280,27 @@ class SimplePromptTransform(PromptTransform):
if stops is not None and len(stops) == 0:
stops = None
return [self._get_last_user_message(prompt, files, image_detail_config)], stops
return [self._get_last_user_message(prompt, files, image_detail_config, context_files)], stops
def _get_last_user_message(
self,
prompt: str,
files: Sequence["File"],
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
) -> UserPromptMessage:
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
if files:
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
if context_files:
for file in context_files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
if prompt_message_contents:
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
prompt_message = UserPromptMessage(content=prompt_message_contents)

View File

@@ -2,6 +2,7 @@ from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
from core.rag.rerank.rerank_base import BaseRerankRunner
@@ -30,9 +31,10 @@ class DataPostProcessor:
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
if self.rerank_runner:
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type)
if self.reorder_runner:
documents = self.reorder_runner.run(documents)

View File

@@ -1,23 +1,30 @@
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from configs import dify_config
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.signature import sign_upload_file
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
@@ -37,14 +44,15 @@ class RetrievalService:
retrieval_method: RetrievalMethod,
dataset_id: str,
query: str,
top_k: int,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
document_ids_filter: list[str] | None = None,
attachment_ids: list | None = None,
):
if not query:
if not query and not attachment_ids:
return []
dataset = cls._get_dataset(dataset_id)
if not dataset:
@@ -56,69 +64,52 @@ class RetrievalService:
# Optimize multithreading with thread pools
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
futures = []
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH:
retrieval_service = RetrievalService()
if query:
futures.append(
executor.submit(
cls.keyword_search,
retrieval_service._retrieve,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
all_documents=all_documents,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
if RetrievalMethod.is_support_semantic_search(retrieval_method):
futures.append(
executor.submit(
cls.embedding_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
retrieval_method=retrieval_method,
dataset=dataset,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents,
retrieval_method=retrieval_method,
exceptions=exceptions,
reranking_mode=reranking_mode,
weights=weights,
document_ids_filter=document_ids_filter,
attachment_id=None,
all_documents=all_documents,
exceptions=exceptions,
)
)
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
futures.append(
executor.submit(
cls.full_text_index_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
if attachment_ids:
for attachment_id in attachment_ids:
futures.append(
executor.submit(
retrieval_service._retrieve,
flask_app=current_app._get_current_object(), # type: ignore
retrieval_method=retrieval_method,
dataset=dataset,
query=None,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
reranking_mode=reranking_mode,
weights=weights,
document_ids_filter=document_ids_filter,
attachment_id=attachment_id,
all_documents=all_documents,
exceptions=exceptions,
)
)
)
concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED)
concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED)
if exceptions:
raise ValueError(";\n".join(exceptions))
# Deduplicate documents for hybrid search to avoid duplicate chunks
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
all_documents = cls._deduplicate_documents(all_documents)
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k,
)
return all_documents
@classmethod
@@ -223,6 +214,7 @@ class RetrievalService:
retrieval_method: RetrievalMethod,
exceptions: list,
document_ids_filter: list[str] | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
):
with flask_app.app_context():
try:
@@ -231,14 +223,30 @@ class RetrievalService:
raise ValueError("dataset not found")
vector = Vector(dataset=dataset)
documents = vector.search_by_vector(
query,
search_type="similarity_score_threshold",
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
documents = []
if query_type == QueryType.TEXT_QUERY:
documents.extend(
vector.search_by_vector(
query,
search_type="similarity_score_threshold",
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
)
if query_type == QueryType.IMAGE_QUERY:
if not dataset.is_multimodal:
return
documents.extend(
vector.search_by_file(
file_id=query,
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
)
if documents:
if (
@@ -250,14 +258,37 @@ class RetrievalService:
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
if dataset.is_multimodal:
model_manager = ModelManager()
is_support_vision = model_manager.check_model_support_vision(
tenant_id=dataset.tenant_id,
provider=reranking_model.get("reranking_provider_name") or "",
model=reranking_model.get("reranking_model_name") or "",
model_type=ModelType.RERANK,
)
if is_support_vision:
all_documents.extend(
data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
query_type=query_type,
)
)
else:
# not effective, return original documents
all_documents.extend(documents)
else:
all_documents.extend(
data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
query_type=query_type,
)
)
)
else:
all_documents.extend(documents)
except Exception as e:
@@ -339,103 +370,161 @@ class RetrievalService:
records = []
include_segment_ids = set()
segment_child_map = {}
# Process documents
for document in documents:
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# Handle parent-child documents
child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = db.session.scalar(child_chunk_stmt)
if not child_chunk:
segment_file_map = {}
with Session(bind=db.engine, expire_on_commit=False) as session:
# Process documents
for document in documents:
segment_id = None
attachment_info = None
child_chunk = None
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
segment = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == child_chunk.segment_id,
)
.options(
load_only(
DocumentSegment.id,
DocumentSegment.content,
DocumentSegment.answer,
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
# Handle parent-child documents
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attachment_info"]
segment_id = attachment_info_dict["segment_id"]
else:
child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = session.scalar(child_chunk_stmt)
if not child_chunk:
continue
segment_id = child_chunk.segment_id
if not segment_id:
continue
segment = (
session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
.first()
)
.first()
)
if not segment:
continue
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
if child_chunk:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
records.append(record)
else:
if child_chunk:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
if segment.id in segment_child_map:
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
segment_child_map[segment.id] = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
if attachment_info:
if segment.id in segment_file_map:
segment_file_map[segment.id].append(attachment_info)
else:
segment_file_map[segment.id] = [attachment_info]
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
# Handle normal documents
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = db.session.scalar(document_segment_stmt)
# Handle normal documents
segment = None
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attachment_info"]
segment_id = attachment_info_dict["segment_id"]
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
segment = session.scalar(document_segment_stmt)
if segment:
segment_file_map[segment.id] = [attachment_info]
else:
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = session.scalar(document_segment_stmt)
if not segment:
continue
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score"), # type: ignore
}
records.append(record)
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score"), # type: ignore
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
records.append(record)
else:
if attachment_info:
attachment_infos = segment_file_map.get(segment.id, [])
if attachment_info not in attachment_infos:
attachment_infos.append(attachment_info)
segment_file_map[segment.id] = attachment_infos
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"]
if record["segment"].id in segment_file_map:
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
result = []
for record in records:
@@ -447,6 +536,11 @@ class RetrievalService:
if not isinstance(child_chunks, list):
child_chunks = None
# Extract files, ensuring it's a list or None
files = record.get("files")
if not isinstance(files, list):
files = None
# Extract score, ensuring it's a float or None
score_value = record.get("score")
score = (
@@ -456,10 +550,149 @@ class RetrievalService:
)
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score)
retrieval_segment = RetrievalSegments(
segment=segment, child_chunks=child_chunks, score=score, files=files
)
result.append(retrieval_segment)
return result
except Exception as e:
db.session.rollback()
raise e
def _retrieve(
self,
flask_app: Flask,
retrieval_method: RetrievalMethod,
dataset: Dataset,
query: str | None = None,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
document_ids_filter: list[str] | None = None,
attachment_id: str | None = None,
all_documents: list[Document] = [],
exceptions: list[str] = [],
):
if not query and not attachment_id:
return
with flask_app.app_context():
all_documents_item: list[Document] = []
# Optimize multithreading with thread pools
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
futures = []
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH and query:
futures.append(
executor.submit(
self.keyword_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=query,
top_k=top_k,
all_documents=all_documents_item,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
if RetrievalMethod.is_support_semantic_search(retrieval_method):
if query:
futures.append(
executor.submit(
self.embedding_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents_item,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
query_type=QueryType.TEXT_QUERY,
)
)
if attachment_id:
futures.append(
executor.submit(
self.embedding_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=attachment_id,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents_item,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
query_type=QueryType.IMAGE_QUERY,
)
)
if RetrievalMethod.is_support_fulltext_search(retrieval_method) and query:
futures.append(
executor.submit(
self.full_text_index_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents_item,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED)
if exceptions:
raise ValueError(";\n".join(exceptions))
# Deduplicate documents for hybrid search to avoid duplicate chunks
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
if attachment_id and reranking_mode == RerankMode.WEIGHTED_SCORE:
all_documents.extend(all_documents_item)
all_documents_item = self._deduplicate_documents(all_documents_item)
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
query = query or attachment_id
if not query:
return
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY,
)
all_documents.extend(all_documents_item)
@classmethod
def get_segment_attachment_info(
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
) -> dict[str, Any] | None:
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
if upload_file:
attachment_binding = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id == upload_file.id)
.first()
)
if attachment_binding:
attachment_info = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
"mime_type": upload_file.mime_type,
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
"size": upload_file.size,
}
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
return None

View File

@@ -1,3 +1,4 @@
import base64
import logging
import time
from abc import ABC, abstractmethod
@@ -12,10 +13,13 @@ from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from models.dataset import Dataset, Whitelist
from models.model import UploadFile
logger = logging.getLogger(__name__)
@@ -203,6 +207,47 @@ class Vector:
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
def create_multimodal(self, file_documents: list | None = None, **kwargs):
if file_documents:
start = time.time()
logger.info("start embedding %s files %s", len(file_documents), start)
batch_size = 1000
total_batches = len(file_documents) + batch_size - 1
for i in range(0, len(file_documents), batch_size):
batch = file_documents[i : i + batch_size]
batch_start = time.time()
logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch))
# Batch query all upload files to avoid N+1 queries
attachment_ids = [doc.metadata["doc_id"] for doc in batch]
stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids))
upload_files = db.session.scalars(stmt).all()
upload_file_map = {str(f.id): f for f in upload_files}
file_base64_list = []
real_batch = []
for document in batch:
attachment_id = document.metadata["doc_id"]
doc_type = document.metadata["doc_type"]
upload_file = upload_file_map.get(attachment_id)
if upload_file:
blob = storage.load_once(upload_file.key)
file_base64_str = base64.b64encode(blob).decode()
file_base64_list.append(
{
"content": file_base64_str,
"content_type": doc_type,
"file_id": attachment_id,
}
)
real_batch.append(document)
batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list)
logger.info(
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
)
self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs)
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get("duplicate_check", False):
documents = self._filter_duplicate_texts(documents)
@@ -223,6 +268,22 @@ class Vector:
query_vector = self._embeddings.embed_query(query)
return self._vector_processor.search_by_vector(query_vector, **kwargs)
def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]:
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
return []
blob = storage.load_once(upload_file.key)
file_base64_str = base64.b64encode(blob).decode()
multimodal_vector = self._embeddings.embed_multimodal_query(
{
"content": file_base64_str,
"content_type": DocType.IMAGE,
"file_id": file_id,
}
)
return self._vector_processor.search_by_vector(multimodal_vector, **kwargs)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)

View File

@@ -5,9 +5,9 @@ from sqlalchemy import func, select
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.models.document import Document
from core.rag.models.document import AttachmentDocument, Document
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
class DatasetDocumentStore:
@@ -120,6 +120,9 @@ class DatasetDocumentStore:
db.session.add(segment_document)
db.session.flush()
self.add_multimodel_documents_binding(
segment_id=segment_document.id, multimodel_documents=doc.attachments
)
if save_child:
if doc.children:
for position, child in enumerate(doc.children, start=1):
@@ -144,6 +147,9 @@ class DatasetDocumentStore:
segment_document.index_node_hash = doc.metadata.get("doc_hash")
segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens
self.add_multimodel_documents_binding(
segment_id=segment_document.id, multimodel_documents=doc.attachments
)
if save_child and doc.children:
# delete the existing child chunks
db.session.query(ChildChunk).where(
@@ -233,3 +239,15 @@ class DatasetDocumentStore:
document_segment = db.session.scalar(stmt)
return document_segment
def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None):
if multimodel_documents:
for multimodel_document in multimodel_documents:
binding = SegmentAttachmentBinding(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_id,
attachment_id=multimodel_document.metadata["doc_id"],
)
db.session.add(binding)

View File

@@ -104,6 +104,88 @@ class CacheEmbedding(Embeddings):
return text_embeddings
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
"""Embed file documents."""
# use doc embedding cache or store if not exists
multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))]
embedding_queue_indices = []
for i, multimodel_document in enumerate(multimodel_documents):
file_id = multimodel_document["file_id"]
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider
)
.first()
)
if embedding:
multimodel_embeddings[i] = embedding.get_embedding()
else:
embedding_queue_indices.append(i)
# NOTE: avoid closing the shared scoped session here; downstream code may still have pending work
if embedding_queue_indices:
embedding_queue_multimodel_documents = [multimodel_documents[i] for i in embedding_queue_indices]
embedding_queue_embeddings = []
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(
self._model_instance.model, self._model_instance.credentials
)
max_chunks = (
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
else 1
)
for i in range(0, len(embedding_queue_multimodel_documents), max_chunks):
batch_multimodel_documents = embedding_queue_multimodel_documents[i : i + max_chunks]
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=batch_multimodel_documents,
user=self._user,
input_type=EmbeddingInputType.DOCUMENT,
)
for vector in embedding_result.embeddings:
try:
# FIXME: type ignore for numpy here
normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
if np.isnan(normalized_embedding).any():
# for issue #11827 float values are not json compliant
logger.warning("Normalized embedding is nan: %s", normalized_embedding)
continue
embedding_queue_embeddings.append(normalized_embedding)
except IntegrityError:
db.session.rollback()
except Exception:
logger.exception("Failed transform embedding")
cache_embeddings = []
try:
for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
multimodel_embeddings[i] = n_embedding
file_id = multimodel_documents[i]["file_id"]
if file_id not in cache_embeddings:
embedding_cache = Embedding(
model_name=self._model_instance.model,
hash=file_id,
provider_name=self._model_instance.provider,
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
)
embedding_cache.set_embedding(n_embedding)
db.session.add(embedding_cache)
cache_embeddings.append(file_id)
db.session.commit()
except IntegrityError:
db.session.rollback()
except Exception as ex:
db.session.rollback()
logger.exception("Failed to embed documents")
raise ex
return multimodel_embeddings
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
@@ -146,3 +228,46 @@ class CacheEmbedding(Embeddings):
raise ex
return embedding_results # type: ignore
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
"""Embed multimodal documents."""
# use doc embedding cache or store if not exists
file_id = multimodel_document["file_id"]
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}"
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)
decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float")
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY
)
embedding_results = embedding_result.embeddings[0]
# FIXME: type ignore for numpy here
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore
if np.isnan(embedding_results).any():
raise ValueError("Normalized embedding is nan please try again")
except Exception as ex:
if dify_config.DEBUG:
logger.exception("Failed to embed multimodal document '%s'", multimodel_document["file_id"])
raise ex
try:
# encode embedding to base64
embedding_vector = np.array(embedding_results)
vector_bytes = embedding_vector.tobytes()
# Transform to Base64
encoded_vector = base64.b64encode(vector_bytes)
# Transform to string
encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 600, encoded_str)
except Exception as ex:
if dify_config.DEBUG:
logger.exception(
"Failed to add embedding to redis for the multimodal document '%s'", multimodel_document["file_id"]
)
raise ex
return embedding_results # type: ignore

View File

@@ -9,11 +9,21 @@ class Embeddings(ABC):
"""Embed search docs."""
raise NotImplementedError
@abstractmethod
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
"""Embed file documents."""
raise NotImplementedError
@abstractmethod
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
raise NotImplementedError
@abstractmethod
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
"""Embed multimodal query."""
raise NotImplementedError
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronous Embed search docs."""
raise NotImplementedError

View File

@@ -19,3 +19,4 @@ class RetrievalSegments(BaseModel):
segment: DocumentSegment
child_chunks: list[RetrievalChildChunk] | None = None
score: float | None = None
files: list[dict[str, str | int]] | None = None

View File

@@ -21,3 +21,4 @@ class RetrievalSourceMetadata(BaseModel):
page: int | None = None
doc_metadata: dict[str, Any] | None = None
title: str | None = None
files: list[dict[str, Any]] | None = None

View File

@@ -1,7 +1,7 @@
"""Abstract interface for document loader implementations."""
import os
from typing import cast
from typing import TypedDict
import pandas as pd
from openpyxl import load_workbook
@@ -10,6 +10,12 @@ from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
class Candidate(TypedDict):
idx: int
count: int
map: dict[int, str]
class ExcelExtractor(BaseExtractor):
"""Load Excel files.
@@ -30,32 +36,38 @@ class ExcelExtractor(BaseExtractor):
file_extension = os.path.splitext(self._file_path)[-1].lower()
if file_extension == ".xlsx":
wb = load_workbook(self._file_path, data_only=True)
for sheet_name in wb.sheetnames:
sheet = wb[sheet_name]
data = sheet.values
cols = next(data, None)
if cols is None:
continue
df = pd.DataFrame(data, columns=cols)
df.dropna(how="all", inplace=True)
for index, row in df.iterrows():
page_content = []
for col_index, (k, v) in enumerate(row.items()):
if pd.notna(v):
cell = sheet.cell(
row=cast(int, index) + 2, column=col_index + 1
) # +2 to account for header and 1-based index
if cell.hyperlink:
value = f"[{v}]({cell.hyperlink.target})"
page_content.append(f'"{k}":"{value}"')
else:
page_content.append(f'"{k}":"{v}"')
documents.append(
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
)
wb = load_workbook(self._file_path, read_only=True, data_only=True)
try:
for sheet_name in wb.sheetnames:
sheet = wb[sheet_name]
header_row_idx, column_map, max_col_idx = self._find_header_and_columns(sheet)
if not column_map:
continue
start_row = header_row_idx + 1
for row in sheet.iter_rows(min_row=start_row, max_col=max_col_idx, values_only=False):
if all(cell.value is None for cell in row):
continue
page_content = []
for col_idx, cell in enumerate(row):
value = cell.value
if col_idx in column_map:
col_name = column_map[col_idx]
if hasattr(cell, "hyperlink") and cell.hyperlink:
target = getattr(cell.hyperlink, "target", None)
if target:
value = f"[{value}]({target})"
if value is None:
value = ""
elif not isinstance(value, str):
value = str(value)
value = value.strip().replace('"', '\\"')
page_content.append(f'"{col_name}":"{value}"')
if page_content:
documents.append(
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
)
finally:
wb.close()
elif file_extension == ".xls":
excel_file = pd.ExcelFile(self._file_path, engine="xlrd")
@@ -63,9 +75,9 @@ class ExcelExtractor(BaseExtractor):
df = excel_file.parse(sheet_name=excel_sheet_name)
df.dropna(how="all", inplace=True)
for _, row in df.iterrows():
for _, series_row in df.iterrows():
page_content = []
for k, v in row.items():
for k, v in series_row.items():
if pd.notna(v):
page_content.append(f'"{k}":"{v}"')
documents.append(
@@ -75,3 +87,61 @@ class ExcelExtractor(BaseExtractor):
raise ValueError(f"Unsupported file extension: {file_extension}")
return documents
def _find_header_and_columns(self, sheet, scan_rows=10) -> tuple[int, dict[int, str], int]:
"""
Scan first N rows to find the most likely header row.
Returns:
header_row_idx: 1-based index of the header row
column_map: Dict mapping 0-based column index to column name
max_col_idx: 1-based index of the last valid column (for iter_rows boundary)
"""
# Store potential candidates: (row_index, non_empty_count, column_map)
candidates: list[Candidate] = []
# Limit scan to avoid performance issues on huge files
# We iterate manually to control the read scope
for current_row_idx, row in enumerate(sheet.iter_rows(min_row=1, max_row=scan_rows, values_only=True), start=1):
# Filter out empty cells and build a temp map for this row
# col_idx is 0-based
row_map = {}
for col_idx, cell_value in enumerate(row):
if cell_value is not None and str(cell_value).strip():
row_map[col_idx] = str(cell_value).strip().replace('"', '\\"')
if not row_map:
continue
non_empty_count = len(row_map)
# Header selection heuristic (implemented):
# - Prefer the first row with at least 2 non-empty columns.
# - Fallback: choose the row with the most non-empty columns
# (tie-breaker: smaller row index).
candidates.append({"idx": current_row_idx, "count": non_empty_count, "map": row_map})
if not candidates:
return 0, {}, 0
# Choose the best candidate header row.
best_candidate: Candidate | None = None
# Strategy: prefer the first row with >= 2 non-empty columns; otherwise fallback.
for cand in candidates:
if cand["count"] >= 2:
best_candidate = cand
break
# Fallback: if no row has >= 2 columns, or all have 1, just take the one with max columns
if not best_candidate:
# Sort by count desc, then index asc
candidates.sort(key=lambda x: (-x["count"], x["idx"]))
best_candidate = candidates[0]
# Determine max_col_idx (1-based for openpyxl)
# It is the index of the last valid column in our map + 1
max_col_idx = max(best_candidate["map"].keys()) + 1
return best_candidate["idx"], best_candidate["map"], max_col_idx

View File

@@ -84,22 +84,46 @@ class WordExtractor(BaseExtractor):
image_count = 0
image_map = {}
for rel in doc.part.rels.values():
for rId, rel in doc.part.rels.items():
if "image" in rel.target_ref:
image_count += 1
if rel.is_external:
url = rel.target_ref
response = ssrf_proxy.get(url)
if not self._is_valid_url(url):
continue
try:
response = ssrf_proxy.get(url)
except Exception as e:
logger.warning("Failed to download image from URL: %s: %s", url, str(e))
continue
if response.status_code == 200:
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
image_ext = mimetypes.guess_extension(response.headers.get("Content-Type", ""))
if image_ext is None:
continue
file_uuid = str(uuid.uuid4())
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + image_ext
mime_type, _ = mimetypes.guess_type(file_key)
storage.save(file_key, response.content)
else:
continue
# save file to db
upload_file = UploadFile(
tenant_id=self.tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=file_key,
size=0,
extension=str(image_ext),
mime_type=mime_type or "",
created_by=self.user_id,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=naive_utc_now(),
used=True,
used_by=self.user_id,
used_at=naive_utc_now(),
)
db.session.add(upload_file)
db.session.commit()
# Use rId as key for external images since target_part is undefined
image_map[rId] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
else:
image_ext = rel.target_ref.split(".")[-1]
if image_ext is None:
@@ -110,26 +134,28 @@ class WordExtractor(BaseExtractor):
mime_type, _ = mimetypes.guess_type(file_key)
storage.save(file_key, rel.target_part.blob)
# save file to db
upload_file = UploadFile(
tenant_id=self.tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=file_key,
size=0,
extension=str(image_ext),
mime_type=mime_type or "",
created_by=self.user_id,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=naive_utc_now(),
used=True,
used_by=self.user_id,
used_at=naive_utc_now(),
)
db.session.add(upload_file)
db.session.commit()
image_map[rel.target_part] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
# save file to db
upload_file = UploadFile(
tenant_id=self.tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=file_key,
size=0,
extension=str(image_ext),
mime_type=mime_type or "",
created_by=self.user_id,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=naive_utc_now(),
used=True,
used_by=self.user_id,
used_at=naive_utc_now(),
)
db.session.add(upload_file)
db.session.commit()
# Use target_part as key for internal images
image_map[rel.target_part] = (
f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
)
return image_map
@@ -186,11 +212,17 @@ class WordExtractor(BaseExtractor):
image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
if not image_id:
continue
image_part = paragraph.part.rels[image_id].target_part
if image_part in image_map:
image_link = image_map[image_part]
paragraph_content.append(image_link)
rel = paragraph.part.rels.get(image_id)
if rel is None:
continue
# For external images, use image_id as key; for internal, use target_part
if rel.is_external:
if image_id in image_map:
paragraph_content.append(image_map[image_id])
else:
image_part = rel.target_part
if image_part in image_map:
paragraph_content.append(image_map[image_part])
else:
paragraph_content.append(run.text)
return "".join(paragraph_content).strip()
@@ -227,6 +259,18 @@ class WordExtractor(BaseExtractor):
def parse_paragraph(paragraph):
paragraph_content = []
def append_image_link(image_id, has_drawing):
"""Helper to append image link from image_map based on relationship type."""
rel = doc.part.rels[image_id]
if rel.is_external:
if image_id in image_map and not has_drawing:
paragraph_content.append(image_map[image_id])
else:
image_part = rel.target_part
if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
for run in paragraph.runs:
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
# Process drawing type images
@@ -243,10 +287,18 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed"
)
if embed_id:
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
has_drawing = True
paragraph_content.append(image_map[image_part])
rel = doc.part.rels.get(embed_id)
if rel is not None and rel.is_external:
# External image: use embed_id as key
if embed_id in image_map:
has_drawing = True
paragraph_content.append(image_map[embed_id])
else:
# Internal image: use target_part as key
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
has_drawing = True
paragraph_content.append(image_map[image_part])
# Process pict type images
shape_elements = run.element.findall(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
@@ -261,9 +313,7 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
image_part = doc.part.rels[image_id].target_part
if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
append_image_link(image_id, has_drawing)
# Find imagedata element in VML
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
if image_data is not None:
@@ -271,9 +321,7 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
image_part = doc.part.rels[image_id].target_part
if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
append_image_link(image_id, has_drawing)
if run.text.strip():
paragraph_content.append(run.text.strip())
return "".join(paragraph_content) if paragraph_content else ""

View File

@@ -0,0 +1,6 @@
from enum import StrEnum
class DocType(StrEnum):
TEXT = "text"
IMAGE = "image"

View File

@@ -1,7 +1,12 @@
from enum import StrEnum
class IndexType(StrEnum):
class IndexStructureType(StrEnum):
PARAGRAPH_INDEX = "text_model"
QA_INDEX = "qa_model"
PARENT_CHILD_INDEX = "hierarchical_model"
class IndexTechniqueType(StrEnum):
ECONOMY = "economy"
HIGH_QUALITY = "high_quality"

View File

@@ -0,0 +1,6 @@
from enum import StrEnum
class QueryType(StrEnum):
TEXT_QUERY = "text_query"
IMAGE_QUERY = "image_query"

View File

@@ -1,20 +1,34 @@
"""Abstract interface for document loader implementations."""
import cgi
import logging
import mimetypes
import os
import re
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Optional
from urllib.parse import unquote, urlparse
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import AttachmentDocument, Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter,
)
from core.rag.splitter.text_splitter import TextSplitter
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import Account, ToolFile
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
if TYPE_CHECKING:
from core.model_manager import ModelInstance
@@ -28,11 +42,18 @@ class BaseIndexProcessor(ABC):
raise NotImplementedError
@abstractmethod
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
raise NotImplementedError
@abstractmethod
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
raise NotImplementedError
@abstractmethod
@@ -96,3 +117,178 @@ class BaseIndexProcessor(ABC):
)
return character_splitter # type: ignore
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
"""
Get the content files from the document.
"""
multi_model_documents: list[AttachmentDocument] = []
text = document.page_content
images = self._extract_markdown_images(text)
if not images:
return multi_model_documents
upload_file_id_list = []
for image in images:
# Collect all upload_file_ids including duplicates to preserve occurrence count
# For data before v0.10.0
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
match = re.search(pattern, image)
if match:
upload_file_id = match.group(1)
upload_file_id_list.append(upload_file_id)
continue
# For data after v0.10.0
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
match = re.search(pattern, image)
if match:
upload_file_id = match.group(1)
upload_file_id_list.append(upload_file_id)
continue
# For tools directory - direct file formats (e.g., .png, .jpg, etc.)
# Match URL including any query parameters up to common URL boundaries (space, parenthesis, quotes)
pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
match = re.search(pattern, image)
if match:
if current_user:
tool_file_id = match.group(1)
upload_file_id = self._download_tool_file(tool_file_id, current_user)
if upload_file_id:
upload_file_id_list.append(upload_file_id)
continue
if current_user:
upload_file_id = self._download_image(image.split(" ")[0], current_user)
if upload_file_id:
upload_file_id_list.append(upload_file_id)
if not upload_file_id_list:
return multi_model_documents
# Get unique IDs for database query
unique_upload_file_ids = list(set(upload_file_id_list))
upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all()
# Create a mapping from ID to UploadFile for quick lookup
upload_file_map = {upload_file.id: upload_file for upload_file in upload_files}
# Create a Document for each occurrence (including duplicates)
for upload_file_id in upload_file_id_list:
upload_file = upload_file_map.get(upload_file_id)
if upload_file:
multi_model_documents.append(
AttachmentDocument(
page_content=upload_file.name,
metadata={
"doc_id": upload_file.id,
"doc_hash": "",
"document_id": document.metadata.get("document_id"),
"dataset_id": document.metadata.get("dataset_id"),
"doc_type": DocType.IMAGE,
},
)
)
return multi_model_documents
def _extract_markdown_images(self, text: str) -> list[str]:
"""
Extract the markdown images from the text.
"""
pattern = r"!\[.*?\]\((.*?)\)"
return re.findall(pattern, text)
def _download_image(self, image_url: str, current_user: Account) -> str | None:
"""
Download the image from the URL.
Image size must not exceed 2MB.
"""
from services.file_service import FileService
MAX_IMAGE_SIZE = dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024
DOWNLOAD_TIMEOUT = dify_config.ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT
try:
# Download with timeout
response = ssrf_proxy.get(image_url, timeout=DOWNLOAD_TIMEOUT)
response.raise_for_status()
# Check Content-Length header if available
content_length = response.headers.get("Content-Length")
if content_length and int(content_length) > MAX_IMAGE_SIZE:
logging.warning("Image from %s exceeds 2MB limit (size: %s bytes)", image_url, content_length)
return None
filename = None
content_disposition = response.headers.get("content-disposition")
if content_disposition:
_, params = cgi.parse_header(content_disposition)
if "filename" in params:
filename = params["filename"]
filename = unquote(filename)
if not filename:
parsed_url = urlparse(image_url)
# unquote 处理 URL 中的中文
path = unquote(parsed_url.path)
filename = os.path.basename(path)
if not filename:
filename = "downloaded_image_file"
name, current_ext = os.path.splitext(filename)
content_type = response.headers.get("content-type", "").split(";")[0].strip()
real_ext = mimetypes.guess_extension(content_type)
if not current_ext and real_ext or current_ext in [".php", ".jsp", ".asp", ".html"] and real_ext:
filename = f"{name}{real_ext}"
# Download content with size limit
blob = b""
for chunk in response.iter_bytes(chunk_size=8192):
blob += chunk
if len(blob) > MAX_IMAGE_SIZE:
logging.warning("Image from %s exceeds 2MB limit during download", image_url)
return None
if not blob:
logging.warning("Image from %s is empty", image_url)
return None
upload_file = FileService(db.engine).upload_file(
filename=filename,
content=blob,
mimetype=content_type,
user=current_user,
)
return upload_file.id
except httpx.TimeoutException:
logging.warning("Timeout downloading image from %s after %s seconds", image_url, DOWNLOAD_TIMEOUT)
return None
except httpx.RequestError as e:
logging.warning("Error downloading image from %s: %s", image_url, str(e))
return None
except Exception:
logging.exception("Unexpected error downloading image from %s", image_url)
return None
def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None:
"""
Download the tool file from the ID.
"""
from services.file_service import FileService
tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
upload_file = FileService(db.engine).upload_file(
filename=tool_file.name,
content=blob,
mimetype=tool_file.mimetype,
user=current_user,
)
return upload_file.id

View File

@@ -1,6 +1,6 @@
"""Abstract interface for document loader implementations."""
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
@@ -19,11 +19,11 @@ class IndexProcessorFactory:
if not self._index_type:
raise ValueError("Index type must be specified.")
if self._index_type == IndexType.PARAGRAPH_INDEX:
if self._index_type == IndexStructureType.PARAGRAPH_INDEX:
return ParagraphIndexProcessor()
elif self._index_type == IndexType.QA_INDEX:
elif self._index_type == IndexStructureType.QA_INDEX:
return QAIndexProcessor()
elif self._index_type == IndexType.PARENT_CHILD_INDEX:
elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX:
return ParentChildIndexProcessor()
else:
raise ValueError(f"Index type {self._index_type} is not supported.")

View File

@@ -11,14 +11,17 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.account import Account
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule
@@ -33,7 +36,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
@@ -69,6 +72,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if document_node.metadata is not None:
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
multimodal_documents = (
self._get_content_files(document_node, current_user) if document_node.metadata else None
)
if multimodal_documents:
document_node.attachments = multimodal_documents
# delete Splitter character
page_content = remove_leading_symbols(document_node.page_content).strip()
if len(page_content) > 0:
@@ -77,10 +85,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
all_documents.extend(split_documents)
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
with_keywords = False
if with_keywords:
keywords_list = kwargs.get("keywords_list")
@@ -134,8 +151,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
return docs
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
documents: list[Any] = []
all_multimodal_documents: list[Any] = []
if isinstance(chunks, list):
documents = []
for content in chunks:
metadata = {
"dataset_id": dataset.id,
@@ -144,26 +162,68 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
"doc_hash": helper.generate_text_hash(content),
}
doc = Document(page_content=content, metadata=metadata)
attachments = self._get_content_files(doc)
if attachments:
doc.attachments = attachments
all_multimodal_documents.extend(attachments)
documents.append(doc)
if documents:
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
elif dataset.indexing_technique == "economy":
keyword = Keyword(dataset)
keyword.add_texts(documents)
else:
raise ValueError("Chunks is not a list")
multimodal_general_structure = MultimodalGeneralStructureChunk.model_validate(chunks)
for general_chunk in multimodal_general_structure.general_chunks:
metadata = {
"dataset_id": dataset.id,
"document_id": document.id,
"doc_id": str(uuid.uuid4()),
"doc_hash": helper.generate_text_hash(general_chunk.content),
}
doc = Document(page_content=general_chunk.content, metadata=metadata)
if general_chunk.files:
attachments = []
for file in general_chunk.files:
file_metadata = {
"doc_id": file.id,
"doc_hash": "",
"document_id": document.id,
"dataset_id": dataset.id,
"doc_type": DocType.IMAGE,
}
file_document = AttachmentDocument(
page_content=file.filename or "image_file", metadata=file_metadata
)
attachments.append(file_document)
all_multimodal_documents.append(file_document)
doc.attachments = attachments
else:
account = AccountService.load_user(document.created_by)
if not account:
raise ValueError("Invalid account")
doc.attachments = self._get_content_files(doc, current_user=account)
if doc.attachments:
all_multimodal_documents.extend(doc.attachments)
documents.append(doc)
if documents:
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if all_multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(all_multimodal_documents)
elif dataset.indexing_technique == "economy":
keyword = Keyword(dataset)
keyword.add_texts(documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
if isinstance(chunks, list):
preview = []
for content in chunks:
preview.append({"content": content})
return {"chunk_structure": IndexType.PARAGRAPH_INDEX, "preview": preview, "total_segments": len(chunks)}
return {
"chunk_structure": IndexStructureType.PARAGRAPH_INDEX,
"preview": preview,
"total_segments": len(chunks),
}
else:
raise ValueError("Chunks is not a list")

View File

@@ -13,14 +13,17 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from libs import helper
from models import Account
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
@@ -35,7 +38,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
@@ -77,6 +80,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
page_content = page_content
if len(page_content) > 0:
document_node.page_content = page_content
multimodel_documents = self._get_content_files(document_node, current_user)
if multimodel_documents:
document_node.attachments = multimodel_documents
# parse document to child nodes
child_nodes = self._split_child_nodes(
document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
@@ -87,6 +93,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
elif rules.parent_mode == ParentMode.FULL_DOC:
page_content = "\n".join([document.page_content for document in documents])
document = Document(page_content=page_content, metadata=documents[0].metadata)
multimodel_documents = self._get_content_files(document)
if multimodel_documents:
document.attachments = multimodel_documents
# parse document to child nodes
child_nodes = self._split_child_nodes(
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
@@ -104,7 +113,14 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
for document in documents:
@@ -114,6 +130,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
Document.model_validate(child_document.model_dump()) for child_document in child_documents
]
vector.create(formatted_child_documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# node_ids is segment's node_ids
@@ -244,6 +262,24 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
}
child_documents.append(ChildDocument(page_content=child, metadata=child_metadata))
doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents)
if parent_child.files and len(parent_child.files) > 0:
attachments = []
for file in parent_child.files:
file_metadata = {
"doc_id": file.id,
"doc_hash": "",
"document_id": document.id,
"dataset_id": dataset.id,
"doc_type": DocType.IMAGE,
}
file_document = AttachmentDocument(page_content=file.filename or "", metadata=file_metadata)
attachments.append(file_document)
doc.attachments = attachments
else:
account = AccountService.load_user(document.created_by)
if not account:
raise ValueError("Invalid account")
doc.attachments = self._get_content_files(doc, current_user=account)
documents.append(doc)
if documents:
# update document parent mode
@@ -267,12 +303,17 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
doc_store.add_documents(docs=documents, save_child=True)
if dataset.indexing_technique == "high_quality":
all_child_documents = []
all_multimodal_documents = []
for doc in documents:
if doc.children:
all_child_documents.extend(doc.children)
if doc.attachments:
all_multimodal_documents.extend(doc.attachments)
vector = Vector(dataset)
if all_child_documents:
vector = Vector(dataset)
vector.create(all_child_documents)
if all_multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(all_multimodal_documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
parent_childs = ParentChildStructureChunk.model_validate(chunks)
@@ -280,7 +321,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
for parent_child in parent_childs.parent_child_chunks:
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
return {
"chunk_structure": IndexType.PARENT_CHILD_INDEX,
"chunk_structure": IndexStructureType.PARENT_CHILD_INDEX,
"parent_mode": parent_childs.parent_mode,
"preview": preview,
"total_segments": len(parent_childs.parent_child_chunks),

View File

@@ -18,12 +18,13 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document, QAStructureChunk
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.account import Account
from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
@@ -41,7 +42,7 @@ class QAIndexProcessor(BaseIndexProcessor):
)
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
preview = kwargs.get("preview")
process_rule = kwargs.get("process_rule")
if not process_rule:
@@ -116,7 +117,7 @@ class QAIndexProcessor(BaseIndexProcessor):
try:
# Skip the first row
df = pd.read_csv(file)
df = pd.read_csv(file) # type: ignore
text_docs = []
for _, row in df.iterrows():
data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]})
@@ -128,10 +129,19 @@ class QAIndexProcessor(BaseIndexProcessor):
raise ValueError(str(e))
return text_docs
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
vector = Vector(dataset)
@@ -197,7 +207,7 @@ class QAIndexProcessor(BaseIndexProcessor):
for qa_chunk in qa_chunks.qa_chunks:
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
return {
"chunk_structure": IndexType.QA_INDEX,
"chunk_structure": IndexStructureType.QA_INDEX,
"qa_preview": preview,
"total_segments": len(qa_chunks.qa_chunks),
}

View File

@@ -4,6 +4,8 @@ from typing import Any
from pydantic import BaseModel, Field
from core.file import File
class ChildDocument(BaseModel):
"""Class for storing a piece of text and associated metadata."""
@@ -15,7 +17,19 @@ class ChildDocument(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: dict = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
class AttachmentDocument(BaseModel):
"""Class for storing a piece of text and associated metadata."""
page_content: str
provider: str | None = "dify"
vector: list[float] | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
class Document(BaseModel):
@@ -28,12 +42,31 @@ class Document(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: dict = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
provider: str | None = "dify"
children: list[ChildDocument] | None = None
attachments: list[AttachmentDocument] | None = None
class GeneralChunk(BaseModel):
"""
General Chunk.
"""
content: str
files: list[File] | None = None
class MultimodalGeneralStructureChunk(BaseModel):
"""
Multimodal General Structure Chunk.
"""
general_chunks: list[GeneralChunk]
class GeneralStructureChunk(BaseModel):
"""
@@ -50,6 +83,7 @@ class ParentChildChunk(BaseModel):
parent_content: str
child_contents: list[str]
files: list[File] | None = None
class ParentChildStructureChunk(BaseModel):

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
@@ -12,6 +13,7 @@ class BaseRerankRunner(ABC):
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
"""
Run rerank model

View File

@@ -1,6 +1,15 @@
from core.model_manager import ModelInstance
import base64
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.rerank_base import BaseRerankRunner
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import UploadFile
class RerankModelRunner(BaseRerankRunner):
@@ -14,6 +23,7 @@ class RerankModelRunner(BaseRerankRunner):
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
"""
Run rerank model
@@ -24,6 +34,56 @@ class RerankModelRunner(BaseRerankRunner):
:param user: unique user id if needed
:return:
"""
model_manager = ModelManager()
is_support_vision = model_manager.check_model_support_vision(
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
provider=self.rerank_model_instance.provider,
model=self.rerank_model_instance.model,
model_type=ModelType.RERANK,
)
if not is_support_vision:
if query_type == QueryType.TEXT_QUERY:
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
else:
return documents
else:
rerank_result, unique_documents = self.fetch_multimodal_rerank(
query, documents, score_threshold, top_n, user, query_type
)
rerank_documents = []
for result in rerank_result.docs:
if score_threshold is None or result.score >= score_threshold:
# format document
rerank_document = Document(
page_content=result.text,
metadata=unique_documents[result.index].metadata,
provider=unique_documents[result.index].provider,
)
if rerank_document.metadata is not None:
rerank_document.metadata["score"] = result.score
rerank_documents.append(rerank_document)
rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents
def fetch_text_rerank(
self,
query: str,
documents: list[Document],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> tuple[RerankResult, list[Document]]:
"""
Fetch text rerank
:param query: search query
:param documents: documents for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id if needed
:return:
"""
docs = []
doc_ids = set()
unique_documents = []
@@ -33,33 +93,99 @@ class RerankModelRunner(BaseRerankRunner):
and document.metadata is not None
and document.metadata["doc_id"] not in doc_ids
):
doc_ids.add(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT:
doc_ids.add(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
elif document.provider == "external":
if document not in unique_documents:
docs.append(document.page_content)
unique_documents.append(document)
documents = unique_documents
rerank_result = self.rerank_model_instance.invoke_rerank(
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
)
return rerank_result, unique_documents
rerank_documents = []
def fetch_multimodal_rerank(
self,
query: str,
documents: list[Document],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> tuple[RerankResult, list[Document]]:
"""
Fetch multimodal rerank
:param query: search query
:param documents: documents for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id if needed
:param query_type: query type
:return: rerank result
"""
docs = []
doc_ids = set()
unique_documents = []
for document in documents:
if (
document.provider == "dify"
and document.metadata is not None
and document.metadata["doc_id"] not in doc_ids
):
if document.metadata.get("doc_type") == DocType.IMAGE:
# Query file info within db.session context to ensure thread-safe access
upload_file = (
db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first()
)
if upload_file:
blob = storage.load_once(upload_file.key)
document_file_base64 = base64.b64encode(blob).decode()
document_file_dict = {
"content": document_file_base64,
"content_type": document.metadata["doc_type"],
}
docs.append(document_file_dict)
else:
document_text_dict = {
"content": document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
}
docs.append(document_text_dict)
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
elif document.provider == "external":
if document not in unique_documents:
docs.append(
{
"content": document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
}
)
unique_documents.append(document)
for result in rerank_result.docs:
if score_threshold is None or result.score >= score_threshold:
# format document
rerank_document = Document(
page_content=result.text,
metadata=documents[result.index].metadata,
provider=documents[result.index].provider,
documents = unique_documents
if query_type == QueryType.TEXT_QUERY:
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
return rerank_result, unique_documents
elif query_type == QueryType.IMAGE_QUERY:
# Query file info within db.session context to ensure thread-safe access
upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first()
if upload_file:
blob = storage.load_once(upload_file.key)
file_query = base64.b64encode(blob).decode()
file_query_dict = {
"content": file_query,
"content_type": DocType.IMAGE,
}
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
)
if rerank_document.metadata is not None:
rerank_document.metadata["score"] = result.score
rerank_documents.append(rerank_document)
return rerank_result, unique_documents
else:
raise ValueError(f"Upload file not found for query: {query}")
rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents
else:
raise ValueError(f"Query type {query_type} is not supported")

View File

@@ -7,6 +7,8 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.entity.weight import VectorSetting, Weights
from core.rag.rerank.rerank_base import BaseRerankRunner
@@ -24,6 +26,7 @@ class WeightRerankRunner(BaseRerankRunner):
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
"""
Run rerank model
@@ -43,8 +46,10 @@ class WeightRerankRunner(BaseRerankRunner):
and document.metadata is not None
and document.metadata["doc_id"] not in doc_ids
):
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
# weight rerank only support text documents
if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT:
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
else:
if document not in unique_documents:
unique_documents.append(document)

View File

@@ -8,6 +8,7 @@ from typing import Any, Union, cast
from flask import Flask, current_app
from sqlalchemy import and_, or_, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import (
DatasetEntity,
@@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.file import File, FileTransferMethod, FileType
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
@@ -37,7 +39,9 @@ from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@@ -52,10 +56,12 @@ from core.rag.retrieval.template_prompts import (
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.tools.signature import sign_upload_file
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
@@ -99,7 +105,8 @@ class DatasetRetrieval:
message_id: str,
memory: TokenBufferMemory | None = None,
inputs: Mapping[str, Any] | None = None,
) -> str | None:
vision_enabled: bool = False,
) -> tuple[str | None, list[File] | None]:
"""
Retrieve dataset.
:param app_id: app_id
@@ -118,7 +125,7 @@ class DatasetRetrieval:
"""
dataset_ids = config.dataset_ids
if len(dataset_ids) == 0:
return None
return None, []
retrieve_config = config.retrieve_config
# check model is support tool calling
@@ -136,7 +143,7 @@ class DatasetRetrieval:
)
if not model_schema:
return None
return None, []
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
@@ -182,8 +189,8 @@ class DatasetRetrieval:
tenant_id,
user_id,
user_from,
available_datasets,
query,
available_datasets,
model_instance,
model_config,
planning_strategy,
@@ -213,6 +220,7 @@ class DatasetRetrieval:
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
document_context_list: list[DocumentContext] = []
context_files: list[File] = []
retrieval_resource_list: list[RetrievalSourceMetadata] = []
# deal with external documents
for item in external_documents:
@@ -248,6 +256,31 @@ class DatasetRetrieval:
score=record.score,
)
)
if vision_enabled:
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.segment_id == segment.id,
)
).all()
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attachment_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=segment.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=sign_upload_file(upload_file.id, upload_file.extension),
)
context_files.append(attachment_info)
if show_retrieve_source:
for record in records:
segment = record.segment
@@ -288,8 +321,10 @@ class DatasetRetrieval:
hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list:
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
return str("\n".join([document_context.content for document_context in document_context_list]))
return ""
return str(
"\n".join([document_context.content for document_context in document_context_list])
), context_files
return "", context_files
def single_retrieve(
self,
@@ -297,8 +332,8 @@ class DatasetRetrieval:
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
available_datasets: list,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
@@ -336,7 +371,7 @@ class DatasetRetrieval:
dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
self._record_usage(router_usage)
timer = None
if dataset_id:
# get retrieval model config
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
@@ -406,10 +441,19 @@ class DatasetRetrieval:
weights=retrieval_model_config.get("weights", None),
document_ids_filter=document_ids_filter,
)
self._on_query(query, [dataset_id], app_id, user_from, user_id)
self._on_query(query, None, [dataset_id], app_id, user_from, user_id)
if results:
self._on_retrieval_end(results, message_id, timer)
thread = threading.Thread(
target=self._on_retrieval_end,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"documents": results,
"message_id": message_id,
"timer": timer,
},
)
thread.start()
return results
return []
@@ -421,7 +465,7 @@ class DatasetRetrieval:
user_id: str,
user_from: str,
available_datasets: list,
query: str,
query: str | None,
top_k: int,
score_threshold: float,
reranking_mode: str,
@@ -431,10 +475,11 @@ class DatasetRetrieval:
message_id: str | None = None,
metadata_filter_document_ids: dict[str, list[str]] | None = None,
metadata_condition: MetadataCondition | None = None,
attachment_ids: list[str] | None = None,
):
if not available_datasets:
return []
threads = []
all_threads = []
all_documents: list[Document] = []
dataset_ids = [dataset.id for dataset in available_datasets]
index_type_check = all(
@@ -467,102 +512,187 @@ class DatasetRetrieval:
0
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
with measure_time() as timer:
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
all_documents = data_post_processor.invoke(
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
if query:
query_thread = threading.Thread(
target=self._multiple_retrieve_thread,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"available_datasets": available_datasets,
"metadata_condition": metadata_condition,
"metadata_filter_document_ids": metadata_filter_document_ids,
"all_documents": all_documents,
"tenant_id": tenant_id,
"reranking_enable": reranking_enable,
"reranking_mode": reranking_mode,
"reranking_model": reranking_model,
"weights": weights,
"top_k": top_k,
"score_threshold": score_threshold,
"query": query,
"attachment_id": None,
},
)
else:
if index_type == "economy":
all_documents = self.calculate_keyword_score(query, all_documents, top_k)
elif index_type == "high_quality":
all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
else:
all_documents = all_documents[:top_k] if top_k else all_documents
self._on_query(query, dataset_ids, app_id, user_from, user_id)
all_threads.append(query_thread)
query_thread.start()
if attachment_ids:
for attachment_id in attachment_ids:
attachment_thread = threading.Thread(
target=self._multiple_retrieve_thread,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"available_datasets": available_datasets,
"metadata_condition": metadata_condition,
"metadata_filter_document_ids": metadata_filter_document_ids,
"all_documents": all_documents,
"tenant_id": tenant_id,
"reranking_enable": reranking_enable,
"reranking_mode": reranking_mode,
"reranking_model": reranking_model,
"weights": weights,
"top_k": top_k,
"score_threshold": score_threshold,
"query": None,
"attachment_id": attachment_id,
},
)
all_threads.append(attachment_thread)
attachment_thread.start()
for thread in all_threads:
thread.join()
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
if all_documents:
self._on_retrieval_end(all_documents, message_id, timer)
# add thread to call _on_retrieval_end
retrieval_end_thread = threading.Thread(
target=self._on_retrieval_end,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"documents": all_documents,
"message_id": message_id,
"timer": timer,
},
)
retrieval_end_thread.start()
retrieval_resource_list = []
doc_ids_filter = []
for document in all_documents:
if document.provider == "dify":
doc_id = document.metadata.get("doc_id")
if doc_id and doc_id not in doc_ids_filter:
doc_ids_filter.append(doc_id)
retrieval_resource_list.append(document)
elif document.provider == "external":
retrieval_resource_list.append(document)
return retrieval_resource_list
return all_documents
def _on_retrieval_end(self, documents: list[Document], message_id: str | None = None, timer: dict | None = None):
def _on_retrieval_end(
self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None
):
"""Handle retrieval end."""
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
if document.metadata is not None:
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == document.metadata["document_id"]
)
dataset_document = db.session.scalar(dataset_document_stmt)
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
_ = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
)
with flask_app.app_context():
dify_documents = [document for document in documents if document.provider == "dify"]
if not dify_documents:
self._send_trace_task(message_id, documents, timer)
return
with Session(db.engine) as session:
# Collect all document_ids and batch fetch DatasetDocuments
document_ids = {
doc.metadata["document_id"]
for doc in dify_documents
if doc.metadata and "document_id" in doc.metadata
}
if not document_ids:
self._send_trace_task(message_id, documents, timer)
return
dataset_docs_stmt = select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))
dataset_docs = session.scalars(dataset_docs_stmt).all()
dataset_doc_map = {str(doc.id): doc for doc in dataset_docs}
# Categorize documents by type and collect necessary IDs
parent_child_text_docs: list[tuple[Document, DatasetDocument]] = []
parent_child_image_docs: list[tuple[Document, DatasetDocument]] = []
normal_text_docs: list[tuple[Document, DatasetDocument]] = []
normal_image_docs: list[tuple[Document, DatasetDocument]] = []
for doc in dify_documents:
if not doc.metadata or "document_id" not in doc.metadata:
continue
dataset_doc = dataset_doc_map.get(doc.metadata["document_id"])
if not dataset_doc:
continue
is_image = doc.metadata.get("doc_type") == DocType.IMAGE
is_parent_child = dataset_doc.doc_form == IndexStructureType.PARENT_CHILD_INDEX
if is_parent_child:
if is_image:
parent_child_image_docs.append((doc, dataset_doc))
else:
parent_child_text_docs.append((doc, dataset_doc))
else:
query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"]
if is_image:
normal_image_docs.append((doc, dataset_doc))
else:
normal_text_docs.append((doc, dataset_doc))
segment_ids_to_update: set[str] = set()
# Process PARENT_CHILD_INDEX text documents - batch fetch ChildChunks
if parent_child_text_docs:
index_node_ids = [doc.metadata["doc_id"] for doc, _ in parent_child_text_docs if doc.metadata]
if index_node_ids:
child_chunks_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(index_node_ids))
child_chunks = session.scalars(child_chunks_stmt).all()
child_chunk_map = {chunk.index_node_id: chunk.segment_id for chunk in child_chunks}
for doc, _ in parent_child_text_docs:
if doc.metadata:
segment_id = child_chunk_map.get(doc.metadata["doc_id"])
if segment_id:
segment_ids_to_update.add(str(segment_id))
# Process non-PARENT_CHILD_INDEX text documents - batch fetch DocumentSegments
if normal_text_docs:
index_node_ids = [doc.metadata["doc_id"] for doc, _ in normal_text_docs if doc.metadata]
if index_node_ids:
segments_stmt = select(DocumentSegment).where(DocumentSegment.index_node_id.in_(index_node_ids))
segments = session.scalars(segments_stmt).all()
segment_map = {seg.index_node_id: seg.id for seg in segments}
for doc, _ in normal_text_docs:
if doc.metadata:
segment_id = segment_map.get(doc.metadata["doc_id"])
if segment_id:
segment_ids_to_update.add(str(segment_id))
# Process IMAGE documents - batch fetch SegmentAttachmentBindings
all_image_docs = parent_child_image_docs + normal_image_docs
if all_image_docs:
attachment_ids = [
doc.metadata["doc_id"]
for doc, _ in all_image_docs
if doc.metadata and doc.metadata.get("doc_id")
]
if attachment_ids:
bindings_stmt = select(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.attachment_id.in_(attachment_ids)
)
bindings = session.scalars(bindings_stmt).all()
segment_ids_to_update.update(str(binding.segment_id) for binding in bindings)
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# Batch update hit_count for all segments
if segment_ids_to_update:
session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
session.commit()
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
self._send_trace_task(message_id, documents, timer)
db.session.commit()
# get tracing instance
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None):
"""Send trace task if trace manager is available."""
trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
@@ -573,25 +703,40 @@ class DatasetRetrieval:
)
)
def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str):
def _on_query(
self,
query: str | None,
attachment_ids: list[str] | None,
dataset_ids: list[str],
app_id: str,
user_from: str,
user_id: str,
):
"""
Handle query.
"""
if not query:
if not query and not attachment_ids:
return
dataset_queries = []
for dataset_id in dataset_ids:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source="app",
source_app_id=app_id,
created_by_role=user_from,
created_by=user_id,
)
dataset_queries.append(dataset_query)
if dataset_queries:
db.session.add_all(dataset_queries)
contents = []
if query:
contents.append({"content_type": QueryType.TEXT_QUERY, "content": query})
if attachment_ids:
for attachment_id in attachment_ids:
contents.append({"content_type": QueryType.IMAGE_QUERY, "content": attachment_id})
if contents:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=json.dumps(contents),
source="app",
source_app_id=app_id,
created_by_role=user_from,
created_by=user_id,
)
dataset_queries.append(dataset_query)
if dataset_queries:
db.session.add_all(dataset_queries)
db.session.commit()
def _retriever(
@@ -603,6 +748,7 @@ class DatasetRetrieval:
all_documents: list,
document_ids_filter: list[str] | None = None,
metadata_condition: MetadataCondition | None = None,
attachment_ids: list[str] | None = None,
):
with flask_app.app_context():
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
@@ -611,7 +757,7 @@ class DatasetRetrieval:
if not dataset:
return []
if dataset.provider == "external":
if dataset.provider == "external" and query:
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
@@ -663,6 +809,7 @@ class DatasetRetrieval:
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
document_ids_filter=document_ids_filter,
attachment_ids=attachment_ids,
)
all_documents.extend(documents)
@@ -1222,3 +1369,86 @@ class DatasetRetrieval:
usage = LLMUsage.empty_usage()
return full_text, usage
def _multiple_retrieve_thread(
self,
flask_app: Flask,
available_datasets: list,
metadata_condition: MetadataCondition | None,
metadata_filter_document_ids: dict[str, list[str]] | None,
all_documents: list[Document],
tenant_id: str,
reranking_enable: bool,
reranking_mode: str,
reranking_model: dict | None,
weights: dict[str, Any] | None,
top_k: int,
score_threshold: float,
query: str | None,
attachment_id: str | None,
):
with flask_app.app_context():
threads = []
all_documents_item: list[Document] = []
index_type = None
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": flask_app,
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents_item,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
"attachment_ids": [attachment_id] if attachment_id else None,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)

View File

@@ -0,0 +1,65 @@
{
"$id": "https://dify.ai/schemas/v1/multimodal_general_structure.json",
"$schema": "http://json-schema.org/draft-07/schema#",
"version": "1.0.0",
"type": "array",
"title": "Multimodal General Structure",
"description": "Schema for multimodal general structure (v1) - array of objects",
"properties": {
"general_chunks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The content"
},
"files": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "file name"
},
"size": {
"type": "number",
"description": "file size"
},
"extension": {
"type": "string",
"description": "file extension"
},
"type": {
"type": "string",
"description": "file type"
},
"mime_type": {
"type": "string",
"description": "file mime type"
},
"transfer_method": {
"type": "string",
"description": "file transfer method"
},
"url": {
"type": "string",
"description": "file url"
},
"related_id": {
"type": "string",
"description": "file related id"
}
},
"description": "List of files"
}
}
},
"required": ["content"]
},
"description": "List of content and files"
}
}
}

View File

@@ -0,0 +1,78 @@
{
"$id": "https://dify.ai/schemas/v1/multimodal_parent_child_structure.json",
"$schema": "http://json-schema.org/draft-07/schema#",
"version": "1.0.0",
"type": "object",
"title": "Multimodal Parent-Child Structure",
"description": "Schema for multimodal parent-child structure (v1)",
"properties": {
"parent_mode": {
"type": "string",
"description": "The mode of parent-child relationship"
},
"parent_child_chunks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"parent_content": {
"type": "string",
"description": "The parent content"
},
"files": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "file name"
},
"size": {
"type": "number",
"description": "file size"
},
"extension": {
"type": "string",
"description": "file extension"
},
"type": {
"type": "string",
"description": "file type"
},
"mime_type": {
"type": "string",
"description": "file mime type"
},
"transfer_method": {
"type": "string",
"description": "file transfer method"
},
"url": {
"type": "string",
"description": "file url"
},
"related_id": {
"type": "string",
"description": "file related id"
}
},
"required": ["name", "size", "extension", "type", "mime_type", "transfer_method", "url", "related_id"]
},
"description": "List of files"
},
"child_contents": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of child contents"
}
},
"required": ["parent_content", "child_contents"]
},
"description": "List of parent-child chunk pairs"
}
},
"required": ["parent_mode", "parent_child_chunks"]
}

View File

@@ -25,6 +25,24 @@ def sign_tool_file(tool_file_id: str, extension: str) -> str:
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
def sign_upload_file(upload_file_id: str, extension: str) -> str:
"""
sign file to get a temporary url for plugin access
"""
# Use internal URL for plugin/tool file access in Docker environments
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
file_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature

View File

@@ -13,5 +13,5 @@ def remove_leading_symbols(text: str) -> str:
"""
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'
return re.sub(pattern, "", text)

View File

@@ -221,7 +221,7 @@ class WorkflowToolProviderController(ToolProviderController):
session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
WorkflowToolProvider.id == self.provider_id,
)
.first()
)

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from pydantic import Field
from core.file import File
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.pause_reason import PauseReason
@@ -14,6 +15,7 @@ from .base import NodeEventBase
class RunRetrieverResourceEvent(NodeEventBase):
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
context_files: list[File] | None = Field(default=None, description="context files")
class ModelInvokeCompletedEvent(NodeEventBase):

View File

@@ -59,7 +59,7 @@ class OutputVariableEntity(BaseModel):
"""
variable: str
value_type: OutputVariableType
value_type: OutputVariableType = OutputVariableType.ANY
value_selector: Sequence[str]
@field_validator("value_type", mode="before")

View File

@@ -412,16 +412,20 @@ class Executor:
body_string += f"--{boundary}\r\n"
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
# decode content safely
try:
body_string += content.decode("utf-8")
except UnicodeDecodeError:
body_string += content.decode("utf-8", errors="replace")
body_string += "\r\n"
# Do not decode binary content; use a placeholder with file metadata instead.
# Includes filename, size, and MIME type for better logging context.
body_string += (
f"<file_content_binary: '{file_entry[1][0] or 'unknown'}', "
f"type='{file_entry[1][2] if len(file_entry[1]) > 2 else 'unknown'}', "
f"size={len(content)} bytes>\r\n"
)
body_string += f"--{boundary}--\r\n"
elif self.node_data.body:
if self.content:
# If content is bytes, do not decode it; show a placeholder with size.
# Provides content size information for binary data without exposing the raw bytes.
if isinstance(self.content, bytes):
body_string = self.content.decode("utf-8", errors="replace")
body_string = f"<binary_content: size={len(self.content)} bytes>"
else:
body_string = self.content
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":

View File

@@ -114,7 +114,8 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
"""
type: str = "knowledge-retrieval"
query_variable_selector: list[str]
query_variable_selector: list[str] | None | str = None
query_attachment_selector: list[str] | None | str = None
dataset_ids: list[str]
retrieval_mode: Literal["single", "multiple"]
multiple_retrieval_config: MultipleRetrievalConfig | None = None

View File

@@ -25,6 +25,8 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import (
ArrayFileSegment,
FileSegment,
StringSegment,
)
from core.variables.segments import ArrayObjectSegment
@@ -119,20 +121,41 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
return "1"
def _run(self) -> NodeRunResult:
# extract variables
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={},
error="Query variable is not string type.",
)
query = variable.value
variables = {"query": query}
if not query:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
process_data={},
outputs={},
metadata={},
llm_usage=LLMUsage.empty_usage(),
)
variables: dict[str, Any] = {}
# extract variables
if self._node_data.query_variable_selector:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error="Query variable is not string type.",
)
query = variable.value
variables["query"] = query
if self._node_data.query_attachment_selector:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_attachment_selector)
if not isinstance(variable, ArrayFileSegment) and not isinstance(variable, FileSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error="Attachments variable is not array file or file type.",
)
if isinstance(variable, ArrayFileSegment):
variables["attachments"] = variable.value
else:
variables["attachments"] = [variable.value]
# TODO(-LAN-): Move this check outside.
# check rate limit
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
@@ -161,7 +184,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
# retrieve knowledge
usage = LLMUsage.empty_usage()
try:
results, usage = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -198,12 +221,16 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
db.session.close()
def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, query: str
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
available_datasets = []
dataset_ids = node_data.dataset_ids
query = variables.get("query")
attachments = variables.get("attachments")
metadata_filter_document_ids = None
metadata_condition = None
metadata_usage = LLMUsage.empty_usage()
# Subquery: Count the number of available documents for each dataset
subquery = (
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
@@ -234,13 +261,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if not dataset:
continue
available_datasets.append(dataset)
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data
)
usage = self._merge_usage(usage, metadata_usage)
if query:
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data
)
usage = self._merge_usage(usage, metadata_usage)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
# fetch model config
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
@@ -272,7 +300,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
@@ -319,6 +347,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
)
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
@@ -327,7 +356,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
retrieval_resource_list = []
# deal with external documents
for item in external_documents:
source = {
source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = {
"metadata": {
"_source": "knowledge",
"dataset_id": item.metadata.get("dataset_id"),
@@ -384,6 +413,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
"doc_metadata": document.doc_metadata,
},
"title": document.name,
"files": list(record.files) if record.files else None,
}
if segment.answer:
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
@@ -393,13 +423,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if retrieval_resource_list:
retrieval_resource_list = sorted(
retrieval_resource_list,
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
key=self._score, # type: ignore[arg-type, return-value]
reverse=True,
)
for position, item in enumerate(retrieval_resource_list, start=1):
item["metadata"]["position"] = position
item["metadata"]["position"] = position # type: ignore[index]
return retrieval_resource_list, usage
def _score(self, item: dict[str, Any]) -> float:
meta = item.get("metadata")
if isinstance(meta, dict):
s = meta.get("score")
if isinstance(s, (int, float)):
return float(s)
return 0.0
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
@@ -659,7 +697,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
variable_mapping = {}
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
if typed_node_data.query_variable_selector:
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
if typed_node_data.query_attachment_selector:
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
return variable_mapping
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:

View File

@@ -7,8 +7,10 @@ import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import select
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
from core.file import File, FileTransferMethod, FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
@@ -44,6 +46,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.signature import sign_upload_file
from core.variables import (
ArrayFileSegment,
ArraySegment,
@@ -72,6 +75,9 @@ from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from . import llm_utils
from .entities import (
@@ -179,12 +185,17 @@ class LLMNode(Node[LLMNodeData]):
# fetch context value
generator = self._fetch_context(node_data=self.node_data)
context = None
context_files: list[File] = []
for event in generator:
context = event.context
context_files = event.context_files or []
yield event
if context:
node_inputs["#context#"] = context
if context_files:
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
# fetch model config
model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=self.node_data.model,
@@ -220,6 +231,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
# handle invoke result
@@ -322,6 +334,7 @@ class LLMNode(Node[LLMNodeData]):
inputs=node_inputs,
process_data=process_data,
error_type=type(e).__name__,
llm_usage=usage,
)
)
except Exception as e:
@@ -332,6 +345,8 @@ class LLMNode(Node[LLMNodeData]):
error=str(e),
inputs=node_inputs,
process_data=process_data,
error_type=type(e).__name__,
llm_usage=usage,
)
)
@@ -654,10 +669,13 @@ class LLMNode(Node[LLMNodeData]):
context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
if context_value_variable:
if isinstance(context_value_variable, StringSegment):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
yield RunRetrieverResourceEvent(
retriever_resources=[], context=context_value_variable.value, context_files=[]
)
elif isinstance(context_value_variable, ArraySegment):
context_str = ""
original_retriever_resource: list[RetrievalSourceMetadata] = []
context_files: list[File] = []
for item in context_value_variable.value:
if isinstance(item, str):
context_str += item + "\n"
@@ -670,9 +688,34 @@ class LLMNode(Node[LLMNodeData]):
retriever_resource = self._convert_to_original_retriever_resource(item)
if retriever_resource:
original_retriever_resource.append(retriever_resource)
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.segment_id == retriever_resource.segment_id,
)
).all()
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attachment_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=sign_upload_file(upload_file.id, upload_file.extension),
)
context_files.append(attachment_info)
yield RunRetrieverResourceEvent(
retriever_resources=original_retriever_resource, context=context_str.strip()
retriever_resources=original_retriever_resource,
context=context_str.strip(),
context_files=context_files,
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None:
@@ -700,6 +743,7 @@ class LLMNode(Node[LLMNodeData]):
content=context_dict.get("content"),
page=metadata.get("page"),
doc_metadata=metadata.get("doc_metadata"),
files=context_dict.get("files"),
)
return source
@@ -741,6 +785,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
context_files: list["File"] | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
@@ -853,6 +898,23 @@ class LLMNode(Node[LLMNodeData]):
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# The context_files
if vision_enabled and context_files:
file_prompts = []
for file in context_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Remove empty messages and filter unsupported content
filtered_prompt_messages = []
for prompt_message in prompt_messages:

View File

@@ -221,6 +221,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,

View File

@@ -97,11 +97,27 @@ dataset_detail_fields = {
"total_documents": fields.Integer,
"total_available_documents": fields.Integer,
"enable_api": fields.Boolean,
"is_multimodal": fields.Boolean,
}
file_info_fields = {
"id": fields.String,
"name": fields.String,
"size": fields.Integer,
"extension": fields.String,
"mime_type": fields.String,
"source_url": fields.String,
}
content_fields = {
"content_type": fields.String,
"content": fields.String,
"file_info": fields.Nested(file_info_fields, allow_null=True),
}
dataset_query_detail_fields = {
"id": fields.String,
"content": fields.String,
"queries": fields.Nested(content_fields),
"source": fields.String,
"source_app_id": fields.String,
"created_by_role": fields.String,

View File

@@ -9,6 +9,8 @@ upload_config_fields = {
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"workflow_file_upload_limit": fields.Integer,
"image_file_batch_limit": fields.Integer,
"single_chunk_attachment_limit": fields.Integer,
}

View File

@@ -43,9 +43,19 @@ child_chunk_fields = {
"score": fields.Float,
}
files_fields = {
"id": fields.String,
"name": fields.String,
"size": fields.Integer,
"extension": fields.String,
"mime_type": fields.String,
"source_url": fields.String,
}
hit_testing_record_fields = {
"segment": fields.Nested(segment_fields),
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
"score": fields.Float,
"tsne_position": fields.Raw,
"files": fields.List(fields.Nested(files_fields)),
}

View File

@@ -13,6 +13,15 @@ child_chunk_fields = {
"updated_at": TimestampField,
}
attachment_fields = {
"id": fields.String,
"name": fields.String,
"size": fields.Integer,
"extension": fields.String,
"mime_type": fields.String,
"source_url": fields.String,
}
segment_fields = {
"id": fields.String,
"position": fields.Integer,
@@ -39,4 +48,5 @@ segment_fields = {
"error": fields.String,
"stopped_at": TimestampField,
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
"attachments": fields.List(fields.Nested(attachment_fields)),
}

View File

@@ -107,7 +107,7 @@ def email(email):
EmailStr = Annotated[str, AfterValidator(email)]
def uuid_value(value):
def uuid_value(value: Any) -> str:
if value == "":
return str(value)
@@ -215,7 +215,11 @@ def generate_text_hash(text: str) -> str:
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json")
return Response(
response=json.dumps(jsonable_encoder(response)),
status=200,
content_type="application/json; charset=utf-8",
)
else:
def generate() -> Generator:

View File

@@ -0,0 +1,57 @@
"""support-multi-modal
Revision ID: d57accd375ae
Revises: 03f8dcbc611e
Create Date: 2025-11-12 15:37:12.363670
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'd57accd375ae'
down_revision = '7bb281b7a422'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('segment_attachment_bindings',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('segment_id', models.types.StringUUID(), nullable=False),
sa.Column('attachment_id', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.PrimaryKeyConstraint('id', name='segment_attachment_binding_pkey')
)
with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op:
batch_op.create_index(
'segment_attachment_binding_tenant_dataset_document_segment_idx',
['tenant_id', 'dataset_id', 'document_id', 'segment_id'],
unique=False
)
batch_op.create_index('segment_attachment_binding_attachment_idx', ['attachment_id'], unique=False)
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('is_multimodal', sa.Boolean(), server_default=sa.text('false'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_column('is_multimodal')
with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op:
batch_op.drop_index('segment_attachment_binding_attachment_idx')
batch_op.drop_index('segment_attachment_binding_tenant_dataset_document_segment_idx')
op.drop_table('segment_attachment_bindings')
# ### end Alembic commands ###

View File

@@ -1,4 +1,4 @@
"""empty message
"""mysql adaptation
Revision ID: 09cfdda155d1
Revises: 669ffd70119c
@@ -97,11 +97,31 @@ def downgrade():
batch_op.alter_column('include_plugins',
existing_type=sa.JSON(),
type_=postgresql.ARRAY(sa.VARCHAR(length=255)),
existing_nullable=False)
existing_nullable=False,
postgresql_using="""
COALESCE(
regexp_replace(
replace(replace(include_plugins::text, '[', '{'), ']', '}'),
'"',
'',
'g'
)::varchar(255)[],
ARRAY[]::varchar(255)[]
)""")
batch_op.alter_column('exclude_plugins',
existing_type=sa.JSON(),
type_=postgresql.ARRAY(sa.VARCHAR(length=255)),
existing_nullable=False)
existing_nullable=False,
postgresql_using="""
COALESCE(
regexp_replace(
replace(replace(exclude_plugins::text, '[', '{'), ']', '}'),
'"',
'',
'g'
)::varchar(255)[],
ARRAY[]::varchar(255)[]
)""")
with op.batch_alter_table('external_knowledge_bindings', schema=None) as batch_op:
batch_op.alter_column('external_knowledge_id',

View File

@@ -19,7 +19,9 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.signature import sign_upload_file
from extensions.ext_storage import storage
from libs.uuid_utils import uuidv7
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
@@ -76,6 +78,7 @@ class Dataset(Base):
pipeline_id = mapped_column(StringUUID, nullable=True)
chunk_structure = mapped_column(sa.String(255), nullable=True)
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
is_multimodal = mapped_column(sa.Boolean, default=False, nullable=False, server_default=db.text("false"))
@property
def total_documents(self):
@@ -728,9 +731,7 @@ class DocumentSegment(Base):
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
error = mapped_column(LongText, nullable=True)
@@ -866,6 +867,47 @@ class DocumentSegment(Base):
return text
@property
def attachments(self) -> list[dict[str, Any]]:
# Use JOIN to fetch attachments in a single query instead of two separate queries
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.tenant_id == self.tenant_id,
SegmentAttachmentBinding.dataset_id == self.dataset_id,
SegmentAttachmentBinding.document_id == self.document_id,
SegmentAttachmentBinding.segment_id == self.id,
)
).all()
if not attachments_with_bindings:
return []
attachment_list = []
for _, attachment in attachments_with_bindings:
upload_file_id = attachment.id
nonce = os.urandom(16).hex()
timestamp = str(int(time.time()))
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
reference_url = dify_config.CONSOLE_API_URL or ""
base_url = f"{reference_url}/files/{upload_file_id}/image-preview"
source_url = f"{base_url}?{params}"
attachment_list.append(
{
"id": attachment.id,
"name": attachment.name,
"size": attachment.size,
"extension": attachment.extension,
"mime_type": attachment.mime_type,
"source_url": source_url,
}
)
return attachment_list
class ChildChunk(Base):
__tablename__ = "child_chunks"
@@ -963,6 +1005,38 @@ class DatasetQuery(TypeBase):
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
)
@property
def queries(self) -> list[dict[str, Any]]:
try:
queries = json.loads(self.content)
if isinstance(queries, list):
for query in queries:
if query["content_type"] == QueryType.IMAGE_QUERY:
file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first()
if file_info:
query["file_info"] = {
"id": file_info.id,
"name": file_info.name,
"size": file_info.size,
"extension": file_info.extension,
"mime_type": file_info.mime_type,
"source_url": sign_upload_file(file_info.id, file_info.extension),
}
else:
query["file_info"] = None
return queries
else:
return [queries]
except JSONDecodeError:
return [
{
"content_type": QueryType.TEXT_QUERY,
"content": self.content,
"file_info": None,
}
]
class DatasetKeywordTable(TypeBase):
__tablename__ = "dataset_keyword_tables"
@@ -1470,3 +1544,25 @@ class PipelineRecommendedPlugin(TypeBase):
onupdate=func.current_timestamp(),
init=False,
)
class SegmentAttachmentBinding(Base):
__tablename__ = "segment_attachment_bindings"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="segment_attachment_binding_pkey"),
sa.Index(
"segment_attachment_binding_tenant_dataset_document_segment_idx",
"tenant_id",
"dataset_id",
"document_id",
"segment_id",
),
sa.Index("segment_attachment_binding_attachment_idx", "attachment_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@@ -111,7 +111,11 @@ class App(Base):
else:
app_model_config = self.app_model_config
if app_model_config:
return app_model_config.pre_prompt
pre_prompt = app_model_config.pre_prompt or ""
# Truncate to 200 characters with ellipsis if using prompt as description
if len(pre_prompt) > 200:
return pre_prompt[:200] + "..."
return pre_prompt
else:
return ""
@@ -259,7 +263,7 @@ class App(Base):
provider_id = tool.get("provider_id", "")
if provider_type == ToolProviderType.API:
if uuid.UUID(provider_id) not in existing_api_providers:
if provider_id not in existing_api_providers:
deleted_tools.append(
{
"type": ToolProviderType.API,
@@ -835,7 +839,29 @@ class Conversation(Base):
@property
def status_count(self):
messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all()
from models.workflow import WorkflowRun
# Get all messages with workflow_run_id for this conversation
messages = db.session.scalars(
select(Message).where(Message.conversation_id == self.id, Message.workflow_run_id.isnot(None))
).all()
if not messages:
return None
# Batch load all workflow runs in a single query, filtered by this conversation's app_id
workflow_run_ids = [msg.workflow_run_id for msg in messages if msg.workflow_run_id]
workflow_runs = {}
if workflow_run_ids:
workflow_runs_query = db.session.scalars(
select(WorkflowRun).where(
WorkflowRun.id.in_(workflow_run_ids),
WorkflowRun.app_id == self.app_id, # Filter by this conversation's app_id
)
).all()
workflow_runs = {run.id: run for run in workflow_runs_query}
status_counts = {
WorkflowExecutionStatus.RUNNING: 0,
WorkflowExecutionStatus.SUCCEEDED: 0,
@@ -845,18 +871,24 @@ class Conversation(Base):
}
for message in messages:
if message.workflow_run:
status_counts[WorkflowExecutionStatus(message.workflow_run.status)] += 1
# Guard against None to satisfy type checker and avoid invalid dict lookups
if message.workflow_run_id is None:
continue
workflow_run = workflow_runs.get(message.workflow_run_id)
if not workflow_run:
continue
return (
{
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
"failed": status_counts[WorkflowExecutionStatus.FAILED],
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
}
if messages
else None
)
try:
status_counts[WorkflowExecutionStatus(workflow_run.status)] += 1
except (ValueError, KeyError):
# Handle invalid status values gracefully
pass
return {
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
"failed": status_counts[WorkflowExecutionStatus.FAILED],
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
}
@property
def first_message(self):
@@ -1255,13 +1287,9 @@ class Message(Base):
"id": self.id,
"app_id": self.app_id,
"conversation_id": self.conversation_id,
"model_provider": self.model_provider,
"model_id": self.model_id,
"inputs": self.inputs,
"query": self.query,
"message_tokens": self.message_tokens,
"answer_tokens": self.answer_tokens,
"provider_response_latency": self.provider_response_latency,
"total_price": self.total_price,
"message": self.message,
"answer": self.answer,
@@ -1283,12 +1311,8 @@ class Message(Base):
id=data["id"],
app_id=data["app_id"],
conversation_id=data["conversation_id"],
model_provider=data.get("model_provider"),
model_id=data["model_id"],
inputs=data["inputs"],
message_tokens=data.get("message_tokens", 0),
answer_tokens=data.get("answer_tokens", 0),
provider_response_latency=data.get("provider_response_latency", 0.0),
total_price=data["total_price"],
query=data["query"],
message=data["message"],

View File

@@ -907,19 +907,29 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
@property
def extras(self) -> dict[str, Any]:
from core.tools.tool_manager import ToolManager
from core.trigger.trigger_manager import TriggerManager
extras: dict[str, Any] = {}
if self.execution_metadata_dict:
if self.node_type == NodeType.TOOL and "tool_info" in self.execution_metadata_dict:
tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"]
execution_metadata = self.execution_metadata_dict
if execution_metadata:
if self.node_type == NodeType.TOOL and "tool_info" in execution_metadata:
tool_info: dict[str, Any] = execution_metadata["tool_info"]
extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self.tenant_id,
provider_type=tool_info["provider_type"],
provider_id=tool_info["provider_id"],
)
elif self.node_type == NodeType.DATASOURCE and "datasource_info" in self.execution_metadata_dict:
datasource_info = self.execution_metadata_dict["datasource_info"]
elif self.node_type == NodeType.DATASOURCE and "datasource_info" in execution_metadata:
datasource_info = execution_metadata["datasource_info"]
extras["icon"] = datasource_info.get("icon")
elif self.node_type == NodeType.TRIGGER_PLUGIN and "trigger_info" in execution_metadata:
trigger_info = execution_metadata["trigger_info"] or {}
provider_id = trigger_info.get("provider_id")
if provider_id:
extras["icon"] = TriggerManager.get_trigger_plugin_icon(
tenant_id=self.tenant_id,
provider_id=provider_id,
)
return extras
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:

View File

@@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.10.1"
version = "1.11.1"
requires-python = ">=3.11,<3.13"
dependencies = [
@@ -151,7 +151,7 @@ dev = [
"types-pywin32~=310.0.0",
"types-pyyaml~=6.0.12",
"types-regex~=2024.11.6",
"types-shapely~=2.0.0",
"types-shapely~=2.1.0",
"types-simplejson>=3.20.0",
"types-six>=1.17.0",
"types-tensorflow>=2.18.0",

View File

@@ -0,0 +1,31 @@
import base64
from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from extensions.ext_storage import storage
from models.model import UploadFile
PREVIEW_WORDS_LIMIT = 3000
class AttachmentService:
_session_maker: sessionmaker
def __init__(self, session_factory: sessionmaker | Engine | None = None):
if isinstance(session_factory, Engine):
self._session_maker = sessionmaker(bind=session_factory)
elif isinstance(session_factory, sessionmaker):
self._session_maker = session_factory
else:
raise AssertionError("must be a sessionmaker or an Engine.")
def get_file_base64(self, file_id: str) -> str:
upload_file = (
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
)
if not upload_file:
raise NotFound("File not found")
blob = storage.load_once(upload_file.key)
return base64.b64encode(blob).decode()

View File

@@ -118,7 +118,7 @@ class ConversationService:
app_model: App,
conversation_id: str,
user: Union[Account, EndUser] | None,
name: str,
name: str | None,
auto_generate: bool,
):
conversation = cls.get_conversation(app_model, conversation_id, user)

View File

@@ -7,7 +7,7 @@ import time
import uuid
from collections import Counter
from collections.abc import Sequence
from typing import Any, Literal
from typing import Any, Literal, cast
import sqlalchemy as sa
from redis.exceptions import LockNotOwnedError
@@ -19,9 +19,10 @@ from configs import dify_config
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.helper.name_generator import generate_incremental_name
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from enums.cloud_plan import CloudPlan
from events.dataset_event import dataset_was_deleted
@@ -46,6 +47,7 @@ from models.dataset import (
DocumentSegment,
ExternalKnowledgeBindings,
Pipeline,
SegmentAttachmentBinding,
)
from models.model import UploadFile
from models.provider_ids import ModelProviderID
@@ -363,6 +365,27 @@ class DatasetService:
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
@staticmethod
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str):
try:
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=model,
)
text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance)
model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials)
if not model_schema:
raise ValueError("Model schema not found")
if model_schema.features and ModelFeature.VISION in model_schema.features:
return True
else:
return False
except LLMBadRequestError:
raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.")
@staticmethod
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
try:
@@ -402,13 +425,13 @@ class DatasetService:
if not dataset:
raise ValueError("Dataset not found")
# check if dataset name is exists
if DatasetService._has_dataset_same_name(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
name=data.get("name", dataset.name),
):
raise ValueError("Dataset name already exists")
if data.get("name") and data.get("name") != dataset.name:
if DatasetService._has_dataset_same_name(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
name=data.get("name", dataset.name),
):
raise ValueError("Dataset name already exists")
# Verify user has permission to update this dataset
DatasetService.check_dataset_permission(dataset, user)
@@ -650,6 +673,8 @@ class DatasetService:
Returns:
str: Action to perform ('add', 'remove', 'update', or None)
"""
if "indexing_technique" not in data:
return None
if dataset.indexing_technique != data["indexing_technique"]:
if data["indexing_technique"] == "economy":
# Remove embedding model configuration for economy mode
@@ -844,6 +869,12 @@ class DatasetService:
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_configuration.embedding_model or "",
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_user.current_tenant_id,
knowledge_configuration.embedding_model_provider,
knowledge_configuration.embedding_model,
)
dataset.is_multimodal = is_multimodal
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
@@ -880,6 +911,12 @@ class DatasetService:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_user.current_tenant_id,
knowledge_configuration.embedding_model_provider,
knowledge_configuration.embedding_model,
)
dataset.is_multimodal = is_multimodal
dataset.collection_binding_id = dataset_collection_binding.id
dataset.indexing_technique = knowledge_configuration.indexing_technique
except LLMBadRequestError:
@@ -937,6 +974,12 @@ class DatasetService:
)
)
dataset.collection_binding_id = dataset_collection_binding.id
is_multimodal = DatasetService.check_is_multimodal_model(
current_user.current_tenant_id,
knowledge_configuration.embedding_model_provider,
knowledge_configuration.embedding_model,
)
dataset.is_multimodal = is_multimodal
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
@@ -1593,6 +1636,20 @@ class DocumentService:
return [], ""
db.session.add(dataset_process_rule)
db.session.flush()
else:
# Fallback when no process_rule provided in knowledge_config:
# 1) reuse dataset.latest_process_rule if present
# 2) otherwise create an automatic rule
dataset_process_rule = getattr(dataset, "latest_process_rule", None)
if not dataset_process_rule:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode="automatic",
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
db.session.add(dataset_process_rule)
db.session.flush()
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
try:
with redis_client.lock(lock_name, timeout=600):
@@ -1604,65 +1661,67 @@ class DocumentService:
if not knowledge_config.data_source.info_list.file_info_list:
raise ValueError("File source info is required")
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
files = (
db.session.query(UploadFile)
.where(
UploadFile.tenant_id == dataset.tenant_id,
UploadFile.id.in_(upload_file_list),
)
.all()
)
if len(files) != len(set(upload_file_list)):
raise FileNotExistsError("One or more files not found.")
# raise error if file not found
if not file:
raise FileNotExistsError()
file_name = file.name
file_names = [file.name for file in files]
db_documents = (
db.session.query(Document)
.where(
Document.dataset_id == dataset.id,
Document.tenant_id == current_user.current_tenant_id,
Document.data_source_type == "upload_file",
Document.enabled == True,
Document.name.in_(file_names),
)
.all()
)
documents_map = {document.name: document for document in db_documents}
for file in files:
data_source_info: dict[str, str | bool] = {
"upload_file_id": file_id,
"upload_file_id": file.id,
}
# check duplicate
if knowledge_config.duplicate:
document = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
)
.first()
document = documents_map.get(file.name)
if knowledge_config.duplicate and document:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
else:
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
file.name,
batch,
)
if document:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
file_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list:
@@ -2305,6 +2364,7 @@ class DocumentService:
embedding_model_provider=knowledge_config.embedding_model_provider,
collection_binding_id=dataset_collection_binding_id,
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
is_multimodal=knowledge_config.is_multimodal,
)
db.session.add(dataset)
@@ -2685,6 +2745,13 @@ class SegmentService:
if "content" not in args or not args["content"] or not args["content"].strip():
raise ValueError("Content is empty")
if args.get("attachment_ids"):
if not isinstance(args["attachment_ids"], list):
raise ValueError("Attachment IDs is invalid")
single_chunk_attachment_limit = dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT
if len(args["attachment_ids"]) > single_chunk_attachment_limit:
raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}")
@classmethod
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
@@ -2731,11 +2798,23 @@ class SegmentService:
segment_document.word_count += len(args["answer"])
segment_document.answer = args["answer"]
db.session.add(segment_document)
# update document word count
assert document.word_count is not None
document.word_count += segment_document.word_count
db.session.add(document)
db.session.add(segment_document)
# update document word count
assert document.word_count is not None
document.word_count += segment_document.word_count
db.session.add(document)
db.session.commit()
if args["attachment_ids"]:
for attachment_id in args["attachment_ids"]:
binding = SegmentAttachmentBinding(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
segment_id=segment_document.id,
attachment_id=attachment_id,
)
db.session.add(binding)
db.session.commit()
# save vector index
@@ -2899,7 +2978,7 @@ class SegmentService:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
# update segment index task
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# regenerate child chunks
# get embedding model instance
if dataset.indexing_technique == "high_quality":
@@ -2926,12 +3005,11 @@ class SegmentService:
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
if processing_rule:
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
if args.enabled or keyword_changed:
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
@@ -2976,7 +3054,7 @@ class SegmentService:
db.session.add(document)
db.session.add(segment)
db.session.commit()
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
@@ -3002,15 +3080,15 @@ class SegmentService:
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
if processing_rule:
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
# update multimodel vector index
VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset)
except Exception as e:
logger.exception("update segment index failed")
segment.enabled = False
@@ -3048,7 +3126,9 @@ class SegmentService:
)
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids)
delete_segment_from_index_task.delay(
[segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids
)
db.session.delete(segment)
# update document word count
@@ -3097,7 +3177,9 @@ class SegmentService:
# Start async cleanup with both parent and child node IDs
if index_node_ids or child_node_ids:
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids)
delete_segment_from_index_task.delay(
index_node_ids, dataset.id, document.id, segment_db_ids, child_node_ids
)
if document.word_count is None:
document.word_count = 0

View File

@@ -29,8 +29,14 @@ def get_current_user():
from models.account import Account
from models.model import EndUser
if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore
raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}")
try:
user_object = current_user._get_current_object()
except AttributeError:
# Handle case where current_user might not be a LocalProxy in test environments
user_object = current_user
if not isinstance(user_object, (Account, EndUser)):
raise TypeError(f"current_user must be Account or EndUser, got {type(user_object).__name__}")
return current_user

View File

@@ -124,6 +124,14 @@ class KnowledgeConfig(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
name: str | None = None
is_multimodal: bool = False
class SegmentCreateArgs(BaseModel):
content: str | None = None
answer: str | None = None
keywords: list[str] | None = None
attachment_ids: list[str] | None = None
class SegmentUpdateArgs(BaseModel):
@@ -132,6 +140,7 @@ class SegmentUpdateArgs(BaseModel):
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
enabled: bool | None = None
attachment_ids: list[str] | None = None
class ChildChunkUpdateArgs(BaseModel):

View File

@@ -1,3 +1,4 @@
import base64
import hashlib
import os
import uuid
@@ -123,6 +124,15 @@ class FileService:
return file_size <= file_size_limit
def get_file_base64(self, file_id: str) -> str:
upload_file = (
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
)
if not upload_file:
raise NotFound("File not found")
blob = storage.load_once(upload_file.key)
return base64.b64encode(blob).decode()
def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile:
if len(text_name) > 200:
text_name = text_name[:200]

View File

@@ -1,3 +1,4 @@
import json
import logging
import time
from typing import Any
@@ -5,6 +6,7 @@ from typing import Any
from core.app.app_config.entities import ModelConfig
from core.model_runtime.entities import LLMMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@@ -32,6 +34,7 @@ class HitTestingService:
account: Account,
retrieval_model: Any, # FIXME drop this any
external_retrieval_model: dict,
attachment_ids: list | None = None,
limit: int = 10,
):
start = time.perf_counter()
@@ -41,7 +44,7 @@ class HitTestingService:
retrieval_model = dataset.retrieval_model or default_retrieval_model
document_ids_filter = None
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
if metadata_filtering_conditions:
if metadata_filtering_conditions and query:
dataset_retrieval = DatasetRetrieval()
from core.app.app_config.entities import MetadataFilteringCondition
@@ -66,6 +69,7 @@ class HitTestingService:
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
dataset_id=dataset.id,
query=query,
attachment_ids=attachment_ids,
top_k=retrieval_model.get("top_k", 4),
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
@@ -80,17 +84,24 @@ class HitTestingService:
end = time.perf_counter()
logger.debug("Hit testing retrieve in %s seconds", end - start)
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=query,
source="hit_testing",
source_app_id=None,
created_by_role="account",
created_by=account.id,
)
db.session.add(dataset_query)
dataset_queries = []
if query:
content = {"content_type": QueryType.TEXT_QUERY, "content": query}
dataset_queries.append(content)
if attachment_ids:
for attachment_id in attachment_ids:
content = {"content_type": QueryType.IMAGE_QUERY, "content": attachment_id}
dataset_queries.append(content)
if dataset_queries:
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=json.dumps(dataset_queries),
source="hit_testing",
source_app_id=None,
created_by_role="account",
created_by=account.id,
)
db.session.add(dataset_query)
db.session.commit()
return cls.compact_retrieve_response(query, all_documents)
@@ -167,10 +178,15 @@ class HitTestingService:
@classmethod
def hit_testing_args_check(cls, args):
query = args["query"]
query = args.get("query")
attachment_ids = args.get("attachment_ids")
if not query or len(query) > 250:
raise ValueError("Query is required and cannot exceed 250 characters")
if not attachment_ids and not query:
raise ValueError("Query or attachment_ids is required")
if query and len(query) > 250:
raise ValueError("Query cannot exceed 250 characters")
if attachment_ids and not isinstance(attachment_ids, list):
raise ValueError("Attachment_ids must be a list")
@staticmethod
def escape_query_for_search(query: str) -> str:

View File

@@ -70,9 +70,28 @@ class ModelProviderService:
continue
provider_config = provider_configuration.custom_configuration.provider
model_config = provider_configuration.custom_configuration.models
models = provider_configuration.custom_configuration.models
can_added_models = provider_configuration.custom_configuration.can_added_models
# IMPORTANT: Never expose decrypted credentials in the provider list API.
# Sanitize custom model configurations by dropping the credentials payload.
sanitized_model_config = []
if models:
from core.entities.provider_entities import CustomModelConfiguration # local import to avoid cycles
for model in models:
sanitized_model_config.append(
CustomModelConfiguration(
model=model.model,
model_type=model.model_type,
credentials=None, # strip secrets from list view
current_credential_id=model.current_credential_id,
current_credential_name=model.current_credential_name,
available_model_credentials=model.available_model_credentials,
unadded_to_model_list=model.unadded_to_model_list,
)
)
provider_response = ProviderResponse(
tenant_id=tenant_id,
provider=provider_configuration.provider.provider,
@@ -95,7 +114,7 @@ class ModelProviderService:
current_credential_id=getattr(provider_config, "current_credential_id", None),
current_credential_name=getattr(provider_config, "current_credential_name", None),
available_credentials=getattr(provider_config, "available_credentials", []),
custom_models=model_config,
custom_models=sanitized_model_config,
can_added_models=can_added_models,
),
system_configuration=SystemConfigurationResponse(

View File

@@ -4,11 +4,14 @@ from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.rag.models.document import AttachmentDocument, Document
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode
@@ -21,9 +24,10 @@ class VectorService:
cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str
):
documents: list[Document] = []
multimodal_documents: list[AttachmentDocument] = []
for segment in segments:
if doc_form == IndexType.PARENT_CHILD_INDEX:
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
if not dataset_document:
logger.warning(
@@ -70,12 +74,29 @@ class VectorService:
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.TEXT,
},
)
documents.append(rag_document)
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_document: AttachmentDocument = AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
multimodal_documents.append(multimodal_document)
index_processor: BaseIndexProcessor = IndexProcessorFactory(doc_form).init_index_processor()
if len(documents) > 0:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
index_processor.load(dataset, documents, None, with_keywords=True, keywords_list=keywords_list)
if len(multimodal_documents) > 0:
index_processor.load(dataset, [], multimodal_documents, with_keywords=False)
@classmethod
def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset):
@@ -130,6 +151,7 @@ class VectorService:
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.TEXT,
},
)
# use full doc mode to generate segment's child chunk
@@ -226,3 +248,92 @@ class VectorService:
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
vector = Vector(dataset=dataset)
vector.delete_by_ids([child_chunk.index_node_id])
@classmethod
def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset):
if dataset.indexing_technique != "high_quality":
return
attachments = segment.attachments
old_attachment_ids = [attachment["id"] for attachment in attachments] if attachments else []
# Check if there's any actual change needed
if set(attachment_ids) == set(old_attachment_ids):
return
try:
vector = Vector(dataset=dataset)
if dataset.is_multimodal:
# Delete old vectors if they exist
if old_attachment_ids:
vector.delete_by_ids(old_attachment_ids)
# Delete existing segment attachment bindings in one operation
db.session.query(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id).delete(
synchronize_session=False
)
if not attachment_ids:
db.session.commit()
return
# Bulk fetch upload files - only fetch needed fields
upload_file_list = db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
if not upload_file_list:
db.session.commit()
return
# Create a mapping for quick lookup
upload_file_map = {upload_file.id: upload_file for upload_file in upload_file_list}
# Prepare batch operations
bindings = []
documents = []
# Create common metadata base to avoid repetition
base_metadata = {
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
}
# Process attachments in the order specified by attachment_ids
for attachment_id in attachment_ids:
upload_file = upload_file_map.get(attachment_id)
if not upload_file:
logger.warning("Upload file not found for attachment_id: %s", attachment_id)
continue
# Create segment attachment binding
bindings.append(
SegmentAttachmentBinding(
tenant_id=segment.tenant_id,
dataset_id=segment.dataset_id,
document_id=segment.document_id,
segment_id=segment.id,
attachment_id=upload_file.id,
)
)
# Create document for vector indexing
documents.append(
Document(page_content=upload_file.name, metadata={**base_metadata, "doc_id": upload_file.id})
)
# Bulk insert all bindings at once
if bindings:
db.session.add_all(bindings)
# Add documents to vector store if any
if documents and dataset.is_multimodal:
vector.add_texts(documents, duplicate_check=True)
# Single commit for all operations
db.session.commit()
except Exception:
logger.exception("Failed to update multimodal vector for segment %s", segment.id)
db.session.rollback()
raise

View File

@@ -4,9 +4,10 @@ import time
import click
from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
@@ -55,6 +56,7 @@ def add_document_to_index_task(dataset_document_id: str):
)
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
@@ -65,7 +67,7 @@ def add_document_to_index_task(dataset_document_id: str):
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
@@ -81,11 +83,25 @@ def add_document_to_index_task(dataset_document_id: str):
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents)
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
# delete auto disable log
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()

View File

@@ -18,6 +18,7 @@ from models.dataset import (
DatasetQuery,
Document,
DocumentSegment,
SegmentAttachmentBinding,
)
from models.model import UploadFile
@@ -58,14 +59,20 @@ def clean_dataset_task(
)
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
# Use JOIN to fetch attachments with bindings in a single query
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id)
).all()
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
# This ensures all invalid doc_form values are properly handled
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
doc_form = IndexType.PARAGRAPH_INDEX
doc_form = IndexStructureType.PARAGRAPH_INDEX
logger.info(
click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
)
@@ -90,6 +97,7 @@ def clean_dataset_task(
for document in documents:
db.session.delete(document)
# delete document file
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
@@ -107,6 +115,19 @@ def clean_dataset_task(
)
db.session.delete(image_file)
db.session.delete(segment)
# delete segment attachments
if attachments_with_bindings:
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
db.session.delete(attachment_file)
db.session.delete(binding)
db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()

View File

@@ -9,7 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding
from models.model import UploadFile
logger = logging.getLogger(__name__)
@@ -36,6 +36,16 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
raise Exception("Document has no dataset")
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
# Use JOIN to fetch attachments with bindings in a single query
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
SegmentAttachmentBinding.dataset_id == dataset_id,
SegmentAttachmentBinding.document_id == document_id,
)
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
@@ -69,6 +79,19 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
db.session.delete(file)
db.session.commit()
# delete segment attachments
if attachments_with_bindings:
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
db.session.delete(attachment_file)
db.session.delete(binding)
# delete dataset metadata binding
db.session.query(DatasetMetadataBinding).where(

View File

@@ -4,9 +4,10 @@ import time
import click
from celery import shared_task # type: ignore
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@@ -28,7 +29,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "upgrade":
dataset_documents = (
@@ -119,6 +120,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
)
if segments:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
@@ -129,7 +131,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
@@ -145,9 +147,25 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)

View File

@@ -1,14 +1,14 @@
import logging
import time
from typing import Literal
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]):
def deal_dataset_vector_index_task(dataset_id: str, action: str):
"""
Async deal dataset from index
:param dataset_id: dataset_id
@@ -32,7 +32,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "remove":
index_processor.clean(dataset, None, with_keywords=False)
@@ -119,6 +119,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
)
if segments:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
@@ -129,7 +130,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
@@ -145,9 +146,25 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)

View File

@@ -6,14 +6,15 @@ from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, Document
from models.dataset import Dataset, Document, SegmentAttachmentBinding
from models.model import UploadFile
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_segment_from_index_task(
index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None
index_node_ids: list, dataset_id: str, document_id: str, segment_ids: list, child_node_ids: list | None = None
):
"""
Async Remove segment from index
@@ -49,6 +50,21 @@ def delete_segment_from_index_task(
delete_child_chunks=True,
precomputed_child_node_ids=child_node_ids,
)
if dataset.is_multimodal:
# delete segment attachment binding
segment_attachment_bindings = (
db.session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
)
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
for binding in segment_attachment_bindings:
db.session.delete(binding)
# delete upload file
db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
db.session.commit()
end_at = time.perf_counter()
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))

Some files were not shown because too many files have changed in this diff Show More