refactor: select in service API wraps, file_preview, and site controllers (#34086)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo
2026-03-25 15:01:05 +01:00
committed by GitHub
parent 52e7492cbc
commit 22dd0aa20c
9 changed files with 96 additions and 92 deletions

View File

@@ -4,6 +4,7 @@ from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from controllers.common.file_response import enforce_download_for_html
from controllers.common.schema import register_schema_model
@@ -102,27 +103,27 @@ class FilePreviewApi(Resource):
raise FileAccessDeniedError("Invalid file or app identifier")
# First, find the MessageFile that references this upload file
message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first()
message_file = db.session.scalar(select(MessageFile).where(MessageFile.upload_file_id == file_id).limit(1))
if not message_file:
raise FileNotFoundError("File not found in message context")
# Get the message and verify it belongs to the requesting app
message = (
db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first()
message = db.session.scalar(
select(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).limit(1)
)
if not message:
raise FileAccessDeniedError("File access denied: not owned by requesting app")
# Get the actual upload file record
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = db.session.get(UploadFile, file_id)
if not upload_file:
raise FileNotFoundError("Upload file record not found")
# Additional security: verify tenant isolation
app = db.session.query(App).where(App.id == app_id).first()
app = db.session.get(App, app_id)
if app and upload_file.tenant_id != app.tenant_id:
raise FileAccessDeniedError("File access denied: tenant mismatch")

View File

@@ -1,4 +1,5 @@
from flask_restx import Resource
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
from controllers.common.fields import Site as SiteResponse
@@ -28,7 +29,7 @@ class AppSiteApi(Resource):
Returns the site configuration for the application including theme, icons, and text.
"""
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise Forbidden()

View File

@@ -9,6 +9,7 @@ from flask import current_app, request
from flask_login import user_logged_in
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from enums.cloud_plan import CloudPlan
@@ -62,7 +63,7 @@ def validate_app_token(
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
api_token = validate_and_get_api_token("app")
app_model = db.session.query(App).where(App.id == api_token.app_id).first()
app_model = db.session.get(App, api_token.app_id)
if not app_model:
raise Forbidden("The app no longer exists.")
@@ -72,7 +73,7 @@ def validate_app_token(
if not app_model.enable_api:
raise Forbidden("The app's API service has been disabled.")
tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first()
tenant = db.session.get(Tenant, app_model.tenant_id)
if tenant is None:
raise ValueError("Tenant does not exist.")
if tenant.status == TenantStatus.ARCHIVE:
@@ -106,8 +107,8 @@ def validate_app_token(
else:
# For service API without end-user context, ensure an Account is logged in
# so services relying on current_account_with_tenant() work correctly.
tenant_owner_info = (
db.session.query(Tenant, Account)
tenant_owner_info = db.session.execute(
select(Tenant, Account)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.join(Account, TenantAccountJoin.account_id == Account.id)
.where(
@@ -115,8 +116,7 @@ def validate_app_token(
TenantAccountJoin.role == "owner",
Tenant.status == TenantStatus.NORMAL,
)
.one_or_none()
)
).one_or_none()
if tenant_owner_info:
tenant_model, account = tenant_owner_info
@@ -277,29 +277,28 @@ def validate_dataset_token(
# Validate dataset if dataset_id is provided
if dataset_id:
dataset_id = str(dataset_id)
dataset = (
db.session.query(Dataset)
dataset = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == api_token.tenant_id,
)
.first()
.limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
tenant_account_join = db.session.execute(
select(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
).one_or_none() # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).where(Account.id == ta.account_id).first()
account = db.session.get(Account, ta.account_id)
# Login admin
if account:
account.current_tenant = tenant
@@ -360,7 +359,9 @@ class DatasetApiResource(Resource):
method_decorators = [validate_dataset_token]
def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first()
dataset = db.session.scalar(
select(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")

View File

@@ -123,27 +123,26 @@ def _configure_session_factory(_unit_test_engine):
def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account):
"""
Helper to set up the mock DB query chain for tenant/account authentication.
Helper to set up the mock DB execute chain for tenant/account authentication.
This configures the mock to return (tenant, account) for the join query used
by validate_app_token and validate_dataset_token decorators.
This configures the mock to return (tenant, account) for the
db.session.execute(select(...).join().join().where()).one_or_none()
query used by validate_app_token decorator.
Args:
mock_db: The mocked db object
mock_tenant: Mock tenant object to return
mock_account: Mock account object to return
"""
query = mock_db.session.query.return_value
join_chain = query.join.return_value.join.return_value
where_chain = join_chain.where.return_value
where_chain.one_or_none.return_value = (mock_tenant, mock_account)
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account)
def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta):
"""
Helper to set up the mock DB query chain for dataset tenant authentication.
Helper to set up the mock DB execute chain for dataset tenant authentication.
This configures the mock to return (tenant, tenant_account) for the where chain
This configures the mock to return (tenant, tenant_account) for the
db.session.execute(select(...).where().where().where().where()).one_or_none()
query used by validate_dataset_token decorator.
Args:
@@ -151,6 +150,4 @@ def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta):
mock_tenant: Mock tenant object to return
mock_ta: Mock tenant account object to return
"""
query = mock_db.session.query.return_value
where_chain = query.where.return_value.where.return_value.where.return_value.where.return_value
where_chain.one_or_none.return_value = (mock_tenant, mock_ta)
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)

View File

@@ -65,7 +65,7 @@ class TestAppParameterApi:
mock_tenant.status = "normal"
# Mock DB queries for app and tenant
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
@@ -112,7 +112,7 @@ class TestAppParameterApi:
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
@@ -153,7 +153,7 @@ class TestAppParameterApi:
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
@@ -192,7 +192,7 @@ class TestAppParameterApi:
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
@@ -255,7 +255,7 @@ class TestAppMetaApi:
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
@@ -323,7 +323,7 @@ class TestAppInfoApi:
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
@@ -380,7 +380,7 @@ class TestAppInfoApi:
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_db.session.get.side_effect = [
mock_app,
mock_tenant,
]
@@ -426,7 +426,7 @@ class TestAppInfoApi:
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_db.session.get.side_effect = [
mock_app,
mock_tenant,
]
@@ -478,7 +478,7 @@ class TestAppInfoApi:
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_db.session.get.side_effect = [
mock_app,
mock_tenant,
]

View File

@@ -79,10 +79,13 @@ class TestFilePreviewApi:
mock_message_file.message_id = mock_message.id
with patch("controllers.service_api.app.file_preview.db") as mock_db:
# Mock database queries
mock_db.session.query.return_value.where.return_value.first.side_effect = [
# Mock scalar() for MessageFile and Message queries
mock_db.session.scalar.side_effect = [
mock_message_file, # MessageFile query
mock_message, # Message query
]
# Mock get() for UploadFile and App PK lookups
mock_db.session.get.side_effect = [
mock_upload_file, # UploadFile query
mock_app, # App query for tenant validation
]
@@ -100,8 +103,8 @@ class TestFilePreviewApi:
app_id = str(uuid.uuid4())
with patch("controllers.service_api.app.file_preview.db") as mock_db:
# Mock MessageFile not found
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Mock MessageFile not found via scalar()
mock_db.session.scalar.return_value = None
# Execute and assert exception
with pytest.raises(FileNotFoundError) as exc_info:
@@ -115,8 +118,8 @@ class TestFilePreviewApi:
app_id = str(uuid.uuid4())
with patch("controllers.service_api.app.file_preview.db") as mock_db:
# Mock MessageFile found but Message not owned by app
mock_db.session.query.return_value.where.return_value.first.side_effect = [
# Mock MessageFile found but Message not owned by app via scalar()
mock_db.session.scalar.side_effect = [
mock_message_file, # MessageFile query - found
None, # Message query - not found (access denied)
]
@@ -133,12 +136,13 @@ class TestFilePreviewApi:
app_id = str(uuid.uuid4())
with patch("controllers.service_api.app.file_preview.db") as mock_db:
# Mock MessageFile and Message found but UploadFile not found
mock_db.session.query.return_value.where.return_value.first.side_effect = [
# Mock scalar() for MessageFile and Message
mock_db.session.scalar.side_effect = [
mock_message_file, # MessageFile query - found
mock_message, # Message query - found
None, # UploadFile query - not found
]
# Mock get() for UploadFile - not found
mock_db.session.get.return_value = None
# Execute and assert exception
with pytest.raises(FileNotFoundError) as exc_info:
@@ -161,10 +165,13 @@ class TestFilePreviewApi:
mock_message_file.message_id = mock_message.id
with patch("controllers.service_api.app.file_preview.db") as mock_db:
# Mock database queries
mock_db.session.query.return_value.where.return_value.first.side_effect = [
# Mock scalar() for MessageFile and Message queries
mock_db.session.scalar.side_effect = [
mock_message_file, # MessageFile query
mock_message, # Message query
]
# Mock get() for UploadFile and App PK lookups
mock_db.session.get.side_effect = [
mock_upload_file, # UploadFile query
mock_app, # App query for tenant validation
]
@@ -262,10 +269,13 @@ class TestFilePreviewApi:
mock_storage.load.return_value = mock_generator
with patch("controllers.service_api.app.file_preview.db") as mock_db:
# Mock database queries
mock_db.session.query.return_value.where.return_value.first.side_effect = [
# Mock scalar() for MessageFile and Message queries
mock_db.session.scalar.side_effect = [
mock_message_file, # MessageFile query
mock_message, # Message query
]
# Mock get() for UploadFile and App PK lookups
mock_db.session.get.side_effect = [
mock_upload_file, # UploadFile query
mock_app, # App query for tenant validation
]
@@ -301,10 +311,13 @@ class TestFilePreviewApi:
mock_storage.load.side_effect = Exception("Storage error")
with patch("controllers.service_api.app.file_preview.db") as mock_db:
# Mock database queries for validation
mock_db.session.query.return_value.where.return_value.first.side_effect = [
# Mock scalar() for MessageFile and Message queries
mock_db.session.scalar.side_effect = [
mock_message_file, # MessageFile query
mock_message, # Message query
]
# Mock get() for UploadFile and App PK lookups
mock_db.session.get.side_effect = [
mock_upload_file, # UploadFile query
mock_app, # App query for tenant validation
]
@@ -327,8 +340,8 @@ class TestFilePreviewApi:
app_id = str(uuid.uuid4())
with patch("controllers.service_api.app.file_preview.db") as mock_db:
# Mock database query to raise unexpected exception
mock_db.session.query.side_effect = Exception("Unexpected database error")
# Mock database scalar to raise unexpected exception
mock_db.session.scalar.side_effect = Exception("Unexpected database error")
# Execute and assert exception
with pytest.raises(FileAccessDeniedError) as exc_info:

View File

@@ -119,11 +119,8 @@ class AuthenticationMocker:
@staticmethod
def setup_db_queries(mock_db, mock_app, mock_tenant, mock_account=None):
"""Configure mock_db to return app and tenant in sequence."""
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app,
mock_tenant,
]
"""Configure mock_db to return app and tenant via session.get()."""
mock_db.session.get.side_effect = [mock_app, mock_tenant]
if mock_account:
mock_ta = Mock()
@@ -136,11 +133,9 @@ class AuthenticationMocker:
mock_ta = Mock()
mock_ta.account_id = mock_account.id
mock_query = mock_db.session.query.return_value
target_mock = mock_query.where.return_value.where.return_value.where.return_value.where.return_value
target_mock.one_or_none.return_value = (mock_tenant, mock_ta)
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
mock_db.session.query.return_value.where.return_value.first.return_value = mock_account
mock_db.session.get.return_value = mock_account
@pytest.fixture

View File

@@ -88,7 +88,7 @@ class TestAppSiteApi:
mock_app_model.tenant = mock_tenant
# Mock wraps.db for authentication
mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [
mock_wraps_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
@@ -98,7 +98,7 @@ class TestAppSiteApi:
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
# Mock site.db for site query
mock_db.session.query.return_value.where.return_value.first.return_value = mock_site
mock_db.session.scalar.return_value = mock_site
# Act
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -109,7 +109,7 @@ class TestAppSiteApi:
assert response["title"] == "Test Site"
assert response["icon"] == "icon-url"
assert response["description"] == "Site description"
mock_db.session.query.assert_called_once_with(Site)
mock_db.session.scalar.assert_called_once()
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.app.site.db")
@@ -140,7 +140,7 @@ class TestAppSiteApi:
mock_tenant.status = TenantStatus.NORMAL
mock_app_model.tenant = mock_tenant
mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [
mock_wraps_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
@@ -150,7 +150,7 @@ class TestAppSiteApi:
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
# Mock site query to return None
mock_db.session.query.return_value.where.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -187,7 +187,7 @@ class TestAppSiteApi:
mock_tenant = Mock()
mock_tenant.status = TenantStatus.NORMAL
mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [
mock_wraps_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
@@ -197,7 +197,7 @@ class TestAppSiteApi:
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
# Mock site query
mock_db.session.query.return_value.where.return_value.first.return_value = mock_site
mock_db.session.scalar.return_value = mock_site
# Set tenant status to archived AFTER authentication
mock_app_model.tenant.status = TenantStatus.ARCHIVE
@@ -230,7 +230,7 @@ class TestAppSiteApi:
mock_tenant.status = TenantStatus.NORMAL
mock_app_model.tenant = mock_tenant
mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [
mock_wraps_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
@@ -258,7 +258,7 @@ class TestAppSiteApi:
mock_site.icon_type = "image"
mock_site.created_at = "2024-01-01T00:00:00"
mock_site.updated_at = "2024-01-01T00:00:00"
mock_db.session.query.return_value.where.return_value.first.return_value = mock_site
mock_db.session.scalar.return_value = mock_site
# Act
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
@@ -267,4 +267,4 @@ class TestAppSiteApi:
# Assert
# The query was executed successfully (site returned), which validates the correct query was made
mock_db.session.query.assert_called_once_with(Site)
mock_db.session.scalar.assert_called_once()

View File

@@ -144,14 +144,10 @@ class TestValidateAppToken:
mock_ta = Mock()
mock_ta.account_id = mock_account.id
# Use side_effect to return app first, then tenant
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app,
mock_tenant,
mock_account,
]
# Use side_effect to return app first, then tenant via session.get()
mock_db.session.get.side_effect = [mock_app, mock_tenant]
# Mock the tenant owner query
# Mock the tenant owner query (execute(select(...)).one_or_none())
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
@validate_app_token
@@ -175,7 +171,7 @@ class TestValidateAppToken:
mock_api_token.app_id = str(uuid.uuid4())
mock_validate_token.return_value = mock_api_token
mock_db.session.query.return_value.where.return_value.first.return_value = None
mock_db.session.get.return_value = None
@validate_app_token
def protected_view(**kwargs):
@@ -198,7 +194,7 @@ class TestValidateAppToken:
mock_app = Mock()
mock_app.status = "abnormal"
mock_db.session.query.return_value.where.return_value.first.return_value = mock_app
mock_db.session.get.return_value = mock_app
@validate_app_token
def protected_view(**kwargs):
@@ -222,7 +218,7 @@ class TestValidateAppToken:
mock_app = Mock()
mock_app.status = "normal"
mock_app.enable_api = False
mock_db.session.query.return_value.where.return_value.first.return_value = mock_app
mock_db.session.get.return_value = mock_app
@validate_app_token
def protected_view(**kwargs):
@@ -474,11 +470,11 @@ class TestValidateDatasetToken:
mock_account.id = mock_ta.account_id
mock_account.current_tenant = mock_tenant
# Mock the tenant account join query
# Mock the tenant account join query (execute(select(...)).one_or_none())
setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta)
# Mock the account query
mock_db.session.query.return_value.where.return_value.first.return_value = mock_account
# Mock the account lookup via session.get()
mock_db.session.get.return_value = mock_account
@validate_dataset_token
def protected_view(tenant_id):
@@ -501,7 +497,7 @@ class TestValidateDatasetToken:
mock_api_token.tenant_id = str(uuid.uuid4())
mock_validate_token.return_value = mock_api_token
mock_db.session.query.return_value.where.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
@validate_dataset_token
def protected_view(dataset_id=None, **kwargs):