From 4bd388669aedc342ccc76ae7529785d615b10323 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Tue, 31 Mar 2026 21:20:56 -0500 Subject: [PATCH] refactor: core/app pipeline, core/datasource, and core/indexing_runner (#34359) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../app/apps/pipeline/pipeline_generator.py | 2 +- api/core/app/apps/pipeline/pipeline_runner.py | 17 ++- .../datasource/datasource_file_manager.py | 8 +- api/core/indexing_runner.py | 105 ++++++++++-------- .../apps/pipeline/test_pipeline_generator.py | 2 +- .../app/apps/pipeline/test_pipeline_runner.py | 25 +---- .../test_datasource_file_manager.py | 50 +++------ .../core/rag/indexing/test_indexing_runner.py | 73 ++++++------ 8 files changed, 131 insertions(+), 151 deletions(-) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index fa242003a25..9cc1a197d5f 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -302,7 +302,7 @@ class PipelineGenerator(BaseAppGenerator): """ with preserve_flask_contexts(flask_app, context_vars=context): # init queue manager - workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first() + workflow = db.session.get(Workflow, workflow_id) if not workflow: raise ValueError(f"Workflow not found: {workflow_id}") queue_manager = PipelineQueueManager( diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 4c188dac68d..b4d2310da85 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -9,6 +9,7 @@ from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variable_loader import VariableLoader from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput +from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig @@ -84,13 +85,13 @@ class PipelineRunner(WorkflowBasedAppRunner): user_id = None if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: - end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first() + end_user = db.session.get(EndUser, self.application_generate_entity.user_id) if end_user: user_id = end_user.session_id else: user_id = self.application_generate_entity.user_id - pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first() + pipeline = db.session.get(Pipeline, app_config.app_id) if not pipeline: raise ValueError("Pipeline not found") @@ -213,10 +214,10 @@ class PipelineRunner(WorkflowBasedAppRunner): Get workflow """ # fetch workflow by workflow_id - workflow = ( - db.session.query(Workflow) + workflow = db.session.scalar( + select(Workflow) .where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id) - .first() + .limit(1) ) # return workflow @@ -297,10 +298,8 @@ class PipelineRunner(WorkflowBasedAppRunner): """ if isinstance(event, GraphRunFailedEvent): if document_id and dataset_id: - document = ( - db.session.query(Document) - .where(Document.id == document_id, Document.dataset_id == dataset_id) - .first() + document = db.session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) ) if document: document.indexing_status = "error" diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index fe40d8f0e58..492b507aa99 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -153,7 +153,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ - upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == id).first() + upload_file: UploadFile | None = db.session.get(UploadFile, id) if not upload_file: return None @@ -171,7 +171,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ - message_file: MessageFile | None = db.session.query(MessageFile).where(MessageFile.id == id).first() + message_file: MessageFile | None = db.session.get(MessageFile, id) # Check if message_file is not None if message_file is not None: @@ -185,7 +185,7 @@ class DatasourceFileManager: else: tool_file_id = None - tool_file: ToolFile | None = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first() + tool_file: ToolFile | None = db.session.get(ToolFile, tool_file_id) if not tool_file: return None @@ -203,7 +203,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ - upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() + upload_file: UploadFile | None = db.session.get(UploadFile, upload_file_id) if not upload_file: return None, None diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 3ec17bc9864..b8d5ca2f50f 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -10,7 +10,7 @@ from typing import Any from flask import Flask, current_app from graphon.model_runtime.entities.model_entities import ModelType -from sqlalchemy import select +from sqlalchemy import delete, func, select, update from sqlalchemy.orm.exc import ObjectDeletedError from configs import dify_config @@ -78,7 +78,7 @@ class IndexingRunner: continue # get dataset - dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first() + dataset = db.session.get(Dataset, requeried_document.dataset_id) if not dataset: raise ValueError("no dataset found") @@ -95,7 +95,7 @@ class IndexingRunner: text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform - current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first() + current_user = db.session.get(Account, requeried_document.created_by) if not current_user: raise ValueError("no current user found") current_user.set_tenant_id(dataset.tenant_id) @@ -137,23 +137,24 @@ class IndexingRunner: return # get dataset - dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first() + dataset = db.session.get(Dataset, requeried_document.dataset_id) if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete - document_segments = ( - db.session.query(DocumentSegment) - .filter_by(dataset_id=dataset.id, document_id=requeried_document.id) - .all() - ) + document_segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == requeried_document.id, + ) + ).all() for document_segment in document_segments: db.session.delete(document_segment) if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: # delete child chunks - db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() + db.session.execute(delete(ChildChunk).where(ChildChunk.segment_id == document_segment.id)) db.session.commit() # get the process rule stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id) @@ -167,7 +168,7 @@ class IndexingRunner: text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform - current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first() + current_user = db.session.get(Account, requeried_document.created_by) if not current_user: raise ValueError("no current user found") current_user.set_tenant_id(dataset.tenant_id) @@ -207,17 +208,18 @@ class IndexingRunner: return # get dataset - dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first() + dataset = db.session.get(Dataset, requeried_document.dataset_id) if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete - document_segments = ( - db.session.query(DocumentSegment) - .filter_by(dataset_id=dataset.id, document_id=requeried_document.id) - .all() - ) + document_segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == requeried_document.id, + ) + ).all() documents = [] if document_segments: @@ -289,7 +291,7 @@ class IndexingRunner: embedding_model_instance = None if dataset_id: - dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() + dataset = db.session.get(Dataset, dataset_id) if not dataset: raise ValueError("Dataset not found.") if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}: @@ -652,24 +654,26 @@ class IndexingRunner: @staticmethod def _process_keyword_index(flask_app, dataset_id, document_id, documents): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() + dataset = db.session.get(Dataset, dataset_id) if not dataset: raise ValueError("no dataset found") keyword = Keyword(dataset) keyword.create(documents) if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: document_ids = [document.metadata["doc_id"] for document in documents] - db.session.query(DocumentSegment).where( - DocumentSegment.document_id == document_id, - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == SegmentStatus.INDEXING, - ).update( - { - DocumentSegment.status: SegmentStatus.COMPLETED, - DocumentSegment.enabled: True, - DocumentSegment.completed_at: naive_utc_now(), - } + db.session.execute( + update(DocumentSegment) + .where( + DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.index_node_id.in_(document_ids), + DocumentSegment.status == SegmentStatus.INDEXING, + ) + .values( + status=SegmentStatus.COMPLETED, + enabled=True, + completed_at=naive_utc_now(), + ) ) db.session.commit() @@ -703,17 +707,19 @@ class IndexingRunner: ) document_ids = [document.metadata["doc_id"] for document in chunk_documents] - db.session.query(DocumentSegment).where( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == SegmentStatus.INDEXING, - ).update( - { - DocumentSegment.status: SegmentStatus.COMPLETED, - DocumentSegment.enabled: True, - DocumentSegment.completed_at: naive_utc_now(), - } + db.session.execute( + update(DocumentSegment) + .where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(document_ids), + DocumentSegment.status == SegmentStatus.INDEXING, + ) + .values( + status=SegmentStatus.COMPLETED, + enabled=True, + completed_at=naive_utc_now(), + ) ) db.session.commit() @@ -734,10 +740,17 @@ class IndexingRunner: """ Update the document indexing status. """ - count = db.session.query(DatasetDocument).filter_by(id=document_id, is_paused=True).count() + count = ( + db.session.scalar( + select(func.count()) + .select_from(DatasetDocument) + .where(DatasetDocument.id == document_id, DatasetDocument.is_paused == True) + ) + or 0 + ) if count > 0: raise DocumentIsPausedError() - document = db.session.query(DatasetDocument).filter_by(id=document_id).first() + document = db.session.get(DatasetDocument, document_id) if not document: raise DocumentIsDeletedPausedError() @@ -745,7 +758,7 @@ class IndexingRunner: if extra_update_params: update_params.update(extra_update_params) - db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) # type: ignore + db.session.execute(update(DatasetDocument).where(DatasetDocument.id == document_id).values(update_params)) # type: ignore db.session.commit() @staticmethod @@ -753,7 +766,9 @@ class IndexingRunner: """ Update the document segment by document id. """ - db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params) + db.session.execute( + update(DocumentSegment).where(DocumentSegment.document_id == dataset_document_id).values(update_params) + ) db.session.commit() def _transform( diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py index 06face41fe7..0047f6659d5 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py @@ -345,7 +345,7 @@ def test_generate_raises_when_workflow_not_found(generator, mocker): mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve) session = MagicMock() - session.query.return_value.where.return_value.first.return_value = None + session.get.return_value = None mocker.patch.object(module.db, "session", session) with pytest.raises(ValueError): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index ab70996f0aa..c8ae288e6fe 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -80,9 +80,7 @@ def test_get_workflow_returns_workflow(mocker, runner): pipeline = MagicMock(tenant_id="tenant", id="pipe") workflow = MagicMock(id="wf") - query = MagicMock() - query.where.return_value.first.return_value = workflow - mocker.patch.object(module.db, "session", MagicMock(query=MagicMock(return_value=query))) + mocker.patch.object(module.db, "session", MagicMock(scalar=MagicMock(return_value=workflow))) result = runner.get_workflow(pipeline=pipeline, workflow_id="wf") @@ -115,11 +113,8 @@ def test_init_rag_pipeline_graph_not_found(mocker, runner): def test_update_document_status_on_failure(mocker, runner): document = MagicMock() - query = MagicMock() - query.where.return_value.first.return_value = document - session = MagicMock() - session.query.return_value = query + session.scalar.return_value = document mocker.patch.object(module.db, "session", session) event = GraphRunFailedEvent(error="boom") @@ -189,14 +184,10 @@ def test_run_single_iteration_path(mocker): app_generate_entity.single_iteration_run = MagicMock() pipeline = MagicMock(id="pipe") - query_pipeline = MagicMock() - query_pipeline.where.return_value.first.return_value = pipeline - - query_end_user = MagicMock() - query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess") + end_user = MagicMock(session_id="sess") session = MagicMock() - session.query.side_effect = [query_end_user, query_pipeline] + session.get.side_effect = [end_user, pipeline] mocker.patch.object(module.db, "session", session) runner = PipelineRunner( @@ -241,14 +232,10 @@ def test_run_normal_path_builds_graph(mocker): app_generate_entity = _build_app_generate_entity() pipeline = MagicMock(id="pipe") - query_pipeline = MagicMock() - query_pipeline.where.return_value.first.return_value = pipeline - - query_end_user = MagicMock() - query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess") + end_user = MagicMock(session_id="sess") session = MagicMock() - session.query.side_effect = [query_end_user, query_pipeline] + session.get.side_effect = [end_user, pipeline] mocker.patch.object(module.db, "session", session) workflow = MagicMock( diff --git a/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py index 7cd1fdf06b2..4f39d38831d 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py @@ -287,9 +287,7 @@ class TestDatasourceFileManager: mock_upload_file.key = "some_key" mock_upload_file.mime_type = "image/png" - mock_query = mock_db.session.query.return_value - mock_where = mock_query.where.return_value - mock_where.first.return_value = mock_upload_file + mock_db.session.get.return_value = mock_upload_file mock_storage.load_once.return_value = b"file content" @@ -300,7 +298,7 @@ class TestDatasourceFileManager: assert result == (b"file content", "image/png") # Case: Not found - mock_where.first.return_value = None + mock_db.session.get.return_value = None assert DatasourceFileManager.get_file_binary("unknown") is None @patch("core.datasource.datasource_file_manager.db") @@ -314,16 +312,14 @@ class TestDatasourceFileManager: mock_tool_file.file_key = "tool_key" mock_tool_file.mimetype = "image/png" - # Mock query sequence - def mock_query(model): - m = MagicMock() + def mock_get(model, id): if model == MessageFile: - m.where.return_value.first.return_value = mock_message_file + return mock_message_file elif model == ToolFile: - m.where.return_value.first.return_value = mock_tool_file - return m + return mock_tool_file + return None - mock_db.session.query.side_effect = mock_query + mock_db.session.get.side_effect = mock_get mock_storage.load_once.return_value = b"tool content" # Execute @@ -344,15 +340,12 @@ class TestDatasourceFileManager: mock_tool_file.file_key = "tk" mock_tool_file.mimetype = "image/png" - def mock_query(model): - m = MagicMock() + def mock_get(model, id): if model == MessageFile: - m.where.return_value.first.return_value = mock_message_file - else: - m.where.return_value.first.return_value = mock_tool_file - return m + return mock_message_file + return mock_tool_file - mock_db.session.query.side_effect = mock_query + mock_db.session.get.side_effect = mock_get mock_storage.load_once.return_value = b"bits" result = DatasourceFileManager.get_file_binary_by_message_file_id("m") @@ -361,27 +354,20 @@ class TestDatasourceFileManager: @patch("core.datasource.datasource_file_manager.db") @patch("core.datasource.datasource_file_manager.storage") def test_get_file_binary_by_message_file_id_failures(self, mock_storage, mock_db): - # Setup common mock - mock_query_obj = MagicMock() - mock_db.session.query.return_value = mock_query_obj - mock_query_obj.where.return_value.first.return_value = None - # Case 1: Message file not found + mock_db.session.get.return_value = None assert DatasourceFileManager.get_file_binary_by_message_file_id("none") is None # Case 2: Message file found but tool file not found mock_message_file = MagicMock(spec=MessageFile) mock_message_file.url = None - def mock_query_v2(model): - m = MagicMock() + def mock_get_v2(model, id): if model == MessageFile: - m.where.return_value.first.return_value = mock_message_file - else: - m.where.return_value.first.return_value = None - return m + return mock_message_file + return None - mock_db.session.query.side_effect = mock_query_v2 + mock_db.session.get.side_effect = mock_get_v2 assert DatasourceFileManager.get_file_binary_by_message_file_id("msg_id") is None @patch("core.datasource.datasource_file_manager.db") @@ -392,7 +378,7 @@ class TestDatasourceFileManager: mock_upload_file.key = "upload_key" mock_upload_file.mime_type = "text/plain" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_upload_file + mock_db.session.get.return_value = mock_upload_file mock_storage.load_stream.return_value = iter([b"chunk1", b"chunk2"]) @@ -404,7 +390,7 @@ class TestDatasourceFileManager: assert list(stream) == [b"chunk1", b"chunk2"] # Case: Not found - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None stream, mimetype = DatasourceFileManager.get_file_generator_by_upload_file_id("none") assert stream is None assert mimetype is None diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index 450e7166360..641c5d9ba0f 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -795,33 +795,21 @@ class TestIndexingRunnerRun: doc = sample_dataset_documents[0] # Mock database queries - mock_dependencies["db"].session.get.return_value = doc - mock_dataset = Mock(spec=Dataset) mock_dataset.id = doc.dataset_id mock_dataset.tenant_id = doc.tenant_id mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY - mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset + + mock_current_user = MagicMock() + mock_current_user.set_tenant_id = MagicMock() + + get_dispatch = {"Document": doc, "Dataset": mock_dataset, "Account": mock_current_user} + mock_dependencies["db"].session.get.side_effect = lambda model, id: get_dispatch.get(model.__name__) mock_process_rule = Mock(spec=DatasetProcessRule) mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}} mock_dependencies["db"].session.scalar.return_value = mock_process_rule - # Mock current_user (Account) for _transform - mock_current_user = MagicMock() - mock_current_user.set_tenant_id = MagicMock() - - # Setup db.session.query to return different results based on the model - def mock_query_side_effect(model): - mock_query_result = MagicMock() - if model.__name__ == "Dataset": - mock_query_result.filter_by.return_value.first.return_value = mock_dataset - elif model.__name__ == "Account": - mock_query_result.filter_by.return_value.first.return_value = mock_current_user - return mock_query_result - - mock_dependencies["db"].session.query.side_effect = mock_query_side_effect - # Mock processor mock_processor = MagicMock() mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor @@ -891,10 +879,11 @@ class TestIndexingRunnerRun: doc = sample_dataset_documents[0] # Mock database - mock_dependencies["db"].session.get.return_value = doc - mock_dataset = Mock(spec=Dataset) - mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset + mock_dataset.tenant_id = doc.tenant_id + + get_dispatch = {"Document": doc, "Dataset": mock_dataset} + mock_dependencies["db"].session.get.side_effect = lambda model, id: get_dispatch.get(model.__name__) mock_process_rule = Mock(spec=DatasetProcessRule) mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}} @@ -917,11 +906,12 @@ class TestIndexingRunnerRun: runner = IndexingRunner() doc = sample_dataset_documents[0] - # Mock database to raise ObjectDeletedError - mock_dependencies["db"].session.get.return_value = doc - + # Mock database mock_dataset = Mock(spec=Dataset) - mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset + mock_dataset.tenant_id = doc.tenant_id + + get_dispatch = {"Document": doc, "Dataset": mock_dataset} + mock_dependencies["db"].session.get.side_effect = lambda model, id: get_dispatch.get(model.__name__) mock_process_rule = Mock(spec=DatasetProcessRule) mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}} @@ -945,17 +935,21 @@ class TestIndexingRunnerRun: docs = sample_dataset_documents # Mock database - def get_side_effect(model_class, doc_id): - for doc in docs: - if doc.id == doc_id: - return doc - return None - - mock_dependencies["db"].session.get.side_effect = get_side_effect - mock_dataset = Mock(spec=Dataset) mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY - mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset + mock_current_user = MagicMock() + mock_current_user.set_tenant_id = MagicMock() + + doc_map = {doc.id: doc for doc in docs} + model_dispatch = {"Dataset": mock_dataset, "Account": mock_current_user} + + def get_side_effect(model_class, id): + name = model_class.__name__ + if name == "Document": + return doc_map.get(id) + return model_dispatch.get(name) + + mock_dependencies["db"].session.get.side_effect = get_side_effect mock_process_rule = Mock(spec=DatasetProcessRule) mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}} @@ -1035,9 +1029,8 @@ class TestIndexingRunnerRetryLogic: mock_document = Mock(spec=DatasetDocument) mock_document.id = document_id - mock_dependencies["db"].session.query.return_value.filter_by.return_value.count.return_value = 0 - mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_document - mock_dependencies["db"].session.query.return_value.filter_by.return_value.update.return_value = None + mock_dependencies["db"].session.scalar.return_value = 0 + mock_dependencies["db"].session.get.return_value = mock_document # Act IndexingRunner._update_document_index_status( @@ -1053,7 +1046,7 @@ class TestIndexingRunnerRetryLogic: """Test document status update when document is paused.""" # Arrange document_id = str(uuid.uuid4()) - mock_dependencies["db"].session.query.return_value.filter_by.return_value.count.return_value = 1 + mock_dependencies["db"].session.scalar.return_value = 1 # Act & Assert with pytest.raises(DocumentIsPausedError): @@ -1063,8 +1056,8 @@ class TestIndexingRunnerRetryLogic: """Test document status update when document is deleted.""" # Arrange document_id = str(uuid.uuid4()) - mock_dependencies["db"].session.query.return_value.filter_by.return_value.count.return_value = 0 - mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = None + mock_dependencies["db"].session.scalar.return_value = 0 + mock_dependencies["db"].session.get.return_value = None # Act & Assert with pytest.raises(DocumentIsDeletedPausedError):