fix: enrich Service API segment responses with summary content (#34221)

Co-authored-by: jigangz <jigangz@github.com>
Co-authored-by: FFXN <31929997+FFXN@users.noreply.github.com>
This commit is contained in:
jigangz
2026-03-30 03:09:50 -07:00
committed by GitHub
parent 944db46d4f
commit 1aaba80211
2 changed files with 92 additions and 6 deletions

View File

@@ -29,6 +29,31 @@ from services.entities.knowledge_entities.knowledge_entities import SegmentUpdat
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
from services.summary_index_service import SummaryIndexService
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
"""Marshal a single segment and enrich it with summary content."""
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
"""Marshal multiple segments and enrich them with summary content (batch query)."""
segment_ids = [segment.id for segment in segments]
summaries: dict = {}
if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
result = []
for segment in segments:
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict["summary"] = summaries.get(segment.id)
result.append(segment_dict)
return result
class SegmentCreatePayload(BaseModel):
@@ -132,7 +157,7 @@ class SegmentApi(DatasetApiResource):
for args_item in payload.segments:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200
else:
return {"error": "Segments is required"}, 400
@@ -196,7 +221,7 @@ class SegmentApi(DatasetApiResource):
)
response = {
"data": marshal(segments, segment_fields),
"data": _marshal_segments_with_summary(segments, dataset_id),
"doc_form": document.doc_form,
"total": total,
"has_more": len(segments) == limit,
@@ -296,7 +321,7 @@ class DatasetSegmentApi(DatasetApiResource):
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200
@service_api_ns.doc("get_segment")
@service_api_ns.doc(description="Get a specific segment by ID")
@@ -326,7 +351,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@service_api_ns.route(

View File

@@ -768,6 +768,7 @@ class TestSegmentApiGet:
``current_account_with_tenant()`` and ``marshal``.
"""
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
@patch("controllers.service_api.dataset.segment.marshal")
@patch("controllers.service_api.dataset.segment.SegmentService")
@patch("controllers.service_api.dataset.segment.DocumentService")
@@ -780,6 +781,7 @@ class TestSegmentApiGet:
mock_doc_svc,
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
mock_tenant,
mock_dataset,
@@ -791,7 +793,8 @@ class TestSegmentApiGet:
mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
mock_seg_svc.get_segments.return_value = ([mock_segment], 1)
mock_marshal.return_value = [{"id": mock_segment.id}]
mock_marshal.return_value = {"id": mock_segment.id}
mock_summary_svc.get_segments_summaries.return_value = {}
# Act
with app.test_request_context(
@@ -872,6 +875,7 @@ class TestSegmentApiPost:
mock_rate_limit.enabled = False
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
@patch("controllers.service_api.dataset.segment.marshal")
@patch("controllers.service_api.dataset.segment.SegmentService")
@patch("controllers.service_api.dataset.segment.DocumentService")
@@ -888,6 +892,7 @@ class TestSegmentApiPost:
mock_doc_svc,
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
mock_tenant,
mock_dataset,
@@ -909,7 +914,8 @@ class TestSegmentApiPost:
mock_seg_svc.segment_create_args_validate.return_value = None
mock_seg_svc.multi_create_segment.return_value = [mock_segment]
mock_marshal.return_value = [{"id": mock_segment.id}]
mock_marshal.return_value = {"id": mock_segment.id}
mock_summary_svc.get_segments_summaries.return_value = {}
segments_data = [{"content": "Test segment content", "answer": "Test answer"}]
@@ -1206,6 +1212,7 @@ class TestDatasetSegmentApiUpdate:
mock_rate_limit.enabled = False
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
@patch("controllers.service_api.dataset.segment.marshal")
@patch("controllers.service_api.dataset.segment.SegmentService")
@patch("controllers.service_api.dataset.segment.DocumentService")
@@ -1224,6 +1231,7 @@ class TestDatasetSegmentApiUpdate:
mock_doc_svc,
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
mock_tenant,
mock_dataset,
@@ -1240,6 +1248,7 @@ class TestDatasetSegmentApiUpdate:
updated = Mock()
mock_seg_svc.update_segment.return_value = updated
mock_marshal.return_value = {"id": mock_segment.id}
mock_summary_svc.get_segment_summary.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
@@ -1349,6 +1358,7 @@ class TestDatasetSegmentApiGetSingle:
``current_account_with_tenant()`` and ``marshal``.
"""
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
@patch("controllers.service_api.dataset.segment.marshal")
@patch("controllers.service_api.dataset.segment.SegmentService")
@patch("controllers.service_api.dataset.segment.DocumentService")
@@ -1363,6 +1373,7 @@ class TestDatasetSegmentApiGetSingle:
mock_doc_svc,
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
mock_tenant,
mock_dataset,
@@ -1376,6 +1387,7 @@ class TestDatasetSegmentApiGetSingle:
mock_doc_svc.get_document.return_value = mock_doc
mock_seg_svc.get_segment_by_id.return_value = mock_segment
mock_marshal.return_value = {"id": mock_segment.id}
mock_summary_svc.get_segment_summary.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
@@ -1393,6 +1405,55 @@ class TestDatasetSegmentApiGetSingle:
assert "data" in response
assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
@patch("controllers.service_api.dataset.segment.marshal")
@patch("controllers.service_api.dataset.segment.SegmentService")
@patch("controllers.service_api.dataset.segment.DocumentService")
@patch("controllers.service_api.dataset.segment.DatasetService")
@patch("controllers.service_api.dataset.segment.current_account_with_tenant")
@patch("controllers.service_api.dataset.segment.db")
def test_get_single_segment_includes_summary(
self,
mock_db,
mock_account_fn,
mock_dataset_svc,
mock_doc_svc,
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
mock_tenant,
mock_dataset,
mock_segment,
):
"""Test that single segment response includes summary content from SummaryIndexService."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.scalar.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
mock_doc_svc.get_document.return_value = mock_doc
mock_seg_svc.get_segment_by_id.return_value = mock_segment
mock_marshal.return_value = {"id": mock_segment.id, "summary": None}
mock_summary_record = Mock()
mock_summary_record.summary_content = "This is the segment summary"
mock_summary_svc.get_segment_summary.return_value = mock_summary_record
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
method="GET",
):
api = DatasetSegmentApi()
response, status = api.get(
tenant_id=mock_tenant.id,
dataset_id=mock_dataset.id,
document_id="doc-id",
segment_id=mock_segment.id,
)
assert status == 200
assert response["data"]["summary"] == "This is the segment summary"
@patch("controllers.service_api.dataset.segment.current_account_with_tenant")
@patch("controllers.service_api.dataset.segment.db")
def test_get_single_segment_dataset_not_found(