From 4e1d0604391e2df11c6df7b3864b9121e9304fe8 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Wed, 1 Apr 2026 18:37:27 +0200 Subject: [PATCH] refactor: select in message_service and ops_service (#34414) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/services/message_service.py | 57 ++++--- api/services/ops_service.py | 32 ++-- .../services/test_message_service.py | 147 +++--------------- .../unit_tests/services/test_ops_service.py | 53 ++++--- 4 files changed, 94 insertions(+), 195 deletions(-) diff --git a/api/services/message_service.py b/api/services/message_service.py index a04f9cbe012..5c2978db21b 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -3,6 +3,7 @@ from typing import Union from graphon.model_runtime.entities.model_entities import ModelType from pydantic import TypeAdapter +from sqlalchemy import select from sqlalchemy.orm import sessionmaker from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager @@ -75,17 +76,15 @@ class MessageService: fetch_limit = limit + 1 if first_id: - first_message = ( - db.session.query(Message) - .where(Message.conversation_id == conversation.id, Message.id == first_id) - .first() + first_message = db.session.scalar( + select(Message).where(Message.conversation_id == conversation.id, Message.id == first_id).limit(1) ) if not first_message: raise FirstMessageNotExistsError() - history_messages = ( - db.session.query(Message) + history_messages = db.session.scalars( + select(Message) .where( Message.conversation_id == conversation.id, Message.created_at < first_message.created_at, @@ -93,16 +92,14 @@ class MessageService: ) .order_by(Message.created_at.desc()) .limit(fetch_limit) - .all() - ) + ).all() else: - history_messages = ( - db.session.query(Message) + history_messages = db.session.scalars( + select(Message) .where(Message.conversation_id == conversation.id) .order_by(Message.created_at.desc()) .limit(fetch_limit) - .all() - ) + ).all() has_more = False if len(history_messages) > limit: @@ -129,7 +126,7 @@ class MessageService: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) - base_query = db.session.query(Message) + stmt = select(Message) fetch_limit = limit + 1 @@ -138,28 +135,27 @@ class MessageService: app_model=app_model, user=user, conversation_id=conversation_id ) - base_query = base_query.where(Message.conversation_id == conversation.id) + stmt = stmt.where(Message.conversation_id == conversation.id) # Check if include_ids is not None and not empty to avoid WHERE false condition if include_ids is not None: if len(include_ids) == 0: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) - base_query = base_query.where(Message.id.in_(include_ids)) + stmt = stmt.where(Message.id.in_(include_ids)) if last_id: - last_message = base_query.where(Message.id == last_id).first() + last_message = db.session.scalar(stmt.where(Message.id == last_id).limit(1)) if not last_message: raise LastMessageNotExistsError() - history_messages = ( - base_query.where(Message.created_at < last_message.created_at, Message.id != last_message.id) + history_messages = db.session.scalars( + stmt.where(Message.created_at < last_message.created_at, Message.id != last_message.id) .order_by(Message.created_at.desc()) .limit(fetch_limit) - .all() - ) + ).all() else: - history_messages = base_query.order_by(Message.created_at.desc()).limit(fetch_limit).all() + history_messages = db.session.scalars(stmt.order_by(Message.created_at.desc()).limit(fetch_limit)).all() has_more = False if len(history_messages) > limit: @@ -214,21 +210,20 @@ class MessageService: def get_all_messages_feedbacks(cls, app_model: App, page: int, limit: int): """Get all feedbacks of an app""" offset = (page - 1) * limit - feedbacks = ( - db.session.query(MessageFeedback) + feedbacks = db.session.scalars( + select(MessageFeedback) .where(MessageFeedback.app_id == app_model.id) .order_by(MessageFeedback.created_at.desc(), MessageFeedback.id.desc()) .limit(limit) .offset(offset) - .all() - ) + ).all() return [record.to_dict() for record in feedbacks] @classmethod def get_message(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): - message = ( - db.session.query(Message) + message = db.session.scalar( + select(Message) .where( Message.id == message_id, Message.app_id == app_model.id, @@ -236,7 +231,7 @@ class MessageService: Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), Message.from_account_id == (user.id if isinstance(user, Account) else None), ) - .first() + .limit(1) ) if not message: @@ -282,10 +277,10 @@ class MessageService: ) else: if not conversation.override_model_configs: - app_model_config = ( - db.session.query(AppModelConfig) + app_model_config = db.session.scalar( + select(AppModelConfig) .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) - .first() + .limit(1) ) else: conversation_override_model_configs = _app_model_config_adapter.validate_json( diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 50ea832085a..2a64088dd66 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,5 +1,7 @@ from typing import Any +from sqlalchemy import select + from core.ops.entities.config_entity import BaseTracingConfig from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map from extensions.ext_database import db @@ -15,17 +17,17 @@ class OpsService: :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig | None = ( - db.session.query(TraceAppConfig) + trace_config_data: TraceAppConfig | None = db.session.scalar( + select(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) - .first() + .limit(1) ) if not trace_config_data: return None # decrypt_token and obfuscated_token - app = db.session.query(App).where(App.id == app_id).first() + app = db.session.get(App, app_id) if not app: return None tenant_id = app.tenant_id @@ -182,17 +184,17 @@ class OpsService: project_url = None # check if trace config already exists - trace_config_data: TraceAppConfig | None = ( - db.session.query(TraceAppConfig) + trace_config_data: TraceAppConfig | None = db.session.scalar( + select(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) - .first() + .limit(1) ) if trace_config_data: return None # get tenant id - app = db.session.query(App).where(App.id == app_id).first() + app = db.session.get(App, app_id) if not app: return None tenant_id = app.tenant_id @@ -224,17 +226,17 @@ class OpsService: raise ValueError(f"Invalid tracing provider: {tracing_provider}") # check if trace config already exists - current_trace_config = ( - db.session.query(TraceAppConfig) + current_trace_config = db.session.scalar( + select(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) - .first() + .limit(1) ) if not current_trace_config: return None # get tenant id - app = db.session.query(App).where(App.id == app_id).first() + app = db.session.get(App, app_id) if not app: return None tenant_id = app.tenant_id @@ -261,10 +263,10 @@ class OpsService: :param tracing_provider: tracing provider :return: """ - trace_config = ( - db.session.query(TraceAppConfig) + trace_config = db.session.scalar( + select(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) - .first() + .limit(1) ) if not trace_config: diff --git a/api/tests/unit_tests/services/test_message_service.py b/api/tests/unit_tests/services/test_message_service.py index 101b9bff24d..b6e990ebe0f 100644 --- a/api/tests/unit_tests/services/test_message_service.py +++ b/api/tests/unit_tests/services/test_message_service.py @@ -151,12 +151,7 @@ class TestMessageServicePaginationByFirstId: for i in range(5) ] - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.all.return_value = messages + mock_db.session.scalars.return_value.all.return_value = messages # Act result = MessageService.pagination_by_first_id( @@ -196,12 +191,7 @@ class TestMessageServicePaginationByFirstId: for i in range(5) ] - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.all.return_value = messages + mock_db.session.scalars.return_value.all.return_value = messages # Act result = MessageService.pagination_by_first_id( @@ -246,31 +236,8 @@ class TestMessageServicePaginationByFirstId: for i in range(5) ] - # Setup query mocks - mock_query_first = MagicMock() - mock_query_history = MagicMock() - - query_calls = [] - - def query_side_effect(*args): - if args[0] == Message: - query_calls.append(args) - if len(query_calls) == 1: - return mock_query_first - else: - return mock_query_history - - mock_db.session.query.side_effect = [mock_query_first, mock_query_history] - - # Setup first message query - mock_query_first.where.return_value = mock_query_first - mock_query_first.first.return_value = first_message - - # Setup history messages query - mock_query_history.where.return_value = mock_query_history - mock_query_history.order_by.return_value = mock_query_history - mock_query_history.limit.return_value = mock_query_history - mock_query_history.all.return_value = history_messages + mock_db.session.scalar.return_value = first_message + mock_db.session.scalars.return_value.all.return_value = history_messages # Act result = MessageService.pagination_by_first_id( @@ -285,8 +252,6 @@ class TestMessageServicePaginationByFirstId: # Assert assert len(result.data) == 5 assert result.has_more is False - mock_query_first.where.assert_called_once() - mock_query_history.where.assert_called_once() # Test 06: First message not found @patch("services.message_service.db") @@ -300,10 +265,7 @@ class TestMessageServicePaginationByFirstId: mock_conversation_service.get_conversation.return_value = conversation - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # Message not found + mock_db.session.scalar.return_value = None # Message not found # Act & Assert with pytest.raises(FirstMessageNotExistsError): @@ -336,12 +298,7 @@ class TestMessageServicePaginationByFirstId: for i in range(11) ] - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.all.return_value = messages + mock_db.session.scalars.return_value.all.return_value = messages # Act result = MessageService.pagination_by_first_id( @@ -369,12 +326,7 @@ class TestMessageServicePaginationByFirstId: mock_conversation_service.get_conversation.return_value = conversation - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] # Act result = MessageService.pagination_by_first_id( @@ -443,12 +395,7 @@ class TestMessageServicePaginationByLastId: for i in range(5) ] - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.all.return_value = messages + mock_db.session.scalars.return_value.all.return_value = messages # Act result = MessageService.pagination_by_last_id( @@ -485,22 +432,8 @@ class TestMessageServicePaginationByLastId: for i in range(6, 10) ] - # Setup base query mock that returns itself for chaining - mock_base_query = MagicMock() - mock_db.session.query.return_value = mock_base_query - - # First where() call for last_id lookup - mock_query_last = MagicMock() - mock_query_last.first.return_value = last_message - - # Second where() call for history messages - mock_query_history = MagicMock() - mock_query_history.order_by.return_value = mock_query_history - mock_query_history.limit.return_value = mock_query_history - mock_query_history.all.return_value = new_messages - - # Setup where() to return different mocks on consecutive calls - mock_base_query.where.side_effect = [mock_query_last, mock_query_history] + mock_db.session.scalar.return_value = last_message + mock_db.session.scalars.return_value.all.return_value = new_messages # Act result = MessageService.pagination_by_last_id( @@ -522,10 +455,7 @@ class TestMessageServicePaginationByLastId: app = factory.create_app_mock() user = factory.create_end_user_mock() - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # Message not found + mock_db.session.scalar.return_value = None # Message not found # Act & Assert with pytest.raises(LastMessageNotExistsError): @@ -557,12 +487,7 @@ class TestMessageServicePaginationByLastId: for i in range(5) ] - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.all.return_value = messages + mock_db.session.scalars.return_value.all.return_value = messages # Act result = MessageService.pagination_by_last_id( @@ -576,8 +501,6 @@ class TestMessageServicePaginationByLastId: # Assert assert len(result.data) == 5 assert result.has_more is False - # Verify conversation_id was used in query - mock_query.where.assert_called() mock_conversation_service.get_conversation.assert_called_once() # Test 14: Pagination with include_ids filter @@ -594,12 +517,7 @@ class TestMessageServicePaginationByLastId: factory.create_message_mock(message_id="msg-003"), ] - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.all.return_value = messages + mock_db.session.scalars.return_value.all.return_value = messages # Act result = MessageService.pagination_by_last_id( @@ -632,12 +550,7 @@ class TestMessageServicePaginationByLastId: for i in range(11) ] - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.all.return_value = messages + mock_db.session.scalars.return_value.all.return_value = messages # Act result = MessageService.pagination_by_last_id( @@ -743,17 +656,13 @@ class TestMessageServiceGetMessage: user = factory.create_end_user_mock(user_id="end-user-123") message = factory.create_message_mock() - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = message + mock_db.session.scalar.return_value = message # Act result = MessageService.get_message(app_model=app, user=user, message_id="msg-123") # Assert assert result == message - mock_query.where.assert_called_once() # Test 21: get_message success for Account (Admin) @patch("services.message_service.db") @@ -767,10 +676,7 @@ class TestMessageServiceGetMessage: user.id = "account-123" message = factory.create_message_mock() - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = message + mock_db.session.scalar.return_value = message # Act result = MessageService.get_message(app_model=app, user=user, message_id="msg-123") @@ -786,10 +692,7 @@ class TestMessageServiceGetMessage: app = factory.create_app_mock() user = factory.create_end_user_mock() - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with pytest.raises(MessageNotExistsError): @@ -899,21 +802,13 @@ class TestMessageServiceFeedback: feedback = MagicMock() feedback.to_dict.return_value = {"id": "fb-1"} - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.limit.return_value = mock_query - mock_query.offset.return_value = mock_query - mock_query.all.return_value = [feedback] + mock_db.session.scalars.return_value.all.return_value = [feedback] # Act result = MessageService.get_all_messages_feedbacks(app_model=app, page=1, limit=10) # Assert assert result == [{"id": "fb-1"}] - mock_query.limit.assert_called_with(10) - mock_query.offset.assert_called_with(0) class TestMessageServiceSuggestedQuestions: @@ -1015,10 +910,7 @@ class TestMessageServiceSuggestedQuestions: app_model_config.suggested_questions_after_answer_dict = {"enabled": True} app_model_config.model_dict = {"provider": "openai", "name": "gpt-4"} - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = app_model_config + mock_db.session.scalar.return_value = app_model_config mock_llm_gen.generate_suggested_questions_after_answer.return_value = ["Q1?"] @@ -1029,7 +921,6 @@ class TestMessageServiceSuggestedQuestions: # Assert assert result == ["Q1?"] - mock_query.first.assert_called_once() mock_llm_gen.generate_suggested_questions_after_answer.assert_called_once() # Test 30: get_suggested_questions_after_answer - Disabled Error diff --git a/api/tests/unit_tests/services/test_ops_service.py b/api/tests/unit_tests/services/test_ops_service.py index ab7b473790a..7067e3b3dd4 100644 --- a/api/tests/unit_tests/services/test_ops_service.py +++ b/api/tests/unit_tests/services/test_ops_service.py @@ -12,28 +12,27 @@ class TestOpsService: @patch("services.ops_service.OpsTraceManager") def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act result = OpsService.get_tracing_app_config("app_id", "arize") # Assert assert result is None - mock_db.session.query.assert_called_with(TraceAppConfig) @patch("services.ops_service.db") @patch("services.ops_service.OpsTraceManager") def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): # Arrange trace_config = MagicMock(spec=TraceAppConfig) - mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, None] + mock_db.session.scalar.return_value = trace_config + mock_db.session.get.return_value = None # Act result = OpsService.get_tracing_app_config("app_id", "arize") # Assert assert result is None - assert mock_db.session.query.call_count == 2 @patch("services.ops_service.db") @patch("services.ops_service.OpsTraceManager") @@ -43,7 +42,8 @@ class TestOpsService: trace_config.tracing_config = None app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + mock_db.session.scalar.return_value = trace_config + mock_db.session.get.return_value = app # Act & Assert with pytest.raises(ValueError, match="Tracing config cannot be None."): @@ -72,7 +72,8 @@ class TestOpsService: trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}} app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + mock_db.session.scalar.return_value = trace_config + mock_db.session.get.return_value = app mock_ops_trace_manager.decrypt_tracing_config.return_value = {} mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} @@ -97,7 +98,8 @@ class TestOpsService: trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}} app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + mock_db.session.scalar.return_value = trace_config + mock_db.session.get.return_value = app mock_ops_trace_manager.decrypt_tracing_config.return_value = {} mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} @@ -118,7 +120,8 @@ class TestOpsService: trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}} app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + mock_db.session.scalar.return_value = trace_config + mock_db.session.get.return_value = app mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} @@ -139,7 +142,8 @@ class TestOpsService: trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}} app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + mock_db.session.scalar.return_value = trace_config + mock_db.session.get.return_value = app mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} @@ -189,7 +193,7 @@ class TestOpsService: mock_ops_trace_manager.check_trace_config_is_effective.return_value = True mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error") mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") - mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig) + mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig) # Act result = OpsService.create_tracing_app_config("app_id", provider, config) @@ -206,7 +210,8 @@ class TestOpsService: mock_ops_trace_manager.get_trace_config_project_key.return_value = "key" app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_db.session.scalar.return_value = None + mock_db.session.get.return_value = app mock_ops_trace_manager.encrypt_tracing_config.return_value = {} # Act @@ -223,7 +228,7 @@ class TestOpsService: # Arrange provider = TracingProviderEnum.ARIZE mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig) + mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig) # Act result = OpsService.create_tracing_app_config("app_id", provider, {}) @@ -237,7 +242,8 @@ class TestOpsService: # Arrange provider = TracingProviderEnum.ARIZE mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_db.session.query.return_value.where.return_value.first.side_effect = [None, None] + mock_db.session.scalar.return_value = None + mock_db.session.get.return_value = None # Act result = OpsService.create_tracing_app_config("app_id", provider, {}) @@ -253,7 +259,8 @@ class TestOpsService: mock_ops_trace_manager.check_trace_config_is_effective.return_value = True app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_db.session.scalar.return_value = None + mock_db.session.get.return_value = app mock_ops_trace_manager.encrypt_tracing_config.return_value = {} # Act @@ -274,7 +281,8 @@ class TestOpsService: mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url" app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_db.session.scalar.return_value = None + mock_db.session.get.return_value = app mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"} # Act @@ -297,7 +305,7 @@ class TestOpsService: def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): # Arrange provider = TracingProviderEnum.ARIZE - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act result = OpsService.update_tracing_app_config("app_id", provider, {}) @@ -311,7 +319,8 @@ class TestOpsService: # Arrange provider = TracingProviderEnum.ARIZE current_config = MagicMock(spec=TraceAppConfig) - mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, None] + mock_db.session.scalar.return_value = current_config + mock_db.session.get.return_value = None # Act result = OpsService.update_tracing_app_config("app_id", provider, {}) @@ -327,7 +336,8 @@ class TestOpsService: current_config = MagicMock(spec=TraceAppConfig) app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app] + mock_db.session.scalar.return_value = current_config + mock_db.session.get.return_value = app mock_ops_trace_manager.decrypt_tracing_config.return_value = {} mock_ops_trace_manager.check_trace_config_is_effective.return_value = False @@ -344,7 +354,8 @@ class TestOpsService: current_config.to_dict.return_value = {"some": "data"} app = MagicMock(spec=App) app.tenant_id = "tenant_id" - mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app] + mock_db.session.scalar.return_value = current_config + mock_db.session.get.return_value = app mock_ops_trace_manager.decrypt_tracing_config.return_value = {} mock_ops_trace_manager.check_trace_config_is_effective.return_value = True @@ -358,7 +369,7 @@ class TestOpsService: @patch("services.ops_service.db") def test_delete_tracing_app_config_no_config(self, mock_db): # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act result = OpsService.delete_tracing_app_config("app_id", "arize") @@ -370,7 +381,7 @@ class TestOpsService: def test_delete_tracing_app_config_success(self, mock_db): # Arrange trace_config = MagicMock(spec=TraceAppConfig) - mock_db.session.query.return_value.where.return_value.first.return_value = trace_config + mock_db.session.scalar.return_value = trace_config # Act result = OpsService.delete_tracing_app_config("app_id", "arize")