refactor: select in console datasets document controller (#34029)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
tmimmanuel
2026-03-25 04:47:25 +01:00
committed by GitHub
parent 4c32acf857
commit d87263f7c3
55 changed files with 233 additions and 195 deletions

View File

@@ -21,7 +21,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.helper.name_generator import generate_incremental_name
from core.model_manager import ModelManager
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.file import helpers as file_helpers
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
@@ -228,7 +228,7 @@ class DatasetService:
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
embedding_model = None
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
if embedding_model_provider and embedding_model_name:
# check if embedding model setting is valid
@@ -254,7 +254,10 @@ class DatasetService:
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
dataset = Dataset(name=name, indexing_technique=indexing_technique)
dataset = Dataset(
name=name,
indexing_technique=IndexTechniqueType(indexing_technique) if indexing_technique else None,
)
# dataset = Dataset(name=name, provider=provider, config=config)
dataset.description = description
dataset.created_by = account.id
@@ -349,7 +352,7 @@ class DatasetService:
@staticmethod
def check_dataset_model_setting(dataset):
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager.get_model_instance(
@@ -717,13 +720,13 @@ class DatasetService:
if "indexing_technique" not in data:
return None
if dataset.indexing_technique != data["indexing_technique"]:
if data["indexing_technique"] == "economy":
if data["indexing_technique"] == IndexTechniqueType.ECONOMY:
# Remove embedding model configuration for economy mode
filtered_data["embedding_model"] = None
filtered_data["embedding_model_provider"] = None
filtered_data["collection_binding_id"] = None
return "remove"
elif data["indexing_technique"] == "high_quality":
elif data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY:
# Configure embedding model for high quality mode
DatasetService._configure_embedding_model_for_high_quality(data, filtered_data)
return "add"
@@ -953,8 +956,8 @@ class DatasetService:
dataset = session.merge(dataset)
if not has_published:
dataset.chunk_structure = knowledge_configuration.chunk_structure
dataset.indexing_technique = knowledge_configuration.indexing_technique
if knowledge_configuration.indexing_technique == "high_quality":
dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique)
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, # ignore type error
@@ -976,7 +979,7 @@ class DatasetService:
embedding_model_name,
)
dataset.collection_binding_id = dataset_collection_binding.id
elif knowledge_configuration.indexing_technique == "economy":
elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
dataset.keyword_number = knowledge_configuration.keyword_number
else:
raise ValueError("Invalid index method")
@@ -991,9 +994,9 @@ class DatasetService:
action = None
if dataset.indexing_technique != knowledge_configuration.indexing_technique:
# if update indexing_technique
if knowledge_configuration.indexing_technique == "economy":
if knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
elif knowledge_configuration.indexing_technique == "high_quality":
elif knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
action = "add"
# get embedding model setting
try:
@@ -1018,7 +1021,7 @@ class DatasetService:
)
dataset.is_multimodal = is_multimodal
dataset.collection_binding_id = dataset_collection_binding.id
dataset.indexing_technique = knowledge_configuration.indexing_technique
dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique)
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
@@ -1029,7 +1032,7 @@ class DatasetService:
else:
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
# Skip embedding model checks if not provided in the update request
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
skip_embedding_update = False
try:
# Handle existing model provider
@@ -1089,7 +1092,7 @@ class DatasetService:
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
elif dataset.indexing_technique == "economy":
elif dataset.indexing_technique == IndexTechniqueType.ECONOMY:
if dataset.keyword_number != knowledge_configuration.keyword_number:
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
@@ -1907,8 +1910,8 @@ class DocumentService:
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Indexing technique is invalid")
dataset.indexing_technique = knowledge_config.indexing_technique
if knowledge_config.indexing_technique == "high_quality":
dataset.indexing_technique = IndexTechniqueType(knowledge_config.indexing_technique)
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
dataset_embedding_model = knowledge_config.embedding_model
@@ -2689,7 +2692,7 @@ class DocumentService:
dataset_collection_binding_id = None
retrieval_model = None
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
assert knowledge_config.embedding_model_provider
assert knowledge_config.embedding_model
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
@@ -2712,7 +2715,7 @@ class DocumentService:
tenant_id=tenant_id,
name="",
data_source_type=knowledge_config.data_source.info_list.data_source_type,
indexing_technique=knowledge_config.indexing_technique,
indexing_technique=IndexTechniqueType(knowledge_config.indexing_technique),
created_by=account.id,
embedding_model=knowledge_config.embedding_model,
embedding_model_provider=knowledge_config.embedding_model_provider,
@@ -3125,7 +3128,7 @@ class SegmentService:
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
@@ -3208,7 +3211,7 @@ class SegmentService:
try:
with redis_client.lock(lock_name, timeout=600):
embedding_model = None
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
@@ -3230,7 +3233,7 @@ class SegmentService:
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality" and embedding_model:
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY and embedding_model:
# calc embedding use tokens
if document.doc_form == IndexStructureType.QA_INDEX:
tokens = embedding_model.get_text_embedding_num_tokens(
@@ -3345,7 +3348,7 @@ class SegmentService:
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# regenerate child chunks
# get embedding model instance
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
model_manager = ModelManager()
@@ -3382,7 +3385,7 @@ class SegmentService:
# When user manually provides summary, allow saving even if summary_index_setting doesn't exist
# summary_index_setting is only needed for LLM generation, not for manual summary vectorization
# Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# Query existing summary from database
from models.dataset import DocumentSegmentSummary
@@ -3409,7 +3412,7 @@ class SegmentService:
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
@@ -3449,7 +3452,7 @@ class SegmentService:
db.session.commit()
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# get embedding model instance
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
model_manager = ModelManager()
@@ -3481,7 +3484,7 @@ class SegmentService:
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
# Handle summary index when content changed
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
from models.dataset import DocumentSegmentSummary
existing_summary = (

View File

@@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
from core.helper import ssrf_proxy
from core.helper.name_generator import generate_incremental_name
from core.plugin.entities.plugin import PluginDependency
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.workflow.nodes.datasource.entities import DatasourceNodeData
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
@@ -311,13 +312,13 @@ class RagPipelineDslService:
"icon_background": icon_background,
"icon_url": icon_url,
},
indexing_technique=knowledge_configuration.indexing_technique,
indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique),
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
chunk_structure=knowledge_configuration.chunk_structure,
)
if knowledge_configuration.indexing_technique == "high_quality":
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
dataset_collection_binding = (
self._session.query(DatasetCollectionBinding)
.where(
@@ -343,7 +344,7 @@ class RagPipelineDslService:
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = knowledge_configuration.embedding_model
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
dataset.keyword_number = knowledge_configuration.keyword_number
# Update summary_index_setting if provided
if knowledge_configuration.summary_index_setting is not None:
@@ -443,18 +444,18 @@ class RagPipelineDslService:
"icon_background": icon_background,
"icon_url": icon_url,
},
indexing_technique=knowledge_configuration.indexing_technique,
indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique),
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
chunk_structure=knowledge_configuration.chunk_structure,
)
else:
dataset.indexing_technique = knowledge_configuration.indexing_technique
dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique)
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.indexing_technique == "high_quality":
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
dataset_collection_binding = (
self._session.query(DatasetCollectionBinding)
.where(
@@ -480,7 +481,7 @@ class RagPipelineDslService:
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = knowledge_configuration.embedding_model
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
dataset.keyword_number = knowledge_configuration.keyword_number
# Update summary_index_setting if provided
if knowledge_configuration.summary_index_setting is not None:
@@ -772,7 +773,7 @@ class RagPipelineDslService:
)
case _ if typ == KNOWLEDGE_INDEX_NODE_TYPE:
knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"])
if knowledge_index_entity.indexing_technique == "high_quality":
if knowledge_index_entity.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if knowledge_index_entity.embedding_model_provider:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(

View File

@@ -9,7 +9,7 @@ from flask_login import current_user
from constants import DOCUMENT_EXTENSIONS
from core.plugin.impl.plugin import PluginInstaller
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from factories import variable_factory
@@ -105,29 +105,29 @@ class RagPipelineTransformService:
if doc_form == IndexStructureType.PARAGRAPH_INDEX:
match datasource_type:
case DataSourceType.UPLOAD_FILE:
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# get graph from transform.file-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f:
pipeline_yaml = yaml.safe_load(f)
if indexing_technique == "economy":
if indexing_technique == IndexTechniqueType.ECONOMY:
# get graph from transform.file-general-economy.yml
with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case DataSourceType.NOTION_IMPORT:
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# get graph from transform.notion-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f:
pipeline_yaml = yaml.safe_load(f)
if indexing_technique == "economy":
if indexing_technique == IndexTechniqueType.ECONOMY:
# get graph from transform.notion-general-economy.yml
with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case DataSourceType.WEBSITE_CRAWL:
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# get graph from transform.website-crawl-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f:
pipeline_yaml = yaml.safe_load(f)
if indexing_technique == "economy":
if indexing_technique == IndexTechniqueType.ECONOMY:
# get graph from transform.website-crawl-general-economy.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
@@ -170,11 +170,11 @@ class RagPipelineTransformService:
):
knowledge_configuration_dict = node.get("data", {})
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
knowledge_configuration.embedding_model = dataset.embedding_model
knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider
if retrieval_model:
if indexing_technique == "economy":
if indexing_technique == IndexTechniqueType.ECONOMY:
retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH
knowledge_configuration.retrieval_model = retrieval_model
else:

View File

@@ -12,6 +12,7 @@ from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.rag.models.document import Document
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
@@ -140,7 +141,7 @@ class SummaryIndexService:
session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one.
If not provided, creates a new session and commits automatically.
"""
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
logger.warning(
"Summary vectorization skipped for dataset %s: indexing_technique is not high_quality",
dataset.id,
@@ -724,7 +725,7 @@ class SummaryIndexService:
List of created DocumentSegmentSummary instances
"""
# Only generate summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
logger.info(
"Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'",
dataset.id,
@@ -851,7 +852,7 @@ class SummaryIndexService:
)
# Remove from vector database (but keep records)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
if summary_node_ids:
try:
@@ -889,7 +890,7 @@ class SummaryIndexService:
segment_ids: List of segment IDs to enable summaries for. If None, enable all.
"""
# Only enable summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
return
with session_factory.create_session() as session:
@@ -981,7 +982,7 @@ class SummaryIndexService:
return
# Delete from vector database
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
if summary_node_ids:
vector = Vector(dataset)
@@ -1012,7 +1013,7 @@ class SummaryIndexService:
Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality
"""
# Only update summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
return None
# When user manually provides summary, allow saving even if summary_index_setting doesn't exist

View File

@@ -4,7 +4,7 @@ from core.model_manager import ModelInstance, ModelManager
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, Document
@@ -45,7 +45,7 @@ class VectorService:
if not processing_rule:
raise ValueError("No processing rule found.")
# get embedding model instance
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
model_manager = ModelManager()
@@ -112,7 +112,7 @@ class VectorService:
"dataset_id": segment.dataset_id,
},
)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# update vector index
vector = Vector(dataset=dataset)
vector.delete_by_ids([segment.index_node_id])
@@ -197,7 +197,7 @@ class VectorService:
"dataset_id": child_segment.dataset_id,
},
)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# save vector index
vector = Vector(dataset=dataset)
vector.add_texts([child_document], duplicate_check=True)
@@ -237,7 +237,7 @@ class VectorService:
delete_node_ids.append(update_child_chunk.index_node_id)
for delete_child_chunk in delete_child_chunks:
delete_node_ids.append(delete_child_chunk.index_node_id)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# update vector index
vector = Vector(dataset=dataset)
if delete_node_ids:
@@ -252,7 +252,7 @@ class VectorService:
@classmethod
def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset):
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
return
attachments = segment.attachments