From c2428361c442eb4bfcc658e2fd3c0328e40902d7 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Fri, 3 Apr 2026 17:52:01 -0500 Subject: [PATCH] refactor: select in dataset_service (DocumentService class) (#34528) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/services/dataset_service.py | 105 +++++++++--------- .../services/test_dataset_service_document.py | 84 ++++---------- 2 files changed, 78 insertions(+), 111 deletions(-) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 4e1fe3f6a1a..f7e22e0e89c 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1400,8 +1400,8 @@ class DocumentService: @staticmethod def get_document(dataset_id: str, document_id: str | None = None) -> Document | None: if document_id: - document = ( - db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = db.session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) ) return document else: @@ -1630,7 +1630,7 @@ class DocumentService: @staticmethod def get_document_by_id(document_id: str) -> Document | None: - document = db.session.query(Document).where(Document.id == document_id).first() + document = db.session.get(Document, document_id) return document @@ -1695,7 +1695,7 @@ class DocumentService: @staticmethod def get_document_file_detail(file_id: str): - file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none() + file_detail = db.session.get(UploadFile, file_id) return file_detail @staticmethod @@ -1769,9 +1769,11 @@ class DocumentService: document.name = name db.session.add(document) if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict: - db.session.query(UploadFile).where( - UploadFile.id == document.data_source_info_dict["upload_file_id"] - ).update({UploadFile.name: name}) + db.session.execute( + update(UploadFile) + .where(UploadFile.id == document.data_source_info_dict["upload_file_id"]) + .values(name=name) + ) db.session.commit() @@ -1858,8 +1860,8 @@ class DocumentService: @staticmethod def get_documents_position(dataset_id): - document = ( - db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() + document = db.session.scalar( + select(Document).where(Document.dataset_id == dataset_id).order_by(Document.position.desc()).limit(1) ) if document: return document.position + 1 @@ -2016,28 +2018,28 @@ class DocumentService: if not knowledge_config.data_source.info_list.file_info_list: raise ValueError("File source info is required") upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids - files = ( - db.session.query(UploadFile) - .where( - UploadFile.tenant_id == dataset.tenant_id, - UploadFile.id.in_(upload_file_list), - ) - .all() + files = list( + db.session.scalars( + select(UploadFile).where( + UploadFile.tenant_id == dataset.tenant_id, + UploadFile.id.in_(upload_file_list), + ) + ).all() ) if len(files) != len(set(upload_file_list)): raise FileNotExistsError("One or more files not found.") file_names = [file.name for file in files] - db_documents = ( - db.session.query(Document) - .where( - Document.dataset_id == dataset.id, - Document.tenant_id == current_user.current_tenant_id, - Document.data_source_type == DataSourceType.UPLOAD_FILE, - Document.enabled == True, - Document.name.in_(file_names), - ) - .all() + db_documents = list( + db.session.scalars( + select(Document).where( + Document.dataset_id == dataset.id, + Document.tenant_id == current_user.current_tenant_id, + Document.data_source_type == DataSourceType.UPLOAD_FILE, + Document.enabled == True, + Document.name.in_(file_names), + ) + ).all() ) documents_map = {document.name: document for document in db_documents} for file in files: @@ -2083,15 +2085,15 @@ class DocumentService: raise ValueError("No notion info list found.") exist_page_ids = [] exist_document = {} - documents = ( - db.session.query(Document) - .filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type=DataSourceType.NOTION_IMPORT, - enabled=True, - ) - .all() + documents = list( + db.session.scalars( + select(Document).where( + Document.dataset_id == dataset.id, + Document.tenant_id == current_user.current_tenant_id, + Document.data_source_type == DataSourceType.NOTION_IMPORT, + Document.enabled == True, + ) + ).all() ) if documents: for document in documents: @@ -2522,14 +2524,15 @@ class DocumentService: assert isinstance(current_user, Account) documents_count = ( - db.session.query(Document) - .where( - Document.completed_at.isnot(None), - Document.enabled == True, - Document.archived == False, - Document.tenant_id == current_user.current_tenant_id, + db.session.scalar( + select(func.count(Document.id)).where( + Document.completed_at.isnot(None), + Document.enabled == True, + Document.archived == False, + Document.tenant_id == current_user.current_tenant_id, + ) ) - .count() + or 0 ) return documents_count @@ -2579,10 +2582,10 @@ class DocumentService: raise ValueError("No file info list found.") upload_file_list = document_data.data_source.info_list.file_info_list.file_ids for file_id in upload_file_list: - file = ( - db.session.query(UploadFile) + file = db.session.scalar( + select(UploadFile) .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) - .first() + .limit(1) ) # raise error if file not found @@ -2599,8 +2602,8 @@ class DocumentService: notion_info_list = document_data.data_source.info_list.notion_info_list for notion_info in notion_info_list: workspace_id = notion_info.workspace_id - data_source_binding = ( - db.session.query(DataSourceOauthBinding) + data_source_binding = db.session.scalar( + select(DataSourceOauthBinding) .where( sa.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, @@ -2609,7 +2612,7 @@ class DocumentService: DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ) - .first() + .limit(1) ) if not data_source_binding: raise ValueError("Data source binding not found.") @@ -2654,8 +2657,10 @@ class DocumentService: db.session.commit() # update document segment - db.session.query(DocumentSegment).filter_by(document_id=document.id).update( - {DocumentSegment.status: SegmentStatus.RE_SEGMENT} + db.session.execute( + update(DocumentSegment) + .where(DocumentSegment.document_id == document.id) + .values(status=SegmentStatus.RE_SEGMENT) ) db.session.commit() # trigger async task diff --git a/api/tests/unit_tests/services/test_dataset_service_document.py b/api/tests/unit_tests/services/test_dataset_service_document.py index c8036487ab4..e5a2541da7c 100644 --- a/api/tests/unit_tests/services/test_dataset_service_document.py +++ b/api/tests/unit_tests/services/test_dataset_service_document.py @@ -90,13 +90,13 @@ class TestDocumentServiceQueryAndDownloadHelpers: result = DocumentService.get_document("dataset-1", None) assert result is None - mock_db.session.query.assert_not_called() + mock_db.session.scalar.assert_not_called() def test_get_document_queries_by_dataset_and_document_id(self): document = DatasetServiceUnitDataFactory.create_document_mock() with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.return_value = document + mock_db.session.scalar.return_value = document result = DocumentService.get_document("dataset-1", "doc-1") @@ -435,7 +435,7 @@ class TestDocumentServiceQueryAndDownloadHelpers: upload_file = DatasetServiceUnitDataFactory.create_upload_file_mock() with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.where.return_value.one_or_none.return_value = upload_file + mock_db.session.get.return_value = upload_file result = DocumentService.get_document_file_detail(upload_file.id) @@ -570,7 +570,7 @@ class TestDocumentServiceMutations: assert document.name == "New Name" assert document.doc_metadata[BuiltInField.document_name] == "New Name" mock_db.session.add.assert_called_once_with(document) - mock_db.session.query.return_value.where.return_value.update.assert_called_once() + mock_db.session.execute.assert_called() mock_db.session.commit.assert_called_once() def test_recover_document_raises_when_document_is_not_paused(self): @@ -624,9 +624,7 @@ class TestDocumentServiceMutations: document = DatasetServiceUnitDataFactory.create_document_mock(position=7) with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.first.return_value = ( - document - ) + mock_db.session.scalar.return_value = document result = DocumentService.get_documents_position("dataset-1") @@ -634,7 +632,7 @@ class TestDocumentServiceMutations: def test_get_documents_position_defaults_to_one_when_dataset_is_empty(self): with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.first.return_value = None + mock_db.session.scalar.return_value = None result = DocumentService.get_documents_position("dataset-1") @@ -869,11 +867,7 @@ class TestDocumentServiceUpdateDocumentWithDatasetId: patch("services.dataset_service.naive_utc_now", return_value="now"), patch("services.dataset_service.document_indexing_update_task") as update_task, ): - upload_query = MagicMock() - upload_query.where.return_value.first.return_value = SimpleNamespace(id="file-1", name="upload.txt") - segment_query = MagicMock() - segment_query.filter_by.return_value.update.return_value = 3 - mock_db.session.query.side_effect = [upload_query, segment_query] + mock_db.session.scalar.return_value = SimpleNamespace(id="file-1", name="upload.txt") result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) @@ -892,7 +886,7 @@ class TestDocumentServiceUpdateDocumentWithDatasetId: assert document.created_from == "web" assert document.doc_form == IndexStructureType.QA_INDEX assert mock_db.session.commit.call_count == 3 - segment_query.filter_by.return_value.update.assert_called_once() + mock_db.session.execute.assert_called() update_task.delay.assert_called_once_with(document.dataset_id, document.id) def test_update_document_with_dataset_id_notion_import_requires_binding(self, account_context): @@ -920,9 +914,7 @@ class TestDocumentServiceUpdateDocumentWithDatasetId: patch.object(DatasetService, "check_dataset_model_setting"), patch("services.dataset_service.db") as mock_db, ): - binding_query = MagicMock() - binding_query.where.return_value.first.return_value = None - mock_db.session.query.return_value = binding_query + mock_db.session.scalar.return_value = None with pytest.raises(ValueError, match="Data source binding not found"): DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) @@ -954,10 +946,6 @@ class TestDocumentServiceUpdateDocumentWithDatasetId: patch("services.dataset_service.naive_utc_now", return_value="now"), patch("services.dataset_service.document_indexing_update_task") as update_task, ): - segment_query = MagicMock() - segment_query.filter_by.return_value.update.return_value = 2 - mock_db.session.query.return_value = segment_query - result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) assert result is document @@ -968,7 +956,7 @@ class TestDocumentServiceUpdateDocumentWithDatasetId: ) assert document.name == "" assert document.doc_form == IndexStructureType.PARENT_CHILD_INDEX - segment_query.filter_by.return_value.update.assert_called_once() + mock_db.session.execute.assert_called() update_task.delay.assert_called_once_with("dataset-1", "doc-1") @@ -1218,11 +1206,10 @@ class TestDocumentServiceSaveDocumentWithDatasetId: patch("services.dataset_service.secrets.randbelow", return_value=23), ): mock_redis.lock.return_value = _make_lock_context() - upload_query = MagicMock() - upload_query.where.return_value.all.return_value = [upload_file_a, upload_file_b] - existing_documents_query = MagicMock() - existing_documents_query.where.return_value.all.return_value = [duplicate_document] - mock_db.session.query.side_effect = [upload_query, existing_documents_query] + mock_db.session.scalars.return_value.all.side_effect = [ + [upload_file_a, upload_file_b], + [duplicate_document], + ] documents, batch = DocumentService.save_document_with_dataset_id( dataset, @@ -1302,9 +1289,7 @@ class TestDocumentServiceSaveDocumentWithDatasetId: patch("services.dataset_service.DocumentIndexingTaskProxy") as document_proxy_cls, ): mock_redis.lock.return_value = _make_lock_context() - notion_documents_query = MagicMock() - notion_documents_query.filter_by.return_value.all.return_value = [existing_keep, existing_remove] - mock_db.session.query.return_value = notion_documents_query + mock_db.session.scalars.return_value.all.return_value = [existing_keep, existing_remove] documents, _ = DocumentService.save_document_with_dataset_id( dataset, @@ -1474,12 +1459,11 @@ class TestDocumentServiceTenantAndUpdateEdges: def test_get_tenant_documents_count_returns_query_count(self, account_context): with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.where.return_value.count.return_value = 12 + mock_db.session.scalar.return_value = 12 result = DocumentService.get_tenant_documents_count() assert result == 12 - mock_db.session.query.return_value.where.return_value.count.assert_called_once() def test_update_document_with_dataset_id_uses_automatic_process_rule_payload(self, account_context): dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") @@ -1514,11 +1498,7 @@ class TestDocumentServiceTenantAndUpdateEdges: ): process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES process_rule_cls.return_value = created_process_rule - upload_query = MagicMock() - upload_query.where.return_value.first.return_value = SimpleNamespace(id="file-1", name="upload.txt") - segment_query = MagicMock() - segment_query.filter_by.return_value.update.return_value = 1 - mock_db.session.query.side_effect = [upload_query, segment_query] + mock_db.session.scalar.return_value = SimpleNamespace(id="file-1", name="upload.txt") result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) @@ -1567,7 +1547,7 @@ class TestDocumentServiceTenantAndUpdateEdges: patch.object(DatasetService, "check_dataset_model_setting"), patch("services.dataset_service.db") as mock_db, ): - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with pytest.raises(FileNotExistsError): DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) @@ -1618,11 +1598,7 @@ class TestDocumentServiceTenantAndUpdateEdges: patch("services.dataset_service.naive_utc_now", return_value="now"), patch("services.dataset_service.document_indexing_update_task") as update_task, ): - binding_query = MagicMock() - binding_query.where.return_value.first.return_value = SimpleNamespace(id="binding-1") - segment_query = MagicMock() - segment_query.filter_by.return_value.update.return_value = 1 - mock_db.session.query.side_effect = [binding_query, segment_query] + mock_db.session.scalar.return_value = SimpleNamespace(id="binding-1") result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) @@ -1914,11 +1890,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches: ): mock_redis.lock.return_value = _make_lock_context() process_rule_cls.return_value = created_process_rule - upload_query = MagicMock() - upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] - existing_documents_query = MagicMock() - existing_documents_query.where.return_value.all.return_value = [] - mock_db.session.query.side_effect = [upload_query, existing_documents_query] + mock_db.session.scalars.return_value.all.side_effect = [[SimpleNamespace(id="file-1", name="file.txt")], []] documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) @@ -1958,11 +1930,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches: mock_redis.lock.return_value = _make_lock_context() process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES process_rule_cls.return_value = created_process_rule - upload_query = MagicMock() - upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] - existing_documents_query = MagicMock() - existing_documents_query.where.return_value.all.return_value = [] - mock_db.session.query.side_effect = [upload_query, existing_documents_query] + mock_db.session.scalars.return_value.all.side_effect = [[SimpleNamespace(id="file-1", name="file.txt")], []] DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) @@ -1996,11 +1964,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches: mock_redis.lock.return_value = _make_lock_context() process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES process_rule_cls.return_value = created_process_rule - upload_query = MagicMock() - upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] - existing_documents_query = MagicMock() - existing_documents_query.where.return_value.all.return_value = [] - mock_db.session.query.side_effect = [upload_query, existing_documents_query] + mock_db.session.scalars.return_value.all.side_effect = [[SimpleNamespace(id="file-1", name="file.txt")], []] DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) @@ -2024,9 +1988,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches: patch("services.dataset_service.secrets.randbelow", return_value=23), ): mock_redis.lock.return_value = _make_lock_context() - upload_query = MagicMock() - upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] - mock_db.session.query.return_value = upload_query + mock_db.session.scalars.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] with pytest.raises(FileNotExistsError, match="One or more files not found"): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)