diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py index f853a124efa..5e7847d784f 100644 --- a/api/controllers/service_api/app/file_preview.py +++ b/api/controllers/service_api/app/file_preview.py @@ -4,6 +4,7 @@ from urllib.parse import quote from flask import Response, request from flask_restx import Resource from pydantic import BaseModel, Field +from sqlalchemy import select from controllers.common.file_response import enforce_download_for_html from controllers.common.schema import register_schema_model @@ -102,27 +103,27 @@ class FilePreviewApi(Resource): raise FileAccessDeniedError("Invalid file or app identifier") # First, find the MessageFile that references this upload file - message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first() + message_file = db.session.scalar(select(MessageFile).where(MessageFile.upload_file_id == file_id).limit(1)) if not message_file: raise FileNotFoundError("File not found in message context") # Get the message and verify it belongs to the requesting app - message = ( - db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first() + message = db.session.scalar( + select(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).limit(1) ) if not message: raise FileAccessDeniedError("File access denied: not owned by requesting app") # Get the actual upload file record - upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = db.session.get(UploadFile, file_id) if not upload_file: raise FileNotFoundError("Upload file record not found") # Additional security: verify tenant isolation - app = db.session.query(App).where(App.id == app_id).first() + app = db.session.get(App, app_id) if app and upload_file.tenant_id != app.tenant_id: raise FileAccessDeniedError("File access denied: tenant mismatch") diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py index 8b47a887bbe..bc06e8f386d 100644 --- a/api/controllers/service_api/app/site.py +++ b/api/controllers/service_api/app/site.py @@ -1,4 +1,5 @@ from flask_restx import Resource +from sqlalchemy import select from werkzeug.exceptions import Forbidden from controllers.common.fields import Site as SiteResponse @@ -28,7 +29,7 @@ class AppSiteApi(Resource): Returns the site configuration for the application including theme, icons, and text. """ - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise Forbidden() diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 7aa5b2f0925..1d52b8a737a 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -9,6 +9,7 @@ from flask import current_app, request from flask_login import user_logged_in from flask_restx import Resource from pydantic import BaseModel +from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from enums.cloud_plan import CloudPlan @@ -62,7 +63,7 @@ def validate_app_token( def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R: api_token = validate_and_get_api_token("app") - app_model = db.session.query(App).where(App.id == api_token.app_id).first() + app_model = db.session.get(App, api_token.app_id) if not app_model: raise Forbidden("The app no longer exists.") @@ -72,7 +73,7 @@ def validate_app_token( if not app_model.enable_api: raise Forbidden("The app's API service has been disabled.") - tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first() + tenant = db.session.get(Tenant, app_model.tenant_id) if tenant is None: raise ValueError("Tenant does not exist.") if tenant.status == TenantStatus.ARCHIVE: @@ -106,8 +107,8 @@ def validate_app_token( else: # For service API without end-user context, ensure an Account is logged in # so services relying on current_account_with_tenant() work correctly. - tenant_owner_info = ( - db.session.query(Tenant, Account) + tenant_owner_info = db.session.execute( + select(Tenant, Account) .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) .join(Account, TenantAccountJoin.account_id == Account.id) .where( @@ -115,8 +116,7 @@ def validate_app_token( TenantAccountJoin.role == "owner", Tenant.status == TenantStatus.NORMAL, ) - .one_or_none() - ) + ).one_or_none() if tenant_owner_info: tenant_model, account = tenant_owner_info @@ -277,29 +277,28 @@ def validate_dataset_token( # Validate dataset if dataset_id is provided if dataset_id: dataset_id = str(dataset_id) - dataset = ( - db.session.query(Dataset) + dataset = db.session.scalar( + select(Dataset) .where( Dataset.id == dataset_id, Dataset.tenant_id == api_token.tenant_id, ) - .first() + .limit(1) ) if not dataset: raise NotFound("Dataset not found.") if not dataset.enable_api: raise Forbidden("Dataset api access is not enabled.") - tenant_account_join = ( - db.session.query(Tenant, TenantAccountJoin) + tenant_account_join = db.session.execute( + select(Tenant, TenantAccountJoin) .where(Tenant.id == api_token.tenant_id) .where(TenantAccountJoin.tenant_id == Tenant.id) .where(TenantAccountJoin.role.in_(["owner"])) .where(Tenant.status == TenantStatus.NORMAL) - .one_or_none() - ) # TODO: only owner information is required, so only one is returned. + ).one_or_none() # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join - account = db.session.query(Account).where(Account.id == ta.account_id).first() + account = db.session.get(Account, ta.account_id) # Login admin if account: account.current_tenant = tenant @@ -360,7 +359,9 @@ class DatasetApiResource(Resource): method_decorators = [validate_dataset_token] def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index 3f75fd2851f..55873b06a8a 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -123,27 +123,26 @@ def _configure_session_factory(_unit_test_engine): def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account): """ - Helper to set up the mock DB query chain for tenant/account authentication. + Helper to set up the mock DB execute chain for tenant/account authentication. - This configures the mock to return (tenant, account) for the join query used - by validate_app_token and validate_dataset_token decorators. + This configures the mock to return (tenant, account) for the + db.session.execute(select(...).join().join().where()).one_or_none() + query used by validate_app_token decorator. Args: mock_db: The mocked db object mock_tenant: Mock tenant object to return mock_account: Mock account object to return """ - query = mock_db.session.query.return_value - join_chain = query.join.return_value.join.return_value - where_chain = join_chain.where.return_value - where_chain.one_or_none.return_value = (mock_tenant, mock_account) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account) def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta): """ - Helper to set up the mock DB query chain for dataset tenant authentication. + Helper to set up the mock DB execute chain for dataset tenant authentication. - This configures the mock to return (tenant, tenant_account) for the where chain + This configures the mock to return (tenant, tenant_account) for the + db.session.execute(select(...).where().where().where().where()).one_or_none() query used by validate_dataset_token decorator. Args: @@ -151,6 +150,4 @@ def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta): mock_tenant: Mock tenant object to return mock_ta: Mock tenant account object to return """ - query = mock_db.session.query.return_value - where_chain = query.where.return_value.where.return_value.where.return_value.where.return_value - where_chain.one_or_none.return_value = (mock_tenant, mock_ta) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_app.py b/api/tests/unit_tests/controllers/service_api/app/test_app.py index f8e9cf9b801..1507bf7a5fc 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_app.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_app.py @@ -65,7 +65,7 @@ class TestAppParameterApi: mock_tenant.status = "normal" # Mock DB queries for app and tenant - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -112,7 +112,7 @@ class TestAppParameterApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -153,7 +153,7 @@ class TestAppParameterApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -192,7 +192,7 @@ class TestAppParameterApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -255,7 +255,7 @@ class TestAppMetaApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -323,7 +323,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -380,7 +380,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app, mock_tenant, ] @@ -426,7 +426,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app, mock_tenant, ] @@ -478,7 +478,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app, mock_tenant, ] diff --git a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py index 1bdcd0f1a31..d83c22f2cf6 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py @@ -79,10 +79,13 @@ class TestFilePreviewApi: mock_message_file.message_id = mock_message.id with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -100,8 +103,8 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock MessageFile not found - mock_db.session.query.return_value.where.return_value.first.return_value = None + # Mock MessageFile not found via scalar() + mock_db.session.scalar.return_value = None # Execute and assert exception with pytest.raises(FileNotFoundError) as exc_info: @@ -115,8 +118,8 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock MessageFile found but Message not owned by app - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock MessageFile found but Message not owned by app via scalar() + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query - found None, # Message query - not found (access denied) ] @@ -133,12 +136,13 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock MessageFile and Message found but UploadFile not found - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query - found mock_message, # Message query - found - None, # UploadFile query - not found ] + # Mock get() for UploadFile - not found + mock_db.session.get.return_value = None # Execute and assert exception with pytest.raises(FileNotFoundError) as exc_info: @@ -161,10 +165,13 @@ class TestFilePreviewApi: mock_message_file.message_id = mock_message.id with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -262,10 +269,13 @@ class TestFilePreviewApi: mock_storage.load.return_value = mock_generator with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -301,10 +311,13 @@ class TestFilePreviewApi: mock_storage.load.side_effect = Exception("Storage error") with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries for validation - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -327,8 +340,8 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database query to raise unexpected exception - mock_db.session.query.side_effect = Exception("Unexpected database error") + # Mock database scalar to raise unexpected exception + mock_db.session.scalar.side_effect = Exception("Unexpected database error") # Execute and assert exception with pytest.raises(FileAccessDeniedError) as exc_info: diff --git a/api/tests/unit_tests/controllers/service_api/conftest.py b/api/tests/unit_tests/controllers/service_api/conftest.py index 01d2d1e7c02..eddba5a5170 100644 --- a/api/tests/unit_tests/controllers/service_api/conftest.py +++ b/api/tests/unit_tests/controllers/service_api/conftest.py @@ -119,11 +119,8 @@ class AuthenticationMocker: @staticmethod def setup_db_queries(mock_db, mock_app, mock_tenant, mock_account=None): - """Configure mock_db to return app and tenant in sequence.""" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ - mock_app, - mock_tenant, - ] + """Configure mock_db to return app and tenant via session.get().""" + mock_db.session.get.side_effect = [mock_app, mock_tenant] if mock_account: mock_ta = Mock() @@ -136,11 +133,9 @@ class AuthenticationMocker: mock_ta = Mock() mock_ta.account_id = mock_account.id - mock_query = mock_db.session.query.return_value - target_mock = mock_query.where.return_value.where.return_value.where.return_value.where.return_value - target_mock.one_or_none.return_value = (mock_tenant, mock_ta) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_account + mock_db.session.get.return_value = mock_account @pytest.fixture diff --git a/api/tests/unit_tests/controllers/service_api/test_site.py b/api/tests/unit_tests/controllers/service_api/test_site.py index b58caf3be18..c0b40d070a5 100644 --- a/api/tests/unit_tests/controllers/service_api/test_site.py +++ b/api/tests/unit_tests/controllers/service_api/test_site.py @@ -88,7 +88,7 @@ class TestAppSiteApi: mock_app_model.tenant = mock_tenant # Mock wraps.db for authentication - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -98,7 +98,7 @@ class TestAppSiteApi: setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) # Mock site.db for site query - mock_db.session.query.return_value.where.return_value.first.return_value = mock_site + mock_db.session.scalar.return_value = mock_site # Act with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -109,7 +109,7 @@ class TestAppSiteApi: assert response["title"] == "Test Site" assert response["icon"] == "icon-url" assert response["description"] == "Site description" - mock_db.session.query.assert_called_once_with(Site) + mock_db.session.scalar.assert_called_once() @patch("controllers.service_api.wraps.user_logged_in") @patch("controllers.service_api.app.site.db") @@ -140,7 +140,7 @@ class TestAppSiteApi: mock_tenant.status = TenantStatus.NORMAL mock_app_model.tenant = mock_tenant - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -150,7 +150,7 @@ class TestAppSiteApi: setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) # Mock site query to return None - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -187,7 +187,7 @@ class TestAppSiteApi: mock_tenant = Mock() mock_tenant.status = TenantStatus.NORMAL - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -197,7 +197,7 @@ class TestAppSiteApi: setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) # Mock site query - mock_db.session.query.return_value.where.return_value.first.return_value = mock_site + mock_db.session.scalar.return_value = mock_site # Set tenant status to archived AFTER authentication mock_app_model.tenant.status = TenantStatus.ARCHIVE @@ -230,7 +230,7 @@ class TestAppSiteApi: mock_tenant.status = TenantStatus.NORMAL mock_app_model.tenant = mock_tenant - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -258,7 +258,7 @@ class TestAppSiteApi: mock_site.icon_type = "image" mock_site.created_at = "2024-01-01T00:00:00" mock_site.updated_at = "2024-01-01T00:00:00" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_site + mock_db.session.scalar.return_value = mock_site # Act with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -267,4 +267,4 @@ class TestAppSiteApi: # Assert # The query was executed successfully (site returned), which validates the correct query was made - mock_db.session.query.assert_called_once_with(Site) + mock_db.session.scalar.assert_called_once() diff --git a/api/tests/unit_tests/controllers/service_api/test_wraps.py b/api/tests/unit_tests/controllers/service_api/test_wraps.py index 9c2d075f417..a2008e024b9 100644 --- a/api/tests/unit_tests/controllers/service_api/test_wraps.py +++ b/api/tests/unit_tests/controllers/service_api/test_wraps.py @@ -144,14 +144,10 @@ class TestValidateAppToken: mock_ta = Mock() mock_ta.account_id = mock_account.id - # Use side_effect to return app first, then tenant - mock_db.session.query.return_value.where.return_value.first.side_effect = [ - mock_app, - mock_tenant, - mock_account, - ] + # Use side_effect to return app first, then tenant via session.get() + mock_db.session.get.side_effect = [mock_app, mock_tenant] - # Mock the tenant owner query + # Mock the tenant owner query (execute(select(...)).one_or_none()) setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta) @validate_app_token @@ -175,7 +171,7 @@ class TestValidateAppToken: mock_api_token.app_id = str(uuid.uuid4()) mock_validate_token.return_value = mock_api_token - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None @validate_app_token def protected_view(**kwargs): @@ -198,7 +194,7 @@ class TestValidateAppToken: mock_app = Mock() mock_app.status = "abnormal" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_app + mock_db.session.get.return_value = mock_app @validate_app_token def protected_view(**kwargs): @@ -222,7 +218,7 @@ class TestValidateAppToken: mock_app = Mock() mock_app.status = "normal" mock_app.enable_api = False - mock_db.session.query.return_value.where.return_value.first.return_value = mock_app + mock_db.session.get.return_value = mock_app @validate_app_token def protected_view(**kwargs): @@ -474,11 +470,11 @@ class TestValidateDatasetToken: mock_account.id = mock_ta.account_id mock_account.current_tenant = mock_tenant - # Mock the tenant account join query + # Mock the tenant account join query (execute(select(...)).one_or_none()) setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta) - # Mock the account query - mock_db.session.query.return_value.where.return_value.first.return_value = mock_account + # Mock the account lookup via session.get() + mock_db.session.get.return_value = mock_account @validate_dataset_token def protected_view(tenant_id): @@ -501,7 +497,7 @@ class TestValidateDatasetToken: mock_api_token.tenant_id = str(uuid.uuid4()) mock_validate_token.return_value = mock_api_token - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None @validate_dataset_token def protected_view(dataset_id=None, **kwargs):