diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 2b8a3ee5949..d8f4e11e758 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -2,6 +2,7 @@ import threading from typing import Any import pytz +from sqlalchemy import select import contexts from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager @@ -23,25 +24,25 @@ class AgentService: contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - conversation: Conversation | None = ( - db.session.query(Conversation) + conversation: Conversation | None = db.session.scalar( + select(Conversation) .where( Conversation.id == conversation_id, Conversation.app_id == app_model.id, ) - .first() + .limit(1) ) if not conversation: raise ValueError(f"Conversation not found: {conversation_id}") - message: Message | None = ( - db.session.query(Message) + message: Message | None = db.session.scalar( + select(Message) .where( Message.id == message_id, Message.conversation_id == conversation_id, ) - .first() + .limit(1) ) if not message: @@ -51,16 +52,11 @@ class AgentService: if conversation.from_end_user_id: # only select name field - executor = ( - db.session.query(EndUser, EndUser.name).where(EndUser.id == conversation.from_end_user_id).first() - ) + executor_name = db.session.scalar(select(EndUser.name).where(EndUser.id == conversation.from_end_user_id)) else: - executor = db.session.query(Account, Account.name).where(Account.id == conversation.from_account_id).first() + executor_name = db.session.scalar(select(Account.name).where(Account.id == conversation.from_account_id)) - if executor: - executor = executor.name - else: - executor = "Unknown" + executor = executor_name or "Unknown" assert isinstance(current_user, Account) assert current_user.timezone is not None timezone = pytz.timezone(current_user.timezone) diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index 3a0ed41be04..fdb377694bb 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -1,3 +1,5 @@ +from sqlalchemy import select + from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor from core.helper.encrypter import decrypt_token, encrypt_token from extensions.ext_database import db @@ -7,11 +9,12 @@ from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint class APIBasedExtensionService: @staticmethod def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: - extension_list = ( - db.session.query(APIBasedExtension) - .filter_by(tenant_id=tenant_id) - .order_by(APIBasedExtension.created_at.desc()) - .all() + extension_list = list( + db.session.scalars( + select(APIBasedExtension) + .where(APIBasedExtension.tenant_id == tenant_id) + .order_by(APIBasedExtension.created_at.desc()) + ).all() ) for extension in extension_list: @@ -36,11 +39,10 @@ class APIBasedExtensionService: @staticmethod def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: - extension = ( - db.session.query(APIBasedExtension) - .filter_by(tenant_id=tenant_id) - .filter_by(id=api_based_extension_id) - .first() + extension = db.session.scalar( + select(APIBasedExtension) + .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .limit(1) ) if not extension: @@ -58,23 +60,27 @@ class APIBasedExtensionService: if not extension_data.id: # case one: check new data, name must be unique - is_name_existed = ( - db.session.query(APIBasedExtension) - .filter_by(tenant_id=extension_data.tenant_id) - .filter_by(name=extension_data.name) - .first() + is_name_existed = db.session.scalar( + select(APIBasedExtension) + .where( + APIBasedExtension.tenant_id == extension_data.tenant_id, + APIBasedExtension.name == extension_data.name, + ) + .limit(1) ) if is_name_existed: raise ValueError("name must be unique, it is already existed") else: # case two: check existing data, name must be unique - is_name_existed = ( - db.session.query(APIBasedExtension) - .filter_by(tenant_id=extension_data.tenant_id) - .filter_by(name=extension_data.name) - .where(APIBasedExtension.id != extension_data.id) - .first() + is_name_existed = db.session.scalar( + select(APIBasedExtension) + .where( + APIBasedExtension.tenant_id == extension_data.tenant_id, + APIBasedExtension.name == extension_data.name, + APIBasedExtension.id != extension_data.id, + ) + .limit(1) ) if is_name_existed: diff --git a/api/services/app_service.py b/api/services/app_service.py index e9aeb6c43d0..87d52a3159c 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -6,6 +6,7 @@ import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from sqlalchemy import select from configs import dify_config from constants.model_template import default_app_templates @@ -433,9 +434,7 @@ class AppService: meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" elif provider_type == "api": try: - provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first() - ) + provider: ApiToolProvider | None = db.session.get(ApiToolProvider, provider_id) if provider is None: raise ValueError(f"provider not found for tool {tool_name}") meta["tool_icons"][tool_name] = json.loads(provider.icon) @@ -451,7 +450,7 @@ class AppService: :param app_id: app id :return: app code """ - site = db.session.query(Site).where(Site.app_id == app_id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_id).limit(1)) if not site: raise ValueError(f"App with id {app_id} not found") return str(site.code) @@ -463,7 +462,7 @@ class AppService: :param app_code: app code :return: app id """ - site = db.session.query(Site).where(Site.code == app_code).first() + site = db.session.scalar(select(Site).where(Site.code == app_code).limit(1)) if not site: raise ValueError(f"App with code {app_code} not found") return str(site.app_id) diff --git a/api/services/feedback_service.py b/api/services/feedback_service.py index e7473d371b9..d6c338a830d 100644 --- a/api/services/feedback_service.py +++ b/api/services/feedback_service.py @@ -4,7 +4,7 @@ import json from datetime import datetime from flask import Response -from sqlalchemy import or_ +from sqlalchemy import or_, select from extensions.ext_database import db from models.enums import FeedbackRating @@ -41,8 +41,8 @@ class FeedbackService: raise ValueError(f"Unsupported format: {format_type}") # Build base query - query = ( - db.session.query(MessageFeedback, Message, Conversation, App, Account) + stmt = ( + select(MessageFeedback, Message, Conversation, App, Account) .join(Message, MessageFeedback.message_id == Message.id) .join(Conversation, MessageFeedback.conversation_id == Conversation.id) .join(App, MessageFeedback.app_id == App.id) @@ -52,36 +52,36 @@ class FeedbackService: # Apply filters if from_source: - query = query.filter(MessageFeedback.from_source == from_source) + stmt = stmt.where(MessageFeedback.from_source == from_source) if rating: - query = query.filter(MessageFeedback.rating == rating) + stmt = stmt.where(MessageFeedback.rating == rating) if has_comment is not None: if has_comment: - query = query.filter(MessageFeedback.content.isnot(None), MessageFeedback.content != "") + stmt = stmt.where(MessageFeedback.content.isnot(None), MessageFeedback.content != "") else: - query = query.filter(or_(MessageFeedback.content.is_(None), MessageFeedback.content == "")) + stmt = stmt.where(or_(MessageFeedback.content.is_(None), MessageFeedback.content == "")) if start_date: try: start_dt = datetime.strptime(start_date, "%Y-%m-%d") - query = query.filter(MessageFeedback.created_at >= start_dt) + stmt = stmt.where(MessageFeedback.created_at >= start_dt) except ValueError: raise ValueError(f"Invalid start_date format: {start_date}. Use YYYY-MM-DD") if end_date: try: end_dt = datetime.strptime(end_date, "%Y-%m-%d") - query = query.filter(MessageFeedback.created_at <= end_dt) + stmt = stmt.where(MessageFeedback.created_at <= end_dt) except ValueError: raise ValueError(f"Invalid end_date format: {end_date}. Use YYYY-MM-DD") # Order by creation date (newest first) - query = query.order_by(MessageFeedback.created_at.desc()) + stmt = stmt.order_by(MessageFeedback.created_at.desc()) # Execute query - results = query.all() + results = db.session.execute(stmt).all() # Prepare data for export export_data = [] diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 215a8c85285..c3b00fe1094 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -6,6 +6,7 @@ from uuid import uuid4 import yaml from flask_login import current_user +from sqlalchemy import select from constants import DOCUMENT_EXTENSIONS from core.plugin.impl.plugin import PluginInstaller @@ -26,7 +27,7 @@ logger = logging.getLogger(__name__) class RagPipelineTransformService: def transform_dataset(self, dataset_id: str): - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = db.session.get(Dataset, dataset_id) if not dataset: raise ValueError("Dataset not found") if dataset.pipeline_id and dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE: @@ -306,7 +307,7 @@ class RagPipelineTransformService: jina_node_id = "1752491761974" firecrawl_node_id = "1752565402678" - documents = db.session.query(Document).where(Document.dataset_id == dataset.id).all() + documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset.id)).all() for document in documents: data_source_info_dict = document.data_source_info_dict @@ -316,7 +317,7 @@ class RagPipelineTransformService: document.data_source_type = DataSourceType.LOCAL_FILE file_id = data_source_info_dict.get("upload_file_id") if file_id: - file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + file = db.session.get(UploadFile, file_id) if file: data_source_info = json.dumps( { diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 6b211a5632b..9819822103b 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -1,3 +1,5 @@ +from sqlalchemy import select + from configs import dify_config from extensions.ext_database import db from models.model import AccountTrialAppRecord, TrialApp @@ -27,7 +29,7 @@ class RecommendedAppService: apps = result["recommended_apps"] for app in apps: app_id = app["app_id"] - trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first() + trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1)) if trial_app_model: app["can_trial"] = True else: @@ -46,7 +48,7 @@ class RecommendedAppService: result: dict = retrieval_instance.get_recommend_app_detail(app_id) if FeatureService.get_system_features().enable_trial_app: app_id = result["id"] - trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first() + trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1)) if trial_app_model: result["can_trial"] = True else: @@ -60,10 +62,10 @@ class RecommendedAppService: :param app_id: app id :return: """ - account_trial_app_record = ( - db.session.query(AccountTrialAppRecord) + account_trial_app_record = db.session.scalar( + select(AccountTrialAppRecord) .where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id) - .first() + .limit(1) ) if account_trial_app_record: account_trial_app_record.count += 1 diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index d0f4f279683..77d1767c4be 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -1,5 +1,7 @@ from typing import Union +from sqlalchemy import select + from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account @@ -16,16 +18,15 @@ class SavedMessageService: ) -> InfiniteScrollPagination: if not user: raise ValueError("User is required") - saved_messages = ( - db.session.query(SavedMessage) + saved_messages = db.session.scalars( + select(SavedMessage) .where( SavedMessage.app_id == app_model.id, SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), SavedMessage.created_by == user.id, ) .order_by(SavedMessage.created_at.desc()) - .all() - ) + ).all() message_ids = [sm.message_id for sm in saved_messages] return MessageService.pagination_by_last_id( @@ -36,15 +37,15 @@ class SavedMessageService: def save(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): if not user: return - saved_message = ( - db.session.query(SavedMessage) + saved_message = db.session.scalar( + select(SavedMessage) .where( SavedMessage.app_id == app_model.id, SavedMessage.message_id == message_id, SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), SavedMessage.created_by == user.id, ) - .first() + .limit(1) ) if saved_message: @@ -66,15 +67,15 @@ class SavedMessageService: def delete(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): if not user: return - saved_message = ( - db.session.query(SavedMessage) + saved_message = db.session.scalar( + select(SavedMessage) .where( SavedMessage.app_id == app_model.id, SavedMessage.message_id == message_id, SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), SavedMessage.created_by == user.id, ) - .first() + .limit(1) ) if not saved_message: diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 8e3c36e0998..f7447d3c104 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -332,12 +332,11 @@ class BuiltinToolManageService: get builtin tool provider credentials """ with db.session.no_autoflush: - providers = ( - db.session.query(BuiltinToolProvider) - .filter_by(tenant_id=tenant_id, provider=provider_name) + providers = db.session.scalars( + select(BuiltinToolProvider) + .where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name) .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - .all() - ) + ).all() if len(providers) == 0: return [] diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 3f78b823a63..e7266cb8e94 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,6 +1,7 @@ import logging from graphon.model_runtime.entities.model_entities import ModelType +from sqlalchemy import delete, select from core.model_manager import ModelInstance, ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword @@ -29,7 +30,7 @@ class VectorService: for segment in segments: if doc_form == IndexStructureType.PARENT_CHILD_INDEX: - dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first() + dataset_document = db.session.get(DatasetDocument, segment.document_id) if not dataset_document: logger.warning( "Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s", @@ -38,11 +39,7 @@ class VectorService: ) continue # get the process rule - processing_rule = ( - db.session.query(DatasetProcessRule) - .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) - .first() - ) + processing_rule = db.session.get(DatasetProcessRule, dataset_document.dataset_process_rule_id) if not processing_rule: raise ValueError("No processing rule found.") # get embedding model instance @@ -271,8 +268,8 @@ class VectorService: vector.delete_by_ids(old_attachment_ids) # Delete existing segment attachment bindings in one operation - db.session.query(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id).delete( - synchronize_session=False + db.session.execute( + delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id) ) if not attachment_ids: @@ -280,7 +277,7 @@ class VectorService: return # Bulk fetch upload files - only fetch needed fields - upload_file_list = db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all() + upload_file_list = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all() if not upload_file_list: db.session.commit() diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 3b3ee6dd92e..8f365c7c51f 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -138,14 +138,14 @@ class WorkflowService: if workflow_id: return self.get_published_workflow_by_id(app_model, workflow_id) # fetch draft workflow by app_model - workflow = ( - db.session.query(Workflow) + workflow = db.session.scalar( + select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == Workflow.VERSION_DRAFT, ) - .first() + .limit(1) ) # return draft workflow @@ -155,14 +155,14 @@ class WorkflowService: """ fetch published workflow by workflow_id """ - workflow = ( - db.session.query(Workflow) + workflow = db.session.scalar( + select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id, ) - .first() + .limit(1) ) if not workflow: return None @@ -182,14 +182,14 @@ class WorkflowService: return None # fetch published workflow by workflow_id - workflow = ( - db.session.query(Workflow) + workflow = db.session.scalar( + select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == app_model.workflow_id, ) - .first() + .limit(1) ) return workflow @@ -544,14 +544,14 @@ class WorkflowService: # Use the same fallback logic as runtime: get the first available credential # ordered by is_default DESC, created_at ASC (same as tool_manager.py) - default_provider = ( - db.session.query(BuiltinToolProvider) + default_provider = db.session.scalar( + select(BuiltinToolProvider) .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider, ) .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - .first() + .limit(1) ) if not default_provider: diff --git a/api/tests/test_containers_integration_tests/services/test_feedback_service.py b/api/tests/test_containers_integration_tests/services/test_feedback_service.py index 771f4067753..d82933ccb90 100644 --- a/api/tests/test_containers_integration_tests/services/test_feedback_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feedback_service.py @@ -99,7 +99,7 @@ class TestFeedbackService: ) ] - mock_db_session.query.return_value = mock_query + mock_db_session.execute.return_value = mock_query # Test CSV export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -138,7 +138,7 @@ class TestFeedbackService: ) ] - mock_db_session.query.return_value = mock_query + mock_db_session.execute.return_value = mock_query # Test JSON export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") @@ -175,7 +175,7 @@ class TestFeedbackService: ) ] - mock_db_session.query.return_value = mock_query + mock_db_session.execute.return_value = mock_query # Test with filters result = FeedbackService.export_feedbacks( @@ -188,11 +188,8 @@ class TestFeedbackService: format_type="csv", ) - # Verify filters were applied - assert mock_query.filter.called - filter_calls = mock_query.filter.call_args_list - # At least three filter invocations are expected (source, rating, comment) - assert len(filter_calls) >= 3 + # Verify query was executed (filters are baked into the select statement) + assert mock_db_session.execute.called def test_export_feedbacks_no_data(self, mock_db_session, sample_data): """Test exporting feedback when no data exists.""" @@ -206,7 +203,7 @@ class TestFeedbackService: mock_query.order_by.return_value = mock_query mock_query.all.return_value = [] - mock_db_session.query.return_value = mock_query + mock_db_session.execute.return_value = mock_query result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -271,7 +268,7 @@ class TestFeedbackService: ) ] - mock_db_session.query.return_value = mock_query + mock_db_session.execute.return_value = mock_query # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") @@ -329,7 +326,7 @@ class TestFeedbackService: ) ] - mock_db_session.query.return_value = mock_query + mock_db_session.execute.return_value = mock_query # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -367,7 +364,7 @@ class TestFeedbackService: ), ] - mock_db_session.query.return_value = mock_query + mock_db_session.execute.return_value = mock_query # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") diff --git a/api/tests/unit_tests/services/test_vector_service.py b/api/tests/unit_tests/services/test_vector_service.py index 598ff3fc3a4..a78a033f4d3 100644 --- a/api/tests/unit_tests/services/test_vector_service.py +++ b/api/tests/unit_tests/services/test_vector_service.py @@ -77,22 +77,12 @@ def _make_segment( def _mock_db_session_for_update_multimodel(*, upload_files: list[_UploadFileStub] | None) -> MagicMock: session = MagicMock(name="session") - binding_query = MagicMock(name="binding_query") - binding_query.where.return_value = binding_query - binding_query.delete.return_value = 1 + # db.session.execute() is used for delete(SegmentAttachmentBinding).where(...) + session.execute = MagicMock(name="execute") - upload_query = MagicMock(name="upload_query") - upload_query.where.return_value = upload_query - upload_query.all.return_value = upload_files or [] + # db.session.scalars(select(UploadFile).where(...)).all() returns upload files + session.scalars.return_value.all.return_value = upload_files or [] - def query_side_effect(model: object) -> MagicMock: - if model is vector_service_module.SegmentAttachmentBinding: - return binding_query - if model is vector_service_module.UploadFile: - return upload_query - return MagicMock(name=f"query({model})") - - session.query.side_effect = query_side_effect db_mock = MagicMock(name="db") db_mock.session = session return db_mock @@ -165,22 +155,15 @@ def _mock_parent_child_queries( ) -> MagicMock: session = MagicMock(name="session") - doc_query = MagicMock(name="doc_query") - doc_query.filter_by.return_value = doc_query - doc_query.first.return_value = dataset_document + get_dispatch: dict[object, object | None] = { + vector_service_module.DatasetDocument: dataset_document, + vector_service_module.DatasetProcessRule: processing_rule, + } - rule_query = MagicMock(name="rule_query") - rule_query.where.return_value = rule_query - rule_query.first.return_value = processing_rule + def get_side_effect(model: object, pk: object) -> object | None: + return get_dispatch.get(model) - def query_side_effect(model: object) -> MagicMock: - if model is vector_service_module.DatasetDocument: - return doc_query - if model is vector_service_module.DatasetProcessRule: - return rule_query - return MagicMock(name=f"query({model})") - - session.query.side_effect = query_side_effect + session.get.side_effect = get_side_effect db_mock = MagicMock(name="db") db_mock.session = session return db_mock @@ -609,7 +592,7 @@ def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids( vector_cls.assert_called_once_with(dataset=dataset) vector_instance.delete_by_ids.assert_called_once_with(["old-1", "old-2"]) - db_mock.session.query.assert_called_once_with(vector_service_module.SegmentAttachmentBinding) + db_mock.session.execute.assert_called_once() db_mock.session.commit.assert_called_once() db_mock.session.add_all.assert_not_called() vector_instance.add_texts.assert_not_called() @@ -644,6 +627,8 @@ def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_up binding_ctor = MagicMock(side_effect=lambda **kwargs: kwargs) monkeypatch.setattr(vector_service_module, "SegmentAttachmentBinding", binding_ctor) + monkeypatch.setattr(vector_service_module, "delete", MagicMock()) + monkeypatch.setattr(vector_service_module, "select", MagicMock()) logger_mock = MagicMock() monkeypatch.setattr(vector_service_module, "logger", logger_mock) @@ -677,6 +662,8 @@ def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops monkeypatch.setattr( vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs) ) + monkeypatch.setattr(vector_service_module, "delete", MagicMock()) + monkeypatch.setattr(vector_service_module, "select", MagicMock()) VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset) @@ -698,6 +685,8 @@ def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: monkeypatch.setattr( vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs) ) + monkeypatch.setattr(vector_service_module, "delete", MagicMock()) + monkeypatch.setattr(vector_service_module, "select", MagicMock()) logger_mock = MagicMock() monkeypatch.setattr(vector_service_module, "logger", logger_mock) diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index cd71981bcf1..1b253eb2f1f 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -268,7 +268,7 @@ class TestWorkflowService: Provides mock implementations of: - session.add(): Adding new records - session.commit(): Committing transactions - - session.query(): Querying database + - session.scalar(): Scalar queries - session.execute(): Executing SQL statements """ with patch("services.workflow_service.db") as mock_db: @@ -276,7 +276,7 @@ class TestWorkflowService: mock_db.session = mock_session mock_session.add = MagicMock() mock_session.commit = MagicMock() - mock_session.query = MagicMock() + mock_session.scalar = MagicMock() mock_session.execute = MagicMock() yield mock_db @@ -338,10 +338,8 @@ class TestWorkflowService: app = TestWorkflowAssociatedDataFactory.create_app_mock() mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock() - # Mock database query - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + # Mock db.session.scalar() used by get_draft_workflow + mock_db_session.session.scalar.return_value = mock_workflow result = workflow_service.get_draft_workflow(app) @@ -351,10 +349,8 @@ class TestWorkflowService: """Test get_draft_workflow returns None when no draft exists.""" app = TestWorkflowAssociatedDataFactory.create_app_mock() - # Mock database query to return None - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = None + # Mock db.session.scalar() to return None + mock_db_session.session.scalar.return_value = None result = workflow_service.get_draft_workflow(app) @@ -366,10 +362,8 @@ class TestWorkflowService: workflow_id = "workflow-123" mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(version="v1") - # Mock database query - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + # Mock db.session.scalar() used by get_published_workflow_by_id + mock_db_session.session.scalar.return_value = mock_workflow result = workflow_service.get_draft_workflow(app, workflow_id=workflow_id) @@ -384,10 +378,8 @@ class TestWorkflowService: workflow_id = "workflow-123" mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1") - # Mock database query - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + # Mock db.session.scalar() used by get_published_workflow_by_id + mock_db_session.session.scalar.return_value = mock_workflow result = workflow_service.get_published_workflow_by_id(app, workflow_id) @@ -406,10 +398,8 @@ class TestWorkflowService: workflow_id=workflow_id, version=Workflow.VERSION_DRAFT ) - # Mock database query - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + # Mock db.session.scalar() used by get_published_workflow_by_id + mock_db_session.session.scalar.return_value = mock_workflow with pytest.raises(IsDraftWorkflowError): workflow_service.get_published_workflow_by_id(app, workflow_id) @@ -419,10 +409,8 @@ class TestWorkflowService: app = TestWorkflowAssociatedDataFactory.create_app_mock() workflow_id = "nonexistent-workflow" - # Mock database query to return None - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = None + # Mock db.session.scalar() to return None + mock_db_session.session.scalar.return_value = None result = workflow_service.get_published_workflow_by_id(app, workflow_id) @@ -434,10 +422,8 @@ class TestWorkflowService: app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=workflow_id) mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1") - # Mock database query - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + # Mock db.session.scalar() used by get_published_workflow + mock_db_session.session.scalar.return_value = mock_workflow result = workflow_service.get_published_workflow(app) @@ -466,11 +452,9 @@ class TestWorkflowService: graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph() features = {"file_upload": {"enabled": False}} - # Mock get_draft_workflow to return None (no existing draft) + # Mock db.session.scalar() to return None (no existing draft) # This simulates the first time a workflow is created for an app - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = None + mock_db_session.session.scalar.return_value = None with ( patch.object(workflow_service, "validate_features_structure"), @@ -504,12 +488,10 @@ class TestWorkflowService: features = {"file_upload": {"enabled": False}} unique_hash = "test-hash-123" - # Mock existing draft workflow + # Mock existing draft workflow via db.session.scalar() mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash=unique_hash) - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + mock_db_session.session.scalar.return_value = mock_workflow with ( patch.object(workflow_service, "validate_features_structure"), @@ -545,12 +527,10 @@ class TestWorkflowService: graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph() features = {} - # Mock existing draft workflow with different hash + # Mock existing draft workflow with different hash via db.session.scalar() mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash="old-hash") - mock_query = MagicMock() - mock_db_session.session.query.return_value = mock_query - mock_query.where.return_value.first.return_value = mock_workflow + mock_db_session.session.scalar.return_value = mock_workflow with pytest.raises(WorkflowHashNotEqualError): workflow_service.sync_draft_workflow( diff --git a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py index 439d203c58d..175900071b1 100644 --- a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py +++ b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py @@ -347,7 +347,7 @@ class TestGetBuiltinToolProviderCredentials: def test_returns_empty_when_no_providers(self, mock_db): mock_db.session.no_autoflush.__enter__ = MagicMock(return_value=None) mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False) - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] result = BuiltinToolManageService.get_builtin_tool_provider_credentials("t", "google") @@ -362,7 +362,7 @@ class TestGetBuiltinToolProviderCredentials: mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False) provider = MagicMock(provider="google", is_default=False) - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [provider] + mock_db.session.scalars.return_value.all.return_value = [provider] mock_encrypter = MagicMock() mock_encrypter.decrypt.return_value = {"key": "decrypted"}