From a69b8c1e96a04f0c41ed358247a2838c756f3f5a Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Thu, 26 Mar 2026 05:24:54 +0100 Subject: [PATCH] refactor: select in service API dataset document and segment controllers (#34101) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../service_api/dataset/document.py | 53 ++++++++++++------ .../service_api/dataset/segment.py | 37 +++++++++--- .../dataset/test_dataset_segment.py | 56 +++++++++---------- .../service_api/dataset/test_document.py | 46 ++++++++------- .../services/dataset_service_test_helpers.py | 2 +- 5 files changed, 114 insertions(+), 80 deletions(-) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index d34b4124aeb..2c094aa3e6e 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -6,7 +6,7 @@ from uuid import UUID from flask import request, send_file from flask_restx import marshal from pydantic import BaseModel, Field, field_validator, model_validator -from sqlalchemy import desc, select +from sqlalchemy import desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -155,7 +155,9 @@ class DocumentAddByTextApi(DatasetApiResource): dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -238,7 +240,9 @@ class DocumentUpdateByTextApi(DatasetApiResource): def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): """Update document by text.""" payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {}) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1) + ) args = payload.model_dump(exclude_none=True) if not dataset: raise ValueError("Dataset does not exist.") @@ -315,7 +319,9 @@ class DocumentAddByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by upload file.""" - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -425,7 +431,9 @@ class DocumentUpdateByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -515,7 +523,9 @@ class DocumentListApi(DatasetApiResource): dataset_id = str(dataset_id) tenant_id = str(tenant_id) query_params = DocumentListQuery.model_validate(request.args.to_dict()) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -609,7 +619,9 @@ class DocumentIndexingStatusApi(DatasetApiResource): batch = str(batch) tenant_id = str(tenant_id) # get dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # get documents @@ -619,20 +631,23 @@ class DocumentIndexingStatusApi(DatasetApiResource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -822,7 +837,9 @@ class DocumentApi(DatasetApiResource): tenant_id = str(tenant_id) # get dataset info - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index f8c6b251eb3..28fa9151179 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -3,6 +3,7 @@ from typing import Any from flask import request from flask_restx import marshal from pydantic import BaseModel, Field +from sqlalchemy import select from werkzeug.exceptions import NotFound from configs import dify_config @@ -92,7 +93,9 @@ class SegmentApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Create single segment.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check document @@ -150,7 +153,9 @@ class SegmentApi(DatasetApiResource): # check dataset page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check document @@ -220,7 +225,9 @@ class DatasetSegmentApi(DatasetApiResource): def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -254,7 +261,9 @@ class DatasetSegmentApi(DatasetApiResource): def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -301,7 +310,9 @@ class DatasetSegmentApi(DatasetApiResource): def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -344,7 +355,9 @@ class ChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Create child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -402,7 +415,9 @@ class ChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Get child chunks.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -468,7 +483,9 @@ class DatasetChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Delete child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -527,7 +544,9 @@ class DatasetChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Update child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index 73a87761d52..7f5d6b08390 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -788,7 +788,7 @@ class TestSegmentApiGet: """Test successful segment list retrieval.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_seg_svc.get_segments.return_value = ([mock_segment], 1) mock_marshal.return_value = [{"id": mock_segment.id}] @@ -813,7 +813,7 @@ class TestSegmentApiGet: """Test 404 when dataset not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -833,7 +833,7 @@ class TestSegmentApiGet: """Test 404 when document not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None # Act & Assert @@ -899,7 +899,7 @@ class TestSegmentApiPost: mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" @@ -950,7 +950,7 @@ class TestSegmentApiPost: mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" @@ -992,7 +992,7 @@ class TestSegmentApiPost: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "indexing" # Not completed @@ -1043,7 +1043,7 @@ class TestDatasetSegmentApiDelete: """Test successful segment deletion.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc = Mock() @@ -1087,7 +1087,7 @@ class TestDatasetSegmentApiDelete: """Test 404 when segment not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" @@ -1129,7 +1129,7 @@ class TestDatasetSegmentApiDelete: """Test 404 when dataset not found for delete.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -1163,7 +1163,7 @@ class TestDatasetSegmentApiDelete: """Test 404 when document not found for delete.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = None @@ -1233,7 +1233,7 @@ class TestDatasetSegmentApiUpdate: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = mock_segment @@ -1280,7 +1280,7 @@ class TestDatasetSegmentApiUpdate: """Test 404 when dataset not found for update.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id", @@ -1321,7 +1321,7 @@ class TestDatasetSegmentApiUpdate: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1370,7 +1370,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test successful single segment retrieval.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_doc_svc.get_document.return_value = mock_doc @@ -1405,7 +1405,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when dataset not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id", @@ -1436,7 +1436,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when document not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = None @@ -1471,7 +1471,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when segment not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1515,7 +1515,7 @@ class TestChildChunkApiGet: ): """Test successful child chunk list retrieval.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = Mock() @@ -1554,7 +1554,7 @@ class TestChildChunkApiGet: ): """Test 404 when dataset not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks", @@ -1583,7 +1583,7 @@ class TestChildChunkApiGet: ): """Test 404 when document not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None with app.test_request_context( @@ -1615,7 +1615,7 @@ class TestChildChunkApiGet: ): """Test 404 when segment not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1676,7 +1676,7 @@ class TestChildChunkApiPost: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = Mock() mock_child = Mock() @@ -1717,7 +1717,7 @@ class TestChildChunkApiPost: """Test 404 when dataset not found.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks", @@ -1755,7 +1755,7 @@ class TestChildChunkApiPost: """Test 404 when segment not found.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1808,7 +1808,7 @@ class TestDatasetChildChunkApiDelete: ): """Test successful child chunk deletion.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc_svc.get_document.return_value = mock_doc @@ -1858,7 +1858,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when child chunk not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) @@ -1899,7 +1899,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when segment does not belong to the document.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) @@ -1939,7 +1939,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when child chunk does not belong to the segment.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index 7f77e61ee4c..12d5e7345d2 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -717,7 +717,7 @@ class TestDocumentApiDelete: dataset_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = mock_document mock_doc_svc.check_archived.return_value = False @@ -746,7 +746,7 @@ class TestDocumentApiDelete: document_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None @@ -767,7 +767,7 @@ class TestDocumentApiDelete: dataset_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = mock_document mock_doc_svc.check_archived.return_value = True @@ -788,7 +788,7 @@ class TestDocumentApiDelete: # Arrange dataset_id = str(uuid.uuid4()) document_id = str(uuid.uuid4()) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -809,7 +809,7 @@ class TestDocumentListApi: def test_list_documents_success(self, mock_db, mock_doc_svc, mock_marshal, app, mock_tenant, mock_dataset): """Test successful document list retrieval.""" # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_pagination = Mock() mock_pagination.items = [Mock(), Mock()] @@ -838,7 +838,7 @@ class TestDocumentListApi: def test_list_documents_dataset_not_found(self, mock_db, app, mock_tenant, mock_dataset): """Test 404 when dataset not found.""" # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -860,8 +860,6 @@ class TestDocumentIndexingStatusApi: """Test successful indexing status retrieval.""" # Arrange batch_id = "batch_123" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset - mock_doc = Mock() mock_doc.id = str(uuid.uuid4()) mock_doc.is_paused = False @@ -877,8 +875,8 @@ class TestDocumentIndexingStatusApi: mock_doc_svc.get_batch_documents.return_value = [mock_doc] - # Mock segment count queries - mock_db.session.query.return_value.where.return_value.where.return_value.count.return_value = 5 + # scalar() called 3 times: dataset lookup, completed_segments count, total_segments count + mock_db.session.scalar.side_effect = [mock_dataset, 5, 5] mock_marshal.return_value = {"id": mock_doc.id, "indexing_status": "completed"} # Act @@ -898,7 +896,7 @@ class TestDocumentIndexingStatusApi: """Test 404 when dataset not found.""" # Arrange batch_id = "batch_123" - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -915,7 +913,7 @@ class TestDocumentIndexingStatusApi: """Test 404 when no documents found for batch.""" # Arrange batch_id = "batch_empty" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_batch_documents.return_value = [] # Act & Assert @@ -986,7 +984,7 @@ class TestDocumentAddByTextApi: # Arrange — neutralise billing decorators self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset.indexing_technique = "economy" mock_current_user.id = str(uuid.uuid4()) @@ -1035,7 +1033,7 @@ class TestDocumentAddByTextApi: # Arrange — neutralise billing decorators self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -1064,7 +1062,7 @@ class TestDocumentAddByTextApi: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.indexing_technique = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset # Act & Assert with app.test_request_context( @@ -1150,7 +1148,7 @@ class TestDocumentUpdateByTextApiPost: _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.indexing_technique = "economy" mock_dataset.latest_process_rule = Mock() - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_current_user.id = "user-1" mock_upload = Mock() @@ -1193,7 +1191,7 @@ class TestDocumentUpdateByTextApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None doc_id = str(uuid.uuid4()) with app.test_request_context( @@ -1232,7 +1230,7 @@ class TestDocumentAddByFileApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None from io import BytesIO @@ -1263,7 +1261,7 @@ class TestDocumentAddByFileApiPost: """Test ValueError when dataset is external.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.provider = "external" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1298,7 +1296,7 @@ class TestDocumentAddByFileApiPost: mock_dataset.provider = "vendor" mock_dataset.indexing_technique = "economy" mock_dataset.chunk_structure = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset with app.test_request_context( f"/datasets/{mock_dataset.id}/document/create_by_file", @@ -1328,7 +1326,7 @@ class TestDocumentAddByFileApiPost: mock_dataset.provider = "vendor" mock_dataset.indexing_technique = None mock_dataset.chunk_structure = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1366,7 +1364,7 @@ class TestDocumentUpdateByFileApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None from io import BytesIO @@ -1402,7 +1400,7 @@ class TestDocumentUpdateByFileApiPost: """Test ValueError when dataset is external.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.provider = "external" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1450,7 +1448,7 @@ class TestDocumentUpdateByFileApiPost: mock_dataset.chunk_structure = None mock_dataset.latest_process_rule = Mock() mock_dataset.created_by_account = Mock() - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_current_user.id = "user-1" mock_upload = Mock() diff --git a/api/tests/unit_tests/services/dataset_service_test_helpers.py b/api/tests/unit_tests/services/dataset_service_test_helpers.py index 602488d6eac..542179e2a39 100644 --- a/api/tests/unit_tests/services/dataset_service_test_helpers.py +++ b/api/tests/unit_tests/services/dataset_service_test_helpers.py @@ -10,13 +10,13 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, create_autospec, patch import pytest +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType from werkzeug.exceptions import Forbidden, NotFound from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType from enums.cloud_plan import CloudPlan from models import Account, TenantAccountRole from models.dataset import (