From 68e4d13f36aeee4b1527b180e7ebfb56dc2fbb41 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Fri, 3 Apr 2026 17:47:22 -0500 Subject: [PATCH] refactor: select in annotation_service (#34503) Co-authored-by: Asuka Minato Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/services/annotation_service.py | 163 +++---- .../services/test_annotation_service.py | 431 ++++-------------- 2 files changed, 151 insertions(+), 443 deletions(-) diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index b472a269505..ae5facbec09 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -6,7 +6,7 @@ import pandas as pd logger = logging.getLogger(__name__) from typing import TypedDict -from sqlalchemy import or_, select +from sqlalchemy import delete, or_, select, update from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound @@ -51,10 +51,8 @@ class AppAnnotationService: def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info current_user, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: @@ -66,7 +64,9 @@ class AppAnnotationService: if args.get("message_id"): message_id = str(args["message_id"]) - message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app.id).first() + message = db.session.scalar( + select(Message).where(Message.id == message_id, Message.app_id == app.id).limit(1) + ) if not message: raise NotFound("Message Not Exists.") @@ -95,7 +95,9 @@ class AppAnnotationService: db.session.add(annotation) db.session.commit() - annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + annotation_setting = db.session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) + ) assert current_tenant_id is not None if annotation_setting: add_annotation_to_index_task.delay( @@ -151,10 +153,8 @@ class AppAnnotationService: def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): # get app info _, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: @@ -193,20 +193,17 @@ class AppAnnotationService: """ # get app info _, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") - annotations = ( - db.session.query(MessageAnnotation) + annotations = db.session.scalars( + select(MessageAnnotation) .where(MessageAnnotation.app_id == app_id) .order_by(MessageAnnotation.created_at.desc()) - .all() - ) + ).all() # Sanitize CSV-injectable fields to prevent formula injection for annotation in annotations: @@ -223,10 +220,8 @@ class AppAnnotationService: def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info current_user, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: @@ -242,7 +237,9 @@ class AppAnnotationService: db.session.add(annotation) db.session.commit() # if annotation reply is enabled , add annotation to index - annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + annotation_setting = db.session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) + ) if annotation_setting: add_annotation_to_index_task.delay( annotation.id, @@ -257,16 +254,14 @@ class AppAnnotationService: def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): # get app info _, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") - annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + annotation = db.session.get(MessageAnnotation, annotation_id) if not annotation: raise NotFound("Annotation not found") @@ -280,8 +275,8 @@ class AppAnnotationService: db.session.commit() # if annotation reply is enabled , add annotation to index - app_annotation_setting = ( - db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + app_annotation_setting = db.session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) ) if app_annotation_setting: @@ -299,16 +294,14 @@ class AppAnnotationService: def delete_app_annotation(cls, app_id: str, annotation_id: str): # get app info _, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") - annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + annotation = db.session.get(MessageAnnotation, annotation_id) if not annotation: raise NotFound("Annotation not found") @@ -324,8 +317,8 @@ class AppAnnotationService: db.session.commit() # if annotation reply is enabled , delete annotation index - app_annotation_setting = ( - db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + app_annotation_setting = db.session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) ) if app_annotation_setting: @@ -337,22 +330,19 @@ class AppAnnotationService: def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]): # get app info _, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") # Fetch annotations and their settings in a single query - annotations_to_delete = ( - db.session.query(MessageAnnotation, AppAnnotationSetting) + annotations_to_delete = db.session.execute( + select(MessageAnnotation, AppAnnotationSetting) .outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id) .where(MessageAnnotation.id.in_(annotation_ids)) - .all() - ) + ).all() if not annotations_to_delete: return {"deleted_count": 0} @@ -361,9 +351,9 @@ class AppAnnotationService: annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete] # Step 2: Bulk delete hit histories in a single query - db.session.query(AppAnnotationHitHistory).where( - AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete) - ).delete(synchronize_session=False) + db.session.execute( + delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete)) + ) # Step 3: Trigger async tasks for search index deletion for annotation, annotation_setting in annotations_to_delete: @@ -373,11 +363,10 @@ class AppAnnotationService: ) # Step 4: Bulk delete annotations in a single query - deleted_count = ( - db.session.query(MessageAnnotation) - .where(MessageAnnotation.id.in_(annotation_ids_to_delete)) - .delete(synchronize_session=False) + delete_result = db.session.execute( + delete(MessageAnnotation).where(MessageAnnotation.id.in_(annotation_ids_to_delete)) ) + deleted_count = getattr(delete_result, "rowcount", 0) db.session.commit() return {"deleted_count": deleted_count} @@ -398,10 +387,8 @@ class AppAnnotationService: # get app info current_user, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: @@ -522,16 +509,14 @@ class AppAnnotationService: def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): _, current_tenant_id = current_account_with_tenant() # get app info - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") - annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + annotation = db.session.get(MessageAnnotation, annotation_id) if not annotation: raise NotFound("Annotation not found") @@ -551,7 +536,7 @@ class AppAnnotationService: @classmethod def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None: - annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + annotation = db.session.get(MessageAnnotation, annotation_id) if not annotation: return None @@ -571,8 +556,10 @@ class AppAnnotationService: score: float, ): # add hit count to annotation - db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).update( - {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False + db.session.execute( + update(MessageAnnotation) + .where(MessageAnnotation.id == annotation_id) + .values(hit_count=MessageAnnotation.hit_count + 1) ) annotation_hit_history = AppAnnotationHitHistory( @@ -593,16 +580,16 @@ class AppAnnotationService: def get_app_annotation_setting_by_app_id(cls, app_id: str) -> AnnotationSettingDict | AnnotationSettingDisabledDict: _, current_tenant_id = current_account_with_tenant() # get app info - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") - annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + annotation_setting = db.session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) + ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail if collection_binding_detail: @@ -630,22 +617,20 @@ class AppAnnotationService: ) -> AnnotationSettingDict: current_user, current_tenant_id = current_account_with_tenant() # get app info - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") - annotation_setting = ( - db.session.query(AppAnnotationSetting) + annotation_setting = db.session.scalar( + select(AppAnnotationSetting) .where( AppAnnotationSetting.app_id == app_id, AppAnnotationSetting.id == annotation_setting_id, ) - .first() + .limit(1) ) if not annotation_setting: raise NotFound("App annotation not found") @@ -678,26 +663,26 @@ class AppAnnotationService: @classmethod def clear_all_annotations(cls, app_id: str): _, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") # if annotation reply is enabled, delete annotation index - app_annotation_setting = ( - db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + app_annotation_setting = db.session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) ) - annotations_query = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id) - for annotation in annotations_query.yield_per(100): - annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).where( - AppAnnotationHitHistory.annotation_id == annotation.id - ) - for annotation_hit_history in annotation_hit_histories_query.yield_per(100): + annotations_iter = db.session.scalars( + select(MessageAnnotation).where(MessageAnnotation.app_id == app_id) + ).yield_per(100) + for annotation in annotations_iter: + hit_histories_iter = db.session.scalars( + select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation.id) + ).yield_per(100) + for annotation_hit_history in hit_histories_iter: db.session.delete(annotation_hit_history) # if annotation reply is enabled, delete annotation index diff --git a/api/tests/unit_tests/services/test_annotation_service.py b/api/tests/unit_tests/services/test_annotation_service.py index 0aacfc7f134..4295315f48d 100644 --- a/api/tests/unit_tests/services/test_annotation_service.py +++ b/api/tests/unit_tests/services/test_annotation_service.py @@ -79,10 +79,7 @@ class TestAppAnnotationServiceUpInsert: patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = None - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = None # Act & Assert with pytest.raises(NotFound): @@ -100,10 +97,7 @@ class TestAppAnnotationServiceUpInsert: patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = app # Act & Assert with pytest.raises(ValueError): @@ -121,15 +115,7 @@ class TestAppAnnotationServiceUpInsert: patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - message_query = MagicMock() - message_query.where.return_value = message_query - message_query.first.return_value = None - - mock_db.session.query.side_effect = [app_query, message_query] + mock_db.session.scalar.side_effect = [app, None] # Act & Assert with pytest.raises(NotFound): @@ -152,19 +138,7 @@ class TestAppAnnotationServiceUpInsert: patch("services.annotation_service.db") as mock_db, patch("services.annotation_service.add_annotation_to_index_task") as mock_task, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - message_query = MagicMock() - message_query.where.return_value = message_query - message_query.first.return_value = message - - setting_query = MagicMock() - setting_query.where.return_value = setting_query - setting_query.first.return_value = setting - - mock_db.session.query.side_effect = [app_query, message_query, setting_query] + mock_db.session.scalar.side_effect = [app, message, setting] # Act result = AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) @@ -202,19 +176,7 @@ class TestAppAnnotationServiceUpInsert: patch("services.annotation_service.MessageAnnotation", return_value=annotation_instance) as mock_cls, patch("services.annotation_service.add_annotation_to_index_task") as mock_task, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - message_query = MagicMock() - message_query.where.return_value = message_query - message_query.first.return_value = message - - setting_query = MagicMock() - setting_query.where.return_value = setting_query - setting_query.first.return_value = None - - mock_db.session.query.side_effect = [app_query, message_query, setting_query] + mock_db.session.scalar.side_effect = [app, message, None] # Act result = AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) @@ -245,10 +207,7 @@ class TestAppAnnotationServiceUpInsert: patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = app # Act & Assert with pytest.raises(ValueError): @@ -270,15 +229,7 @@ class TestAppAnnotationServiceUpInsert: patch("services.annotation_service.MessageAnnotation", return_value=annotation_instance) as mock_cls, patch("services.annotation_service.add_annotation_to_index_task") as mock_task, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - setting_query = MagicMock() - setting_query.where.return_value = setting_query - setting_query.first.return_value = setting - - mock_db.session.query.side_effect = [app_query, setting_query] + mock_db.session.scalar.side_effect = [app, setting] # Act result = AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) @@ -406,10 +357,7 @@ class TestAppAnnotationServiceListAndExport: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = None - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = None # Act & Assert with pytest.raises(NotFound): @@ -427,10 +375,7 @@ class TestAppAnnotationServiceListAndExport: patch("services.annotation_service.db") as mock_db, patch("libs.helper.escape_like_pattern", return_value="safe"), ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = app mock_db.paginate.return_value = pagination # Act @@ -451,10 +396,7 @@ class TestAppAnnotationServiceListAndExport: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = app mock_db.paginate.return_value = pagination # Act @@ -481,16 +423,8 @@ class TestAppAnnotationServiceListAndExport: patch("services.annotation_service.db") as mock_db, patch("services.annotation_service.CSVSanitizer.sanitize_value", side_effect=lambda v: f"safe:{v}"), ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - annotation_query = MagicMock() - annotation_query.where.return_value = annotation_query - annotation_query.order_by.return_value = annotation_query - annotation_query.all.return_value = [annotation1, annotation2] - - mock_db.session.query.side_effect = [app_query, annotation_query] + mock_db.session.scalar.return_value = app + mock_db.session.scalars.return_value.all.return_value = [annotation1, annotation2] # Act result = AppAnnotationService.export_annotation_list_by_app_id(app.id) @@ -511,10 +445,7 @@ class TestAppAnnotationServiceListAndExport: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = None - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = None # Act & Assert with pytest.raises(NotFound): @@ -534,10 +465,7 @@ class TestAppAnnotationServiceDirectManipulation: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = None - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = None # Act & Assert with pytest.raises(NotFound): @@ -554,10 +482,7 @@ class TestAppAnnotationServiceDirectManipulation: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = app # Act & Assert with pytest.raises(ValueError): @@ -579,15 +504,7 @@ class TestAppAnnotationServiceDirectManipulation: patch("services.annotation_service.MessageAnnotation", return_value=annotation_instance) as mock_cls, patch("services.annotation_service.add_annotation_to_index_task") as mock_task, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - setting_query = MagicMock() - setting_query.where.return_value = setting_query - setting_query.first.return_value = setting - - mock_db.session.query.side_effect = [app_query, setting_query] + mock_db.session.scalar.side_effect = [app, setting] # Act result = AppAnnotationService.insert_app_annotation_directly(args, app.id) @@ -621,15 +538,8 @@ class TestAppAnnotationServiceDirectManipulation: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - annotation_query = MagicMock() - annotation_query.where.return_value = annotation_query - annotation_query.first.return_value = None - - mock_db.session.query.side_effect = [app_query, annotation_query] + mock_db.session.scalar.return_value = app + mock_db.session.get.return_value = None # Act & Assert with pytest.raises(NotFound): @@ -645,10 +555,7 @@ class TestAppAnnotationServiceDirectManipulation: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = None - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = None # Act & Assert with pytest.raises(NotFound): @@ -666,15 +573,8 @@ class TestAppAnnotationServiceDirectManipulation: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - annotation_query = MagicMock() - annotation_query.where.return_value = annotation_query - annotation_query.first.return_value = annotation - - mock_db.session.query.side_effect = [app_query, annotation_query] + mock_db.session.scalar.return_value = app + mock_db.session.get.return_value = annotation # Act & Assert with pytest.raises(ValueError): @@ -695,19 +595,8 @@ class TestAppAnnotationServiceDirectManipulation: patch("services.annotation_service.db") as mock_db, patch("services.annotation_service.update_annotation_to_index_task") as mock_task, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - annotation_query = MagicMock() - annotation_query.where.return_value = annotation_query - annotation_query.first.return_value = annotation - - setting_query = MagicMock() - setting_query.where.return_value = setting_query - setting_query.first.return_value = setting - - mock_db.session.query.side_effect = [app_query, annotation_query, setting_query] + mock_db.session.scalar.side_effect = [app, setting] + mock_db.session.get.return_value = annotation # Act result = AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id) @@ -740,22 +629,11 @@ class TestAppAnnotationServiceDirectManipulation: patch("services.annotation_service.db") as mock_db, patch("services.annotation_service.delete_annotation_index_task") as mock_task, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - annotation_query = MagicMock() - annotation_query.where.return_value = annotation_query - annotation_query.first.return_value = annotation - - setting_query = MagicMock() - setting_query.where.return_value = setting_query - setting_query.first.return_value = setting + mock_db.session.scalar.side_effect = [app, setting] + mock_db.session.get.return_value = annotation scalars_result = MagicMock() scalars_result.all.return_value = [history1, history2] - - mock_db.session.query.side_effect = [app_query, annotation_query, setting_query] mock_db.session.scalars.return_value = scalars_result # Act @@ -782,10 +660,7 @@ class TestAppAnnotationServiceDirectManipulation: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = None - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = None # Act & Assert with pytest.raises(NotFound): @@ -801,15 +676,8 @@ class TestAppAnnotationServiceDirectManipulation: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - annotation_query = MagicMock() - annotation_query.where.return_value = annotation_query - annotation_query.first.return_value = None - - mock_db.session.query.side_effect = [app_query, annotation_query] + mock_db.session.scalar.return_value = app + mock_db.session.get.return_value = None # Act & Assert with pytest.raises(NotFound): @@ -825,16 +693,8 @@ class TestAppAnnotationServiceDirectManipulation: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - annotations_query = MagicMock() - annotations_query.outerjoin.return_value = annotations_query - annotations_query.where.return_value = annotations_query - annotations_query.all.return_value = [] - - mock_db.session.query.side_effect = [app_query, annotations_query] + mock_db.session.scalar.return_value = app + mock_db.session.execute.return_value.all.return_value = [] # Act result = AppAnnotationService.delete_app_annotations_in_batch(app.id, ["ann-1"]) @@ -851,10 +711,7 @@ class TestAppAnnotationServiceDirectManipulation: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = None - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = None # Act & Assert with pytest.raises(NotFound): @@ -874,24 +731,14 @@ class TestAppAnnotationServiceDirectManipulation: patch("services.annotation_service.db") as mock_db, patch("services.annotation_service.delete_annotation_index_task") as mock_task, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app + mock_db.session.scalar.return_value = app - annotations_query = MagicMock() - annotations_query.outerjoin.return_value = annotations_query - annotations_query.where.return_value = annotations_query - annotations_query.all.return_value = [(annotation1, setting), (annotation2, None)] - - hit_history_query = MagicMock() - hit_history_query.where.return_value = hit_history_query - hit_history_query.delete.return_value = None - - delete_query = MagicMock() - delete_query.where.return_value = delete_query - delete_query.delete.return_value = 2 - - mock_db.session.query.side_effect = [app_query, annotations_query, hit_history_query, delete_query] + # First execute().all() for multi-column query, subsequent execute() calls for deletes + execute_result_multi = MagicMock() + execute_result_multi.all.return_value = [(annotation1, setting), (annotation2, None)] + execute_result_delete = MagicMock() + execute_result_delete.rowcount = 2 + mock_db.session.execute.side_effect = [execute_result_multi, MagicMock(), execute_result_delete] # Act result = AppAnnotationService.delete_app_annotations_in_batch(app.id, ["ann-1", "ann-2"]) @@ -915,10 +762,7 @@ class TestAppAnnotationServiceBatchImport: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = None - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = None # Act & Assert with pytest.raises(NotFound): @@ -941,10 +785,7 @@ class TestAppAnnotationServiceBatchImport: new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), ), ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = app # Act result = AppAnnotationService.batch_import_app_annotations(app.id, file) @@ -968,10 +809,7 @@ class TestAppAnnotationServiceBatchImport: new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), ), ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = app # Act result = AppAnnotationService.batch_import_app_annotations(app.id, file) @@ -999,10 +837,7 @@ class TestAppAnnotationServiceBatchImport: new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=2), ), ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = app # Act result = AppAnnotationService.batch_import_app_annotations(app.id, file) @@ -1028,10 +863,7 @@ class TestAppAnnotationServiceBatchImport: new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=1, ANNOTATION_IMPORT_MIN_RECORDS=1), ), ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = app # Act result = AppAnnotationService.batch_import_app_annotations(app.id, file) @@ -1061,10 +893,7 @@ class TestAppAnnotationServiceBatchImport: new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), ), ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = app # Act result = AppAnnotationService.batch_import_app_annotations(app.id, file) @@ -1090,10 +919,7 @@ class TestAppAnnotationServiceBatchImport: new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), ), ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = app # Act result = AppAnnotationService.batch_import_app_annotations(app.id, file) @@ -1119,10 +945,7 @@ class TestAppAnnotationServiceBatchImport: new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), ), ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = app # Act result = AppAnnotationService.batch_import_app_annotations(app.id, file) @@ -1148,10 +971,7 @@ class TestAppAnnotationServiceBatchImport: new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), ), ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = app # Act result = AppAnnotationService.batch_import_app_annotations(app.id, file) @@ -1182,10 +1002,7 @@ class TestAppAnnotationServiceBatchImport: new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), ), ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = app # Act result = AppAnnotationService.batch_import_app_annotations(app.id, file) @@ -1218,10 +1035,7 @@ class TestAppAnnotationServiceBatchImport: new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), ), ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = app # Act result = AppAnnotationService.batch_import_app_annotations(app.id, file) @@ -1257,10 +1071,7 @@ class TestAppAnnotationServiceBatchImport: new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), ), ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = app mock_redis.zadd.side_effect = RuntimeError("boom") mock_redis.zrem.side_effect = RuntimeError("cleanup-failed") @@ -1285,10 +1096,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = None - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = None # Act & Assert with pytest.raises(NotFound): @@ -1306,15 +1114,8 @@ class TestAppAnnotationServiceHitHistoryAndSettings: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - annotation_query = MagicMock() - annotation_query.where.return_value = annotation_query - annotation_query.first.return_value = annotation - - mock_db.session.query.side_effect = [app_query, annotation_query] + mock_db.session.scalar.return_value = app + mock_db.session.get.return_value = annotation mock_db.paginate.return_value = pagination # Act @@ -1334,15 +1135,8 @@ class TestAppAnnotationServiceHitHistoryAndSettings: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - annotation_query = MagicMock() - annotation_query.where.return_value = annotation_query - annotation_query.first.return_value = None - - mock_db.session.query.side_effect = [app_query, annotation_query] + mock_db.session.scalar.return_value = app + mock_db.session.get.return_value = None # Act & Assert with pytest.raises(NotFound): @@ -1352,10 +1146,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings: """Test get_annotation_by_id returns None when not found.""" # Arrange with patch("services.annotation_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query + mock_db.session.get.return_value = None # Act result = AppAnnotationService.get_annotation_by_id("ann-1") @@ -1368,10 +1159,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings: # Arrange annotation = _make_annotation("ann-1") with patch("services.annotation_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = annotation - mock_db.session.query.return_value = query + mock_db.session.get.return_value = annotation # Act result = AppAnnotationService.get_annotation_by_id("ann-1") @@ -1386,10 +1174,6 @@ class TestAppAnnotationServiceHitHistoryAndSettings: patch("services.annotation_service.db") as mock_db, patch("services.annotation_service.AppAnnotationHitHistory") as mock_history_cls, ): - query = MagicMock() - query.where.return_value = query - mock_db.session.query.return_value = query - # Act AppAnnotationService.add_annotation_history( annotation_id="ann-1", @@ -1404,7 +1188,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings: ) # Assert - query.update.assert_called_once() + mock_db.session.execute.assert_called_once() mock_history_cls.assert_called_once() mock_db.session.add.assert_called_once() mock_db.session.commit.assert_called_once() @@ -1420,15 +1204,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - setting_query = MagicMock() - setting_query.where.return_value = setting_query - setting_query.first.return_value = setting - - mock_db.session.query.side_effect = [app_query, setting_query] + mock_db.session.scalar.side_effect = [app, setting] # Act result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) @@ -1448,10 +1224,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = None - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = None # Act & Assert with pytest.raises(NotFound): @@ -1468,15 +1241,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - setting_query = MagicMock() - setting_query.where.return_value = setting_query - setting_query.first.return_value = setting - - mock_db.session.query.side_effect = [app_query, setting_query] + mock_db.session.scalar.side_effect = [app, setting] # Act result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) @@ -1495,15 +1260,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - setting_query = MagicMock() - setting_query.where.return_value = setting_query - setting_query.first.return_value = None - - mock_db.session.query.side_effect = [app_query, setting_query] + mock_db.session.scalar.side_effect = [app, None] # Act result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) @@ -1525,15 +1282,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings: patch("services.annotation_service.db") as mock_db, patch("services.annotation_service.naive_utc_now", return_value="now"), ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - setting_query = MagicMock() - setting_query.where.return_value = setting_query - setting_query.first.return_value = setting - - mock_db.session.query.side_effect = [app_query, setting_query] + mock_db.session.scalar.side_effect = [app, setting] # Act result = AppAnnotationService.update_app_annotation_setting(app.id, setting.id, args) @@ -1560,15 +1309,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings: patch("services.annotation_service.db") as mock_db, patch("services.annotation_service.naive_utc_now", return_value="now"), ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - setting_query = MagicMock() - setting_query.where.return_value = setting_query - setting_query.first.return_value = setting - - mock_db.session.query.side_effect = [app_query, setting_query] + mock_db.session.scalar.side_effect = [app, setting] # Act result = AppAnnotationService.update_app_annotation_setting(app.id, setting.id, args) @@ -1587,10 +1328,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = None - mock_db.session.query.return_value = app_query + mock_db.session.scalar.return_value = None # Act & Assert with pytest.raises(NotFound): @@ -1606,15 +1344,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - app_query = MagicMock() - app_query.where.return_value = app_query - app_query.first.return_value = app - - setting_query = MagicMock() - setting_query.where.return_value = setting_query - setting_query.first.return_value = None - - mock_db.session.query.side_effect = [app_query, setting_query] + mock_db.session.scalar.side_effect = [app, None] # Act & Assert with pytest.raises(NotFound): @@ -1634,25 +1364,21 @@ class TestAppAnnotationServiceClearAll: annotation2 = _make_annotation("ann-2") history = MagicMock(spec=AppAnnotationHitHistory) - def query_side_effect(*args: object, **kwargs: object) -> MagicMock: - query = MagicMock() - query.where.return_value = query - if App in args: - query.first.return_value = app - elif AppAnnotationSetting in args: - query.first.return_value = setting - elif MessageAnnotation in args: - query.yield_per.return_value = [annotation1, annotation2] - elif AppAnnotationHitHistory in args: - query.yield_per.return_value = [history] - return query - with ( patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, patch("services.annotation_service.delete_annotation_index_task") as mock_task, ): - mock_db.session.query.side_effect = query_side_effect + # scalar calls: app lookup, annotation_setting lookup + mock_db.session.scalar.side_effect = [app, setting] + # scalars calls: first for annotations iteration, then for each annotation's hit histories + annotations_scalars = MagicMock() + annotations_scalars.yield_per.return_value = [annotation1, annotation2] + histories_scalars_1 = MagicMock() + histories_scalars_1.yield_per.return_value = [history] + histories_scalars_2 = MagicMock() + histories_scalars_2.yield_per.return_value = [] + mock_db.session.scalars.side_effect = [annotations_scalars, histories_scalars_1, histories_scalars_2] # Act result = AppAnnotationService.clear_all_annotations(app.id) @@ -1675,10 +1401,7 @@ class TestAppAnnotationServiceClearAll: patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), patch("services.annotation_service.db") as mock_db, ): - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query + mock_db.session.scalar.return_value = None # Act & Assert with pytest.raises(NotFound):