mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 02:19:20 +08:00
fix: enrich Service API segment responses with summary content (#34221)
Co-authored-by: jigangz <jigangz@github.com> Co-authored-by: FFXN <31929997+FFXN@users.noreply.github.com>
This commit is contained in:
@@ -29,6 +29,31 @@ from services.entities.knowledge_entities.knowledge_entities import SegmentUpdat
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
|
||||
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
|
||||
"""Marshal a single segment and enrich it with summary content."""
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
|
||||
"""Marshal multiple segments and enrich them with summary content (batch query)."""
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries: dict = {}
|
||||
if segment_ids:
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
||||
|
||||
result = []
|
||||
for segment in segments:
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
result.append(segment_dict)
|
||||
return result
|
||||
|
||||
|
||||
class SegmentCreatePayload(BaseModel):
|
||||
@@ -132,7 +157,7 @@ class SegmentApi(DatasetApiResource):
|
||||
for args_item in payload.segments:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
|
||||
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200
|
||||
else:
|
||||
return {"error": "Segments is required"}, 400
|
||||
|
||||
@@ -196,7 +221,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
|
||||
response = {
|
||||
"data": marshal(segments, segment_fields),
|
||||
"data": _marshal_segments_with_summary(segments, dataset_id),
|
||||
"doc_form": document.doc_form,
|
||||
"total": total,
|
||||
"has_more": len(segments) == limit,
|
||||
@@ -296,7 +321,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
|
||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
@service_api_ns.doc("get_segment")
|
||||
@service_api_ns.doc(description="Get a specific segment by ID")
|
||||
@@ -326,7 +351,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
|
||||
@@ -768,6 +768,7 @@ class TestSegmentApiGet:
|
||||
``current_account_with_tenant()`` and ``marshal``.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||
@patch("controllers.service_api.dataset.segment.marshal")
|
||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||
@@ -780,6 +781,7 @@ class TestSegmentApiGet:
|
||||
mock_doc_svc,
|
||||
mock_seg_svc,
|
||||
mock_marshal,
|
||||
mock_summary_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
@@ -791,7 +793,8 @@ class TestSegmentApiGet:
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_seg_svc.get_segments.return_value = ([mock_segment], 1)
|
||||
mock_marshal.return_value = [{"id": mock_segment.id}]
|
||||
mock_marshal.return_value = {"id": mock_segment.id}
|
||||
mock_summary_svc.get_segments_summaries.return_value = {}
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
@@ -872,6 +875,7 @@ class TestSegmentApiPost:
|
||||
mock_rate_limit.enabled = False
|
||||
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||
@patch("controllers.service_api.dataset.segment.marshal")
|
||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||
@@ -888,6 +892,7 @@ class TestSegmentApiPost:
|
||||
mock_doc_svc,
|
||||
mock_seg_svc,
|
||||
mock_marshal,
|
||||
mock_summary_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
@@ -909,7 +914,8 @@ class TestSegmentApiPost:
|
||||
|
||||
mock_seg_svc.segment_create_args_validate.return_value = None
|
||||
mock_seg_svc.multi_create_segment.return_value = [mock_segment]
|
||||
mock_marshal.return_value = [{"id": mock_segment.id}]
|
||||
mock_marshal.return_value = {"id": mock_segment.id}
|
||||
mock_summary_svc.get_segments_summaries.return_value = {}
|
||||
|
||||
segments_data = [{"content": "Test segment content", "answer": "Test answer"}]
|
||||
|
||||
@@ -1206,6 +1212,7 @@ class TestDatasetSegmentApiUpdate:
|
||||
mock_rate_limit.enabled = False
|
||||
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||
@patch("controllers.service_api.dataset.segment.marshal")
|
||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||
@@ -1224,6 +1231,7 @@ class TestDatasetSegmentApiUpdate:
|
||||
mock_doc_svc,
|
||||
mock_seg_svc,
|
||||
mock_marshal,
|
||||
mock_summary_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
@@ -1240,6 +1248,7 @@ class TestDatasetSegmentApiUpdate:
|
||||
updated = Mock()
|
||||
mock_seg_svc.update_segment.return_value = updated
|
||||
mock_marshal.return_value = {"id": mock_segment.id}
|
||||
mock_summary_svc.get_segment_summary.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
|
||||
@@ -1349,6 +1358,7 @@ class TestDatasetSegmentApiGetSingle:
|
||||
``current_account_with_tenant()`` and ``marshal``.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||
@patch("controllers.service_api.dataset.segment.marshal")
|
||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||
@@ -1363,6 +1373,7 @@ class TestDatasetSegmentApiGetSingle:
|
||||
mock_doc_svc,
|
||||
mock_seg_svc,
|
||||
mock_marshal,
|
||||
mock_summary_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
@@ -1376,6 +1387,7 @@ class TestDatasetSegmentApiGetSingle:
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||
mock_marshal.return_value = {"id": mock_segment.id}
|
||||
mock_summary_svc.get_segment_summary.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
|
||||
@@ -1393,6 +1405,55 @@ class TestDatasetSegmentApiGetSingle:
|
||||
assert "data" in response
|
||||
assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||
@patch("controllers.service_api.dataset.segment.marshal")
|
||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||
@patch("controllers.service_api.dataset.segment.DatasetService")
|
||||
@patch("controllers.service_api.dataset.segment.current_account_with_tenant")
|
||||
@patch("controllers.service_api.dataset.segment.db")
|
||||
def test_get_single_segment_includes_summary(
|
||||
self,
|
||||
mock_db,
|
||||
mock_account_fn,
|
||||
mock_dataset_svc,
|
||||
mock_doc_svc,
|
||||
mock_seg_svc,
|
||||
mock_marshal,
|
||||
mock_summary_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
mock_segment,
|
||||
):
|
||||
"""Test that single segment response includes summary content from SummaryIndexService."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||
mock_marshal.return_value = {"id": mock_segment.id, "summary": None}
|
||||
|
||||
mock_summary_record = Mock()
|
||||
mock_summary_record.summary_content = "This is the segment summary"
|
||||
mock_summary_svc.get_segment_summary.return_value = mock_summary_record
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
|
||||
method="GET",
|
||||
):
|
||||
api = DatasetSegmentApi()
|
||||
response, status = api.get(
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
document_id="doc-id",
|
||||
segment_id=mock_segment.id,
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
assert response["data"]["summary"] == "This is the segment summary"
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.current_account_with_tenant")
|
||||
@patch("controllers.service_api.dataset.segment.db")
|
||||
def test_get_single_segment_dataset_not_found(
|
||||
|
||||
Reference in New Issue
Block a user