mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 16:36:28 +08:00
refactor: partition Celery task sessions into smaller, discrete execu… (#32085)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -48,6 +48,11 @@ def batch_create_segment_to_index_task(
|
||||
|
||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||
|
||||
# Initialize variables with default values
|
||||
upload_file_key: str | None = None
|
||||
dataset_config: dict | None = None
|
||||
document_config: dict | None = None
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
@@ -69,86 +74,115 @@ def batch_create_segment_to_index_task(
|
||||
if not upload_file:
|
||||
raise ValueError("UploadFile not found.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file.key, file_path)
|
||||
dataset_config = {
|
||||
"id": dataset.id,
|
||||
"indexing_technique": dataset.indexing_technique,
|
||||
"tenant_id": dataset.tenant_id,
|
||||
"embedding_model_provider": dataset.embedding_model_provider,
|
||||
"embedding_model": dataset.embedding_model,
|
||||
}
|
||||
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
document_config = {
|
||||
"id": dataset_document.id,
|
||||
"doc_form": dataset_document.doc_form,
|
||||
"word_count": dataset_document.word_count or 0,
|
||||
}
|
||||
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
upload_file_key = upload_file.key
|
||||
|
||||
word_count_change = 0
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(
|
||||
texts=[segment["content"] for segment in content]
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Segments batch created index failed")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
return
|
||||
|
||||
# Ensure required variables are set before proceeding
|
||||
if upload_file_key is None or dataset_config is None or document_config is None:
|
||||
logger.error("Required configuration not set due to session error")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
return
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file_key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file_key, file_path)
|
||||
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if document_config["doc_form"] == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
tokens_list = [0] * len(content)
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == dataset_document.id)
|
||||
.scalar()
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=naive_utc_now(),
|
||||
status="completed",
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset_config["indexing_technique"] == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset_config["tenant_id"],
|
||||
provider=dataset_config["embedding_model_provider"],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset_config["embedding_model"],
|
||||
)
|
||||
|
||||
word_count_change = 0
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content])
|
||||
else:
|
||||
tokens_list = [0] * len(content)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == document_config["id"])
|
||||
.scalar()
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=naive_utc_now(),
|
||||
status="completed",
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if document_config["doc_form"] == "qa_model":
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
dataset_document = session.get(Document, document_id)
|
||||
if dataset_document:
|
||||
assert dataset_document.word_count is not None
|
||||
dataset_document.word_count += word_count_change
|
||||
session.add(dataset_document)
|
||||
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
|
||||
session.commit()
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Segments batch created index failed")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if dataset:
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"])
|
||||
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user