mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 06:09:24 +08:00
fix: enrich Service API segment responses with summary content (#34221)
Co-authored-by: jigangz <jigangz@github.com> Co-authored-by: FFXN <31929997+FFXN@users.noreply.github.com>
This commit is contained in:
@@ -29,6 +29,31 @@ from services.entities.knowledge_entities.knowledge_entities import SegmentUpdat
|
|||||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||||
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
|
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
|
||||||
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
|
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
|
||||||
|
from services.summary_index_service import SummaryIndexService
|
||||||
|
|
||||||
|
|
||||||
|
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
|
||||||
|
"""Marshal a single segment and enrich it with summary content."""
|
||||||
|
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||||
|
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||||
|
segment_dict["summary"] = summary.summary_content if summary else None
|
||||||
|
return segment_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
|
||||||
|
"""Marshal multiple segments and enrich them with summary content (batch query)."""
|
||||||
|
segment_ids = [segment.id for segment in segments]
|
||||||
|
summaries: dict = {}
|
||||||
|
if segment_ids:
|
||||||
|
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||||
|
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for segment in segments:
|
||||||
|
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||||
|
segment_dict["summary"] = summaries.get(segment.id)
|
||||||
|
result.append(segment_dict)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class SegmentCreatePayload(BaseModel):
|
class SegmentCreatePayload(BaseModel):
|
||||||
@@ -132,7 +157,7 @@ class SegmentApi(DatasetApiResource):
|
|||||||
for args_item in payload.segments:
|
for args_item in payload.segments:
|
||||||
SegmentService.segment_create_args_validate(args_item, document)
|
SegmentService.segment_create_args_validate(args_item, document)
|
||||||
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
|
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
|
||||||
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
|
return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200
|
||||||
else:
|
else:
|
||||||
return {"error": "Segments is required"}, 400
|
return {"error": "Segments is required"}, 400
|
||||||
|
|
||||||
@@ -196,7 +221,7 @@ class SegmentApi(DatasetApiResource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"data": marshal(segments, segment_fields),
|
"data": _marshal_segments_with_summary(segments, dataset_id),
|
||||||
"doc_form": document.doc_form,
|
"doc_form": document.doc_form,
|
||||||
"total": total,
|
"total": total,
|
||||||
"has_more": len(segments) == limit,
|
"has_more": len(segments) == limit,
|
||||||
@@ -296,7 +321,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||||||
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
|
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||||
|
|
||||||
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
|
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
|
||||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||||
|
|
||||||
@service_api_ns.doc("get_segment")
|
@service_api_ns.doc("get_segment")
|
||||||
@service_api_ns.doc(description="Get a specific segment by ID")
|
@service_api_ns.doc(description="Get a specific segment by ID")
|
||||||
@@ -326,7 +351,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
|
|
||||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||||
|
|
||||||
|
|
||||||
@service_api_ns.route(
|
@service_api_ns.route(
|
||||||
|
|||||||
@@ -768,6 +768,7 @@ class TestSegmentApiGet:
|
|||||||
``current_account_with_tenant()`` and ``marshal``.
|
``current_account_with_tenant()`` and ``marshal``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||||
@patch("controllers.service_api.dataset.segment.marshal")
|
@patch("controllers.service_api.dataset.segment.marshal")
|
||||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||||
@@ -780,6 +781,7 @@ class TestSegmentApiGet:
|
|||||||
mock_doc_svc,
|
mock_doc_svc,
|
||||||
mock_seg_svc,
|
mock_seg_svc,
|
||||||
mock_marshal,
|
mock_marshal,
|
||||||
|
mock_summary_svc,
|
||||||
app,
|
app,
|
||||||
mock_tenant,
|
mock_tenant,
|
||||||
mock_dataset,
|
mock_dataset,
|
||||||
@@ -791,7 +793,8 @@ class TestSegmentApiGet:
|
|||||||
mock_db.session.scalar.return_value = mock_dataset
|
mock_db.session.scalar.return_value = mock_dataset
|
||||||
mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||||
mock_seg_svc.get_segments.return_value = ([mock_segment], 1)
|
mock_seg_svc.get_segments.return_value = ([mock_segment], 1)
|
||||||
mock_marshal.return_value = [{"id": mock_segment.id}]
|
mock_marshal.return_value = {"id": mock_segment.id}
|
||||||
|
mock_summary_svc.get_segments_summaries.return_value = {}
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with app.test_request_context(
|
with app.test_request_context(
|
||||||
@@ -872,6 +875,7 @@ class TestSegmentApiPost:
|
|||||||
mock_rate_limit.enabled = False
|
mock_rate_limit.enabled = False
|
||||||
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
|
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
|
||||||
|
|
||||||
|
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||||
@patch("controllers.service_api.dataset.segment.marshal")
|
@patch("controllers.service_api.dataset.segment.marshal")
|
||||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||||
@@ -888,6 +892,7 @@ class TestSegmentApiPost:
|
|||||||
mock_doc_svc,
|
mock_doc_svc,
|
||||||
mock_seg_svc,
|
mock_seg_svc,
|
||||||
mock_marshal,
|
mock_marshal,
|
||||||
|
mock_summary_svc,
|
||||||
app,
|
app,
|
||||||
mock_tenant,
|
mock_tenant,
|
||||||
mock_dataset,
|
mock_dataset,
|
||||||
@@ -909,7 +914,8 @@ class TestSegmentApiPost:
|
|||||||
|
|
||||||
mock_seg_svc.segment_create_args_validate.return_value = None
|
mock_seg_svc.segment_create_args_validate.return_value = None
|
||||||
mock_seg_svc.multi_create_segment.return_value = [mock_segment]
|
mock_seg_svc.multi_create_segment.return_value = [mock_segment]
|
||||||
mock_marshal.return_value = [{"id": mock_segment.id}]
|
mock_marshal.return_value = {"id": mock_segment.id}
|
||||||
|
mock_summary_svc.get_segments_summaries.return_value = {}
|
||||||
|
|
||||||
segments_data = [{"content": "Test segment content", "answer": "Test answer"}]
|
segments_data = [{"content": "Test segment content", "answer": "Test answer"}]
|
||||||
|
|
||||||
@@ -1206,6 +1212,7 @@ class TestDatasetSegmentApiUpdate:
|
|||||||
mock_rate_limit.enabled = False
|
mock_rate_limit.enabled = False
|
||||||
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
|
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
|
||||||
|
|
||||||
|
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||||
@patch("controllers.service_api.dataset.segment.marshal")
|
@patch("controllers.service_api.dataset.segment.marshal")
|
||||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||||
@@ -1224,6 +1231,7 @@ class TestDatasetSegmentApiUpdate:
|
|||||||
mock_doc_svc,
|
mock_doc_svc,
|
||||||
mock_seg_svc,
|
mock_seg_svc,
|
||||||
mock_marshal,
|
mock_marshal,
|
||||||
|
mock_summary_svc,
|
||||||
app,
|
app,
|
||||||
mock_tenant,
|
mock_tenant,
|
||||||
mock_dataset,
|
mock_dataset,
|
||||||
@@ -1240,6 +1248,7 @@ class TestDatasetSegmentApiUpdate:
|
|||||||
updated = Mock()
|
updated = Mock()
|
||||||
mock_seg_svc.update_segment.return_value = updated
|
mock_seg_svc.update_segment.return_value = updated
|
||||||
mock_marshal.return_value = {"id": mock_segment.id}
|
mock_marshal.return_value = {"id": mock_segment.id}
|
||||||
|
mock_summary_svc.get_segment_summary.return_value = None
|
||||||
|
|
||||||
with app.test_request_context(
|
with app.test_request_context(
|
||||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
|
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
|
||||||
@@ -1349,6 +1358,7 @@ class TestDatasetSegmentApiGetSingle:
|
|||||||
``current_account_with_tenant()`` and ``marshal``.
|
``current_account_with_tenant()`` and ``marshal``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||||
@patch("controllers.service_api.dataset.segment.marshal")
|
@patch("controllers.service_api.dataset.segment.marshal")
|
||||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||||
@@ -1363,6 +1373,7 @@ class TestDatasetSegmentApiGetSingle:
|
|||||||
mock_doc_svc,
|
mock_doc_svc,
|
||||||
mock_seg_svc,
|
mock_seg_svc,
|
||||||
mock_marshal,
|
mock_marshal,
|
||||||
|
mock_summary_svc,
|
||||||
app,
|
app,
|
||||||
mock_tenant,
|
mock_tenant,
|
||||||
mock_dataset,
|
mock_dataset,
|
||||||
@@ -1376,6 +1387,7 @@ class TestDatasetSegmentApiGetSingle:
|
|||||||
mock_doc_svc.get_document.return_value = mock_doc
|
mock_doc_svc.get_document.return_value = mock_doc
|
||||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||||
mock_marshal.return_value = {"id": mock_segment.id}
|
mock_marshal.return_value = {"id": mock_segment.id}
|
||||||
|
mock_summary_svc.get_segment_summary.return_value = None
|
||||||
|
|
||||||
with app.test_request_context(
|
with app.test_request_context(
|
||||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
|
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
|
||||||
@@ -1393,6 +1405,55 @@ class TestDatasetSegmentApiGetSingle:
|
|||||||
assert "data" in response
|
assert "data" in response
|
||||||
assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX
|
assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX
|
||||||
|
|
||||||
|
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||||
|
@patch("controllers.service_api.dataset.segment.marshal")
|
||||||
|
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||||
|
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||||
|
@patch("controllers.service_api.dataset.segment.DatasetService")
|
||||||
|
@patch("controllers.service_api.dataset.segment.current_account_with_tenant")
|
||||||
|
@patch("controllers.service_api.dataset.segment.db")
|
||||||
|
def test_get_single_segment_includes_summary(
|
||||||
|
self,
|
||||||
|
mock_db,
|
||||||
|
mock_account_fn,
|
||||||
|
mock_dataset_svc,
|
||||||
|
mock_doc_svc,
|
||||||
|
mock_seg_svc,
|
||||||
|
mock_marshal,
|
||||||
|
mock_summary_svc,
|
||||||
|
app,
|
||||||
|
mock_tenant,
|
||||||
|
mock_dataset,
|
||||||
|
mock_segment,
|
||||||
|
):
|
||||||
|
"""Test that single segment response includes summary content from SummaryIndexService."""
|
||||||
|
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||||
|
mock_db.session.scalar.return_value = mock_dataset
|
||||||
|
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||||
|
mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||||
|
mock_doc_svc.get_document.return_value = mock_doc
|
||||||
|
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||||
|
mock_marshal.return_value = {"id": mock_segment.id, "summary": None}
|
||||||
|
|
||||||
|
mock_summary_record = Mock()
|
||||||
|
mock_summary_record.summary_content = "This is the segment summary"
|
||||||
|
mock_summary_svc.get_segment_summary.return_value = mock_summary_record
|
||||||
|
|
||||||
|
with app.test_request_context(
|
||||||
|
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
|
||||||
|
method="GET",
|
||||||
|
):
|
||||||
|
api = DatasetSegmentApi()
|
||||||
|
response, status = api.get(
|
||||||
|
tenant_id=mock_tenant.id,
|
||||||
|
dataset_id=mock_dataset.id,
|
||||||
|
document_id="doc-id",
|
||||||
|
segment_id=mock_segment.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert status == 200
|
||||||
|
assert response["data"]["summary"] == "This is the segment summary"
|
||||||
|
|
||||||
@patch("controllers.service_api.dataset.segment.current_account_with_tenant")
|
@patch("controllers.service_api.dataset.segment.current_account_with_tenant")
|
||||||
@patch("controllers.service_api.dataset.segment.db")
|
@patch("controllers.service_api.dataset.segment.db")
|
||||||
def test_get_single_segment_dataset_not_found(
|
def test_get_single_segment_dataset_not_found(
|
||||||
|
|||||||
Reference in New Issue
Block a user