From 1aaba80211d85cc801f1d4dc2043df9034555359 Mon Sep 17 00:00:00 2001 From: jigangz <115519042+jigangz@users.noreply.github.com> Date: Mon, 30 Mar 2026 03:09:50 -0700 Subject: [PATCH] fix: enrich Service API segment responses with summary content (#34221) Co-authored-by: jigangz Co-authored-by: FFXN <31929997+FFXN@users.noreply.github.com> --- .../service_api/dataset/segment.py | 33 ++++++++-- .../dataset/test_dataset_segment.py | 65 ++++++++++++++++++- 2 files changed, 92 insertions(+), 6 deletions(-) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index b4cc9874b63..5b16da81e08 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -29,6 +29,31 @@ from services.entities.knowledge_entities.knowledge_entities import SegmentUpdat from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError +from services.summary_index_service import SummaryIndexService + + +def _marshal_segment_with_summary(segment, dataset_id: str) -> dict: + """Marshal a single segment and enrich it with summary content.""" + segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] + summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) + segment_dict["summary"] = summary.summary_content if summary else None + return segment_dict + + +def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]: + """Marshal multiple segments and enrich them with summary content (batch query).""" + segment_ids = [segment.id for segment in segments] + summaries: dict = {} + if segment_ids: + summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id) + summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()} + + result = [] + for segment in segments: + segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] + segment_dict["summary"] = summaries.get(segment.id) + result.append(segment_dict) + return result class SegmentCreatePayload(BaseModel): @@ -132,7 +157,7 @@ class SegmentApi(DatasetApiResource): for args_item in payload.segments: SegmentService.segment_create_args_validate(args_item, document) segments = SegmentService.multi_create_segment(payload.segments, document, dataset) - return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200 + return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200 else: return {"error": "Segments is required"}, 400 @@ -196,7 +221,7 @@ class SegmentApi(DatasetApiResource): ) response = { - "data": marshal(segments, segment_fields), + "data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form, "total": total, "has_more": len(segments) == limit, @@ -296,7 +321,7 @@ class DatasetSegmentApi(DatasetApiResource): payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {}) updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset) - return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 + return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200 @service_api_ns.doc("get_segment") @service_api_ns.doc(description="Get a specific segment by ID") @@ -326,7 +351,7 @@ class DatasetSegmentApi(DatasetApiResource): if not segment: raise NotFound("Segment not found.") - return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 + return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200 @service_api_ns.route( diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index 7f5d6b08390..e9c3e6d3769 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -768,6 +768,7 @@ class TestSegmentApiGet: ``current_account_with_tenant()`` and ``marshal``. """ + @patch("controllers.service_api.dataset.segment.SummaryIndexService") @patch("controllers.service_api.dataset.segment.marshal") @patch("controllers.service_api.dataset.segment.SegmentService") @patch("controllers.service_api.dataset.segment.DocumentService") @@ -780,6 +781,7 @@ class TestSegmentApiGet: mock_doc_svc, mock_seg_svc, mock_marshal, + mock_summary_svc, app, mock_tenant, mock_dataset, @@ -791,7 +793,8 @@ class TestSegmentApiGet: mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_seg_svc.get_segments.return_value = ([mock_segment], 1) - mock_marshal.return_value = [{"id": mock_segment.id}] + mock_marshal.return_value = {"id": mock_segment.id} + mock_summary_svc.get_segments_summaries.return_value = {} # Act with app.test_request_context( @@ -872,6 +875,7 @@ class TestSegmentApiPost: mock_rate_limit.enabled = False mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit + @patch("controllers.service_api.dataset.segment.SummaryIndexService") @patch("controllers.service_api.dataset.segment.marshal") @patch("controllers.service_api.dataset.segment.SegmentService") @patch("controllers.service_api.dataset.segment.DocumentService") @@ -888,6 +892,7 @@ class TestSegmentApiPost: mock_doc_svc, mock_seg_svc, mock_marshal, + mock_summary_svc, app, mock_tenant, mock_dataset, @@ -909,7 +914,8 @@ class TestSegmentApiPost: mock_seg_svc.segment_create_args_validate.return_value = None mock_seg_svc.multi_create_segment.return_value = [mock_segment] - mock_marshal.return_value = [{"id": mock_segment.id}] + mock_marshal.return_value = {"id": mock_segment.id} + mock_summary_svc.get_segments_summaries.return_value = {} segments_data = [{"content": "Test segment content", "answer": "Test answer"}] @@ -1206,6 +1212,7 @@ class TestDatasetSegmentApiUpdate: mock_rate_limit.enabled = False mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit + @patch("controllers.service_api.dataset.segment.SummaryIndexService") @patch("controllers.service_api.dataset.segment.marshal") @patch("controllers.service_api.dataset.segment.SegmentService") @patch("controllers.service_api.dataset.segment.DocumentService") @@ -1224,6 +1231,7 @@ class TestDatasetSegmentApiUpdate: mock_doc_svc, mock_seg_svc, mock_marshal, + mock_summary_svc, app, mock_tenant, mock_dataset, @@ -1240,6 +1248,7 @@ class TestDatasetSegmentApiUpdate: updated = Mock() mock_seg_svc.update_segment.return_value = updated mock_marshal.return_value = {"id": mock_segment.id} + mock_summary_svc.get_segment_summary.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}", @@ -1349,6 +1358,7 @@ class TestDatasetSegmentApiGetSingle: ``current_account_with_tenant()`` and ``marshal``. """ + @patch("controllers.service_api.dataset.segment.SummaryIndexService") @patch("controllers.service_api.dataset.segment.marshal") @patch("controllers.service_api.dataset.segment.SegmentService") @patch("controllers.service_api.dataset.segment.DocumentService") @@ -1363,6 +1373,7 @@ class TestDatasetSegmentApiGetSingle: mock_doc_svc, mock_seg_svc, mock_marshal, + mock_summary_svc, app, mock_tenant, mock_dataset, @@ -1376,6 +1387,7 @@ class TestDatasetSegmentApiGetSingle: mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.get_segment_by_id.return_value = mock_segment mock_marshal.return_value = {"id": mock_segment.id} + mock_summary_svc.get_segment_summary.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}", @@ -1393,6 +1405,55 @@ class TestDatasetSegmentApiGetSingle: assert "data" in response assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX + @patch("controllers.service_api.dataset.segment.SummaryIndexService") + @patch("controllers.service_api.dataset.segment.marshal") + @patch("controllers.service_api.dataset.segment.SegmentService") + @patch("controllers.service_api.dataset.segment.DocumentService") + @patch("controllers.service_api.dataset.segment.DatasetService") + @patch("controllers.service_api.dataset.segment.current_account_with_tenant") + @patch("controllers.service_api.dataset.segment.db") + def test_get_single_segment_includes_summary( + self, + mock_db, + mock_account_fn, + mock_dataset_svc, + mock_doc_svc, + mock_seg_svc, + mock_marshal, + mock_summary_svc, + app, + mock_tenant, + mock_dataset, + mock_segment, + ): + """Test that single segment response includes summary content from SummaryIndexService.""" + mock_account_fn.return_value = (Mock(), mock_tenant.id) + mock_db.session.scalar.return_value = mock_dataset + mock_dataset_svc.check_dataset_model_setting.return_value = None + mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) + mock_doc_svc.get_document.return_value = mock_doc + mock_seg_svc.get_segment_by_id.return_value = mock_segment + mock_marshal.return_value = {"id": mock_segment.id, "summary": None} + + mock_summary_record = Mock() + mock_summary_record.summary_content = "This is the segment summary" + mock_summary_svc.get_segment_summary.return_value = mock_summary_record + + with app.test_request_context( + f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}", + method="GET", + ): + api = DatasetSegmentApi() + response, status = api.get( + tenant_id=mock_tenant.id, + dataset_id=mock_dataset.id, + document_id="doc-id", + segment_id=mock_segment.id, + ) + + assert status == 200 + assert response["data"]["summary"] == "This is the segment summary" + @patch("controllers.service_api.dataset.segment.current_account_with_tenant") @patch("controllers.service_api.dataset.segment.db") def test_get_single_segment_dataset_not_found(