mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 20:22:39 +08:00
refactor: select in service API wraps, file_preview, and site controllers (#34086)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from urllib.parse import quote
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.common.file_response import enforce_download_for_html
|
||||
from controllers.common.schema import register_schema_model
|
||||
@@ -102,27 +103,27 @@ class FilePreviewApi(Resource):
|
||||
raise FileAccessDeniedError("Invalid file or app identifier")
|
||||
|
||||
# First, find the MessageFile that references this upload file
|
||||
message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first()
|
||||
message_file = db.session.scalar(select(MessageFile).where(MessageFile.upload_file_id == file_id).limit(1))
|
||||
|
||||
if not message_file:
|
||||
raise FileNotFoundError("File not found in message context")
|
||||
|
||||
# Get the message and verify it belongs to the requesting app
|
||||
message = (
|
||||
db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first()
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).limit(1)
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise FileAccessDeniedError("File access denied: not owned by requesting app")
|
||||
|
||||
# Get the actual upload file record
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = db.session.get(UploadFile, file_id)
|
||||
|
||||
if not upload_file:
|
||||
raise FileNotFoundError("Upload file record not found")
|
||||
|
||||
# Additional security: verify tenant isolation
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
app = db.session.get(App, app_id)
|
||||
if app and upload_file.tenant_id != app.tenant_id:
|
||||
raise FileAccessDeniedError("File access denied: tenant mismatch")
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
@@ -28,7 +29,7 @@ class AppSiteApi(Resource):
|
||||
|
||||
Returns the site configuration for the application including theme, icons, and text.
|
||||
"""
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
@@ -9,6 +9,7 @@ from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
@@ -62,7 +63,7 @@ def validate_app_token(
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
api_token = validate_and_get_api_token("app")
|
||||
|
||||
app_model = db.session.query(App).where(App.id == api_token.app_id).first()
|
||||
app_model = db.session.get(App, api_token.app_id)
|
||||
if not app_model:
|
||||
raise Forbidden("The app no longer exists.")
|
||||
|
||||
@@ -72,7 +73,7 @@ def validate_app_token(
|
||||
if not app_model.enable_api:
|
||||
raise Forbidden("The app's API service has been disabled.")
|
||||
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first()
|
||||
tenant = db.session.get(Tenant, app_model.tenant_id)
|
||||
if tenant is None:
|
||||
raise ValueError("Tenant does not exist.")
|
||||
if tenant.status == TenantStatus.ARCHIVE:
|
||||
@@ -106,8 +107,8 @@ def validate_app_token(
|
||||
else:
|
||||
# For service API without end-user context, ensure an Account is logged in
|
||||
# so services relying on current_account_with_tenant() work correctly.
|
||||
tenant_owner_info = (
|
||||
db.session.query(Tenant, Account)
|
||||
tenant_owner_info = db.session.execute(
|
||||
select(Tenant, Account)
|
||||
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.join(Account, TenantAccountJoin.account_id == Account.id)
|
||||
.where(
|
||||
@@ -115,8 +116,7 @@ def validate_app_token(
|
||||
TenantAccountJoin.role == "owner",
|
||||
Tenant.status == TenantStatus.NORMAL,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
).one_or_none()
|
||||
|
||||
if tenant_owner_info:
|
||||
tenant_model, account = tenant_owner_info
|
||||
@@ -277,29 +277,28 @@ def validate_dataset_token(
|
||||
# Validate dataset if dataset_id is provided
|
||||
if dataset_id:
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == api_token.tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if not dataset.enable_api:
|
||||
raise Forbidden("Dataset api access is not enabled.")
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
tenant_account_join = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == api_token.tenant_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.role.in_(["owner"]))
|
||||
.where(Tenant.status == TenantStatus.NORMAL)
|
||||
.one_or_none()
|
||||
) # TODO: only owner information is required, so only one is returned.
|
||||
).one_or_none() # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.query(Account).where(Account.id == ta.account_id).first()
|
||||
account = db.session.get(Account, ta.account_id)
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
@@ -360,7 +359,9 @@ class DatasetApiResource(Resource):
|
||||
method_decorators = [validate_dataset_token]
|
||||
|
||||
def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first()
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
@@ -123,27 +123,26 @@ def _configure_session_factory(_unit_test_engine):
|
||||
|
||||
def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account):
|
||||
"""
|
||||
Helper to set up the mock DB query chain for tenant/account authentication.
|
||||
Helper to set up the mock DB execute chain for tenant/account authentication.
|
||||
|
||||
This configures the mock to return (tenant, account) for the join query used
|
||||
by validate_app_token and validate_dataset_token decorators.
|
||||
This configures the mock to return (tenant, account) for the
|
||||
db.session.execute(select(...).join().join().where()).one_or_none()
|
||||
query used by validate_app_token decorator.
|
||||
|
||||
Args:
|
||||
mock_db: The mocked db object
|
||||
mock_tenant: Mock tenant object to return
|
||||
mock_account: Mock account object to return
|
||||
"""
|
||||
query = mock_db.session.query.return_value
|
||||
join_chain = query.join.return_value.join.return_value
|
||||
where_chain = join_chain.where.return_value
|
||||
where_chain.one_or_none.return_value = (mock_tenant, mock_account)
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account)
|
||||
|
||||
|
||||
def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta):
|
||||
"""
|
||||
Helper to set up the mock DB query chain for dataset tenant authentication.
|
||||
Helper to set up the mock DB execute chain for dataset tenant authentication.
|
||||
|
||||
This configures the mock to return (tenant, tenant_account) for the where chain
|
||||
This configures the mock to return (tenant, tenant_account) for the
|
||||
db.session.execute(select(...).where().where().where().where()).one_or_none()
|
||||
query used by validate_dataset_token decorator.
|
||||
|
||||
Args:
|
||||
@@ -151,6 +150,4 @@ def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta):
|
||||
mock_tenant: Mock tenant object to return
|
||||
mock_ta: Mock tenant account object to return
|
||||
"""
|
||||
query = mock_db.session.query.return_value
|
||||
where_chain = query.where.return_value.where.return_value.where.return_value.where.return_value
|
||||
where_chain.one_or_none.return_value = (mock_tenant, mock_ta)
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
|
||||
|
||||
@@ -65,7 +65,7 @@ class TestAppParameterApi:
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
# Mock DB queries for app and tenant
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_db.session.get.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
@@ -112,7 +112,7 @@ class TestAppParameterApi:
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_db.session.get.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
@@ -153,7 +153,7 @@ class TestAppParameterApi:
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_db.session.get.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
@@ -192,7 +192,7 @@ class TestAppParameterApi:
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_db.session.get.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
@@ -255,7 +255,7 @@ class TestAppMetaApi:
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_db.session.get.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
@@ -323,7 +323,7 @@ class TestAppInfoApi:
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_db.session.get.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
@@ -380,7 +380,7 @@ class TestAppInfoApi:
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_db.session.get.side_effect = [
|
||||
mock_app,
|
||||
mock_tenant,
|
||||
]
|
||||
@@ -426,7 +426,7 @@ class TestAppInfoApi:
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_db.session.get.side_effect = [
|
||||
mock_app,
|
||||
mock_tenant,
|
||||
]
|
||||
@@ -478,7 +478,7 @@ class TestAppInfoApi:
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_db.session.get.side_effect = [
|
||||
mock_app,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
@@ -79,10 +79,13 @@ class TestFilePreviewApi:
|
||||
mock_message_file.message_id = mock_message.id
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock database queries
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
# Mock scalar() for MessageFile and Message queries
|
||||
mock_db.session.scalar.side_effect = [
|
||||
mock_message_file, # MessageFile query
|
||||
mock_message, # Message query
|
||||
]
|
||||
# Mock get() for UploadFile and App PK lookups
|
||||
mock_db.session.get.side_effect = [
|
||||
mock_upload_file, # UploadFile query
|
||||
mock_app, # App query for tenant validation
|
||||
]
|
||||
@@ -100,8 +103,8 @@ class TestFilePreviewApi:
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock MessageFile not found
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
# Mock MessageFile not found via scalar()
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Execute and assert exception
|
||||
with pytest.raises(FileNotFoundError) as exc_info:
|
||||
@@ -115,8 +118,8 @@ class TestFilePreviewApi:
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock MessageFile found but Message not owned by app
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
# Mock MessageFile found but Message not owned by app via scalar()
|
||||
mock_db.session.scalar.side_effect = [
|
||||
mock_message_file, # MessageFile query - found
|
||||
None, # Message query - not found (access denied)
|
||||
]
|
||||
@@ -133,12 +136,13 @@ class TestFilePreviewApi:
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock MessageFile and Message found but UploadFile not found
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
# Mock scalar() for MessageFile and Message
|
||||
mock_db.session.scalar.side_effect = [
|
||||
mock_message_file, # MessageFile query - found
|
||||
mock_message, # Message query - found
|
||||
None, # UploadFile query - not found
|
||||
]
|
||||
# Mock get() for UploadFile - not found
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
# Execute and assert exception
|
||||
with pytest.raises(FileNotFoundError) as exc_info:
|
||||
@@ -161,10 +165,13 @@ class TestFilePreviewApi:
|
||||
mock_message_file.message_id = mock_message.id
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock database queries
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
# Mock scalar() for MessageFile and Message queries
|
||||
mock_db.session.scalar.side_effect = [
|
||||
mock_message_file, # MessageFile query
|
||||
mock_message, # Message query
|
||||
]
|
||||
# Mock get() for UploadFile and App PK lookups
|
||||
mock_db.session.get.side_effect = [
|
||||
mock_upload_file, # UploadFile query
|
||||
mock_app, # App query for tenant validation
|
||||
]
|
||||
@@ -262,10 +269,13 @@ class TestFilePreviewApi:
|
||||
mock_storage.load.return_value = mock_generator
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock database queries
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
# Mock scalar() for MessageFile and Message queries
|
||||
mock_db.session.scalar.side_effect = [
|
||||
mock_message_file, # MessageFile query
|
||||
mock_message, # Message query
|
||||
]
|
||||
# Mock get() for UploadFile and App PK lookups
|
||||
mock_db.session.get.side_effect = [
|
||||
mock_upload_file, # UploadFile query
|
||||
mock_app, # App query for tenant validation
|
||||
]
|
||||
@@ -301,10 +311,13 @@ class TestFilePreviewApi:
|
||||
mock_storage.load.side_effect = Exception("Storage error")
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock database queries for validation
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
# Mock scalar() for MessageFile and Message queries
|
||||
mock_db.session.scalar.side_effect = [
|
||||
mock_message_file, # MessageFile query
|
||||
mock_message, # Message query
|
||||
]
|
||||
# Mock get() for UploadFile and App PK lookups
|
||||
mock_db.session.get.side_effect = [
|
||||
mock_upload_file, # UploadFile query
|
||||
mock_app, # App query for tenant validation
|
||||
]
|
||||
@@ -327,8 +340,8 @@ class TestFilePreviewApi:
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
with patch("controllers.service_api.app.file_preview.db") as mock_db:
|
||||
# Mock database query to raise unexpected exception
|
||||
mock_db.session.query.side_effect = Exception("Unexpected database error")
|
||||
# Mock database scalar to raise unexpected exception
|
||||
mock_db.session.scalar.side_effect = Exception("Unexpected database error")
|
||||
|
||||
# Execute and assert exception
|
||||
with pytest.raises(FileAccessDeniedError) as exc_info:
|
||||
|
||||
@@ -119,11 +119,8 @@ class AuthenticationMocker:
|
||||
|
||||
@staticmethod
|
||||
def setup_db_queries(mock_db, mock_app, mock_tenant, mock_account=None):
|
||||
"""Configure mock_db to return app and tenant in sequence."""
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app,
|
||||
mock_tenant,
|
||||
]
|
||||
"""Configure mock_db to return app and tenant via session.get()."""
|
||||
mock_db.session.get.side_effect = [mock_app, mock_tenant]
|
||||
|
||||
if mock_account:
|
||||
mock_ta = Mock()
|
||||
@@ -136,11 +133,9 @@ class AuthenticationMocker:
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
|
||||
mock_query = mock_db.session.query.return_value
|
||||
target_mock = mock_query.where.return_value.where.return_value.where.return_value.where.return_value
|
||||
target_mock.one_or_none.return_value = (mock_tenant, mock_ta)
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_account
|
||||
mock_db.session.get.return_value = mock_account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -88,7 +88,7 @@ class TestAppSiteApi:
|
||||
mock_app_model.tenant = mock_tenant
|
||||
|
||||
# Mock wraps.db for authentication
|
||||
mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_wraps_db.session.get.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
@@ -98,7 +98,7 @@ class TestAppSiteApi:
|
||||
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
|
||||
|
||||
# Mock site.db for site query
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_site
|
||||
mock_db.session.scalar.return_value = mock_site
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -109,7 +109,7 @@ class TestAppSiteApi:
|
||||
assert response["title"] == "Test Site"
|
||||
assert response["icon"] == "icon-url"
|
||||
assert response["description"] == "Site description"
|
||||
mock_db.session.query.assert_called_once_with(Site)
|
||||
mock_db.session.scalar.assert_called_once()
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.app.site.db")
|
||||
@@ -140,7 +140,7 @@ class TestAppSiteApi:
|
||||
mock_tenant.status = TenantStatus.NORMAL
|
||||
mock_app_model.tenant = mock_tenant
|
||||
|
||||
mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_wraps_db.session.get.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
@@ -150,7 +150,7 @@ class TestAppSiteApi:
|
||||
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
|
||||
|
||||
# Mock site query to return None
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -187,7 +187,7 @@ class TestAppSiteApi:
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = TenantStatus.NORMAL
|
||||
|
||||
mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_wraps_db.session.get.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
@@ -197,7 +197,7 @@ class TestAppSiteApi:
|
||||
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
|
||||
|
||||
# Mock site query
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_site
|
||||
mock_db.session.scalar.return_value = mock_site
|
||||
|
||||
# Set tenant status to archived AFTER authentication
|
||||
mock_app_model.tenant.status = TenantStatus.ARCHIVE
|
||||
@@ -230,7 +230,7 @@ class TestAppSiteApi:
|
||||
mock_tenant.status = TenantStatus.NORMAL
|
||||
mock_app_model.tenant = mock_tenant
|
||||
|
||||
mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_wraps_db.session.get.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
@@ -258,7 +258,7 @@ class TestAppSiteApi:
|
||||
mock_site.icon_type = "image"
|
||||
mock_site.created_at = "2024-01-01T00:00:00"
|
||||
mock_site.updated_at = "2024-01-01T00:00:00"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_site
|
||||
mock_db.session.scalar.return_value = mock_site
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@@ -267,4 +267,4 @@ class TestAppSiteApi:
|
||||
|
||||
# Assert
|
||||
# The query was executed successfully (site returned), which validates the correct query was made
|
||||
mock_db.session.query.assert_called_once_with(Site)
|
||||
mock_db.session.scalar.assert_called_once()
|
||||
|
||||
@@ -144,14 +144,10 @@ class TestValidateAppToken:
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
|
||||
# Use side_effect to return app first, then tenant
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app,
|
||||
mock_tenant,
|
||||
mock_account,
|
||||
]
|
||||
# Use side_effect to return app first, then tenant via session.get()
|
||||
mock_db.session.get.side_effect = [mock_app, mock_tenant]
|
||||
|
||||
# Mock the tenant owner query
|
||||
# Mock the tenant owner query (execute(select(...)).one_or_none())
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
|
||||
|
||||
@validate_app_token
|
||||
@@ -175,7 +171,7 @@ class TestValidateAppToken:
|
||||
mock_api_token.app_id = str(uuid.uuid4())
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
@validate_app_token
|
||||
def protected_view(**kwargs):
|
||||
@@ -198,7 +194,7 @@ class TestValidateAppToken:
|
||||
|
||||
mock_app = Mock()
|
||||
mock_app.status = "abnormal"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
mock_db.session.get.return_value = mock_app
|
||||
|
||||
@validate_app_token
|
||||
def protected_view(**kwargs):
|
||||
@@ -222,7 +218,7 @@ class TestValidateAppToken:
|
||||
mock_app = Mock()
|
||||
mock_app.status = "normal"
|
||||
mock_app.enable_api = False
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
mock_db.session.get.return_value = mock_app
|
||||
|
||||
@validate_app_token
|
||||
def protected_view(**kwargs):
|
||||
@@ -474,11 +470,11 @@ class TestValidateDatasetToken:
|
||||
mock_account.id = mock_ta.account_id
|
||||
mock_account.current_tenant = mock_tenant
|
||||
|
||||
# Mock the tenant account join query
|
||||
# Mock the tenant account join query (execute(select(...)).one_or_none())
|
||||
setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta)
|
||||
|
||||
# Mock the account query
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_account
|
||||
# Mock the account lookup via session.get()
|
||||
mock_db.session.get.return_value = mock_account
|
||||
|
||||
@validate_dataset_token
|
||||
def protected_view(tenant_id):
|
||||
@@ -501,7 +497,7 @@ class TestValidateDatasetToken:
|
||||
mock_api_token.tenant_id = str(uuid.uuid4())
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
@validate_dataset_token
|
||||
def protected_view(dataset_id=None, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user