fix: enrich Service API segment responses with summary content (#34221)

Co-authored-by: jigangz <jigangz@github.com>
Co-authored-by: FFXN <31929997+FFXN@users.noreply.github.com>
This commit is contained in:
jigangz
2026-03-30 03:09:50 -07:00
committed by GitHub
parent 944db46d4f
commit 1aaba80211
2 changed files with 92 additions and 6 deletions

View File

@@ -29,6 +29,31 @@ from services.entities.knowledge_entities.knowledge_entities import SegmentUpdat
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
from services.summary_index_service import SummaryIndexService
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
"""Marshal a single segment and enrich it with summary content."""
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
"""Marshal multiple segments and enrich them with summary content (batch query)."""
segment_ids = [segment.id for segment in segments]
summaries: dict = {}
if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
result = []
for segment in segments:
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict["summary"] = summaries.get(segment.id)
result.append(segment_dict)
return result
class SegmentCreatePayload(BaseModel): class SegmentCreatePayload(BaseModel):
@@ -132,7 +157,7 @@ class SegmentApi(DatasetApiResource):
for args_item in payload.segments: for args_item in payload.segments:
SegmentService.segment_create_args_validate(args_item, document) SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(payload.segments, document, dataset) segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200 return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200
else: else:
return {"error": "Segments is required"}, 400 return {"error": "Segments is required"}, 400
@@ -196,7 +221,7 @@ class SegmentApi(DatasetApiResource):
) )
response = { response = {
"data": marshal(segments, segment_fields), "data": _marshal_segments_with_summary(segments, dataset_id),
"doc_form": document.doc_form, "doc_form": document.doc_form,
"total": total, "total": total,
"has_more": len(segments) == limit, "has_more": len(segments) == limit,
@@ -296,7 +321,7 @@ class DatasetSegmentApi(DatasetApiResource):
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {}) payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset) updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200
@service_api_ns.doc("get_segment") @service_api_ns.doc("get_segment")
@service_api_ns.doc(description="Get a specific segment by ID") @service_api_ns.doc(description="Get a specific segment by ID")
@@ -326,7 +351,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@service_api_ns.route( @service_api_ns.route(

View File

@@ -768,6 +768,7 @@ class TestSegmentApiGet:
``current_account_with_tenant()`` and ``marshal``. ``current_account_with_tenant()`` and ``marshal``.
""" """
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
@patch("controllers.service_api.dataset.segment.marshal") @patch("controllers.service_api.dataset.segment.marshal")
@patch("controllers.service_api.dataset.segment.SegmentService") @patch("controllers.service_api.dataset.segment.SegmentService")
@patch("controllers.service_api.dataset.segment.DocumentService") @patch("controllers.service_api.dataset.segment.DocumentService")
@@ -780,6 +781,7 @@ class TestSegmentApiGet:
mock_doc_svc, mock_doc_svc,
mock_seg_svc, mock_seg_svc,
mock_marshal, mock_marshal,
mock_summary_svc,
app, app,
mock_tenant, mock_tenant,
mock_dataset, mock_dataset,
@@ -791,7 +793,8 @@ class TestSegmentApiGet:
mock_db.session.scalar.return_value = mock_dataset mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
mock_seg_svc.get_segments.return_value = ([mock_segment], 1) mock_seg_svc.get_segments.return_value = ([mock_segment], 1)
mock_marshal.return_value = [{"id": mock_segment.id}] mock_marshal.return_value = {"id": mock_segment.id}
mock_summary_svc.get_segments_summaries.return_value = {}
# Act # Act
with app.test_request_context( with app.test_request_context(
@@ -872,6 +875,7 @@ class TestSegmentApiPost:
mock_rate_limit.enabled = False mock_rate_limit.enabled = False
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
@patch("controllers.service_api.dataset.segment.marshal") @patch("controllers.service_api.dataset.segment.marshal")
@patch("controllers.service_api.dataset.segment.SegmentService") @patch("controllers.service_api.dataset.segment.SegmentService")
@patch("controllers.service_api.dataset.segment.DocumentService") @patch("controllers.service_api.dataset.segment.DocumentService")
@@ -888,6 +892,7 @@ class TestSegmentApiPost:
mock_doc_svc, mock_doc_svc,
mock_seg_svc, mock_seg_svc,
mock_marshal, mock_marshal,
mock_summary_svc,
app, app,
mock_tenant, mock_tenant,
mock_dataset, mock_dataset,
@@ -909,7 +914,8 @@ class TestSegmentApiPost:
mock_seg_svc.segment_create_args_validate.return_value = None mock_seg_svc.segment_create_args_validate.return_value = None
mock_seg_svc.multi_create_segment.return_value = [mock_segment] mock_seg_svc.multi_create_segment.return_value = [mock_segment]
mock_marshal.return_value = [{"id": mock_segment.id}] mock_marshal.return_value = {"id": mock_segment.id}
mock_summary_svc.get_segments_summaries.return_value = {}
segments_data = [{"content": "Test segment content", "answer": "Test answer"}] segments_data = [{"content": "Test segment content", "answer": "Test answer"}]
@@ -1206,6 +1212,7 @@ class TestDatasetSegmentApiUpdate:
mock_rate_limit.enabled = False mock_rate_limit.enabled = False
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
@patch("controllers.service_api.dataset.segment.marshal") @patch("controllers.service_api.dataset.segment.marshal")
@patch("controllers.service_api.dataset.segment.SegmentService") @patch("controllers.service_api.dataset.segment.SegmentService")
@patch("controllers.service_api.dataset.segment.DocumentService") @patch("controllers.service_api.dataset.segment.DocumentService")
@@ -1224,6 +1231,7 @@ class TestDatasetSegmentApiUpdate:
mock_doc_svc, mock_doc_svc,
mock_seg_svc, mock_seg_svc,
mock_marshal, mock_marshal,
mock_summary_svc,
app, app,
mock_tenant, mock_tenant,
mock_dataset, mock_dataset,
@@ -1240,6 +1248,7 @@ class TestDatasetSegmentApiUpdate:
updated = Mock() updated = Mock()
mock_seg_svc.update_segment.return_value = updated mock_seg_svc.update_segment.return_value = updated
mock_marshal.return_value = {"id": mock_segment.id} mock_marshal.return_value = {"id": mock_segment.id}
mock_summary_svc.get_segment_summary.return_value = None
with app.test_request_context( with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}", f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
@@ -1349,6 +1358,7 @@ class TestDatasetSegmentApiGetSingle:
``current_account_with_tenant()`` and ``marshal``. ``current_account_with_tenant()`` and ``marshal``.
""" """
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
@patch("controllers.service_api.dataset.segment.marshal") @patch("controllers.service_api.dataset.segment.marshal")
@patch("controllers.service_api.dataset.segment.SegmentService") @patch("controllers.service_api.dataset.segment.SegmentService")
@patch("controllers.service_api.dataset.segment.DocumentService") @patch("controllers.service_api.dataset.segment.DocumentService")
@@ -1363,6 +1373,7 @@ class TestDatasetSegmentApiGetSingle:
mock_doc_svc, mock_doc_svc,
mock_seg_svc, mock_seg_svc,
mock_marshal, mock_marshal,
mock_summary_svc,
app, app,
mock_tenant, mock_tenant,
mock_dataset, mock_dataset,
@@ -1376,6 +1387,7 @@ class TestDatasetSegmentApiGetSingle:
mock_doc_svc.get_document.return_value = mock_doc mock_doc_svc.get_document.return_value = mock_doc
mock_seg_svc.get_segment_by_id.return_value = mock_segment mock_seg_svc.get_segment_by_id.return_value = mock_segment
mock_marshal.return_value = {"id": mock_segment.id} mock_marshal.return_value = {"id": mock_segment.id}
mock_summary_svc.get_segment_summary.return_value = None
with app.test_request_context( with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}", f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
@@ -1393,6 +1405,55 @@ class TestDatasetSegmentApiGetSingle:
assert "data" in response assert "data" in response
assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
@patch("controllers.service_api.dataset.segment.marshal")
@patch("controllers.service_api.dataset.segment.SegmentService")
@patch("controllers.service_api.dataset.segment.DocumentService")
@patch("controllers.service_api.dataset.segment.DatasetService")
@patch("controllers.service_api.dataset.segment.current_account_with_tenant")
@patch("controllers.service_api.dataset.segment.db")
def test_get_single_segment_includes_summary(
self,
mock_db,
mock_account_fn,
mock_dataset_svc,
mock_doc_svc,
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
mock_tenant,
mock_dataset,
mock_segment,
):
"""Test that single segment response includes summary content from SummaryIndexService."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.scalar.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
mock_doc_svc.get_document.return_value = mock_doc
mock_seg_svc.get_segment_by_id.return_value = mock_segment
mock_marshal.return_value = {"id": mock_segment.id, "summary": None}
mock_summary_record = Mock()
mock_summary_record.summary_content = "This is the segment summary"
mock_summary_svc.get_segment_summary.return_value = mock_summary_record
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
method="GET",
):
api = DatasetSegmentApi()
response, status = api.get(
tenant_id=mock_tenant.id,
dataset_id=mock_dataset.id,
document_id="doc-id",
segment_id=mock_segment.id,
)
assert status == 200
assert response["data"]["summary"] == "This is the segment summary"
@patch("controllers.service_api.dataset.segment.current_account_with_tenant") @patch("controllers.service_api.dataset.segment.current_account_with_tenant")
@patch("controllers.service_api.dataset.segment.db") @patch("controllers.service_api.dataset.segment.db")
def test_get_single_segment_dataset_not_found( def test_get_single_segment_dataset_not_found(