mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 16:39:26 +08:00
refactor(api): tighten login and wrapper typing (#34447)
This commit is contained in:
@@ -193,7 +193,7 @@ workflow_draft_variable_list_model = console_ns.model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
|
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||||
"""Common prerequisites for all draft workflow variable APIs.
|
"""Common prerequisites for all draft workflow variable APIs.
|
||||||
|
|
||||||
It ensures the following conditions are satisfied:
|
It ensures the following conditions are satisfied:
|
||||||
@@ -210,7 +210,7 @@ def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
|
|||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def wrapper(*args: Any, **kwargs: Any):
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any
|
from typing import overload
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
@@ -23,14 +23,30 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
|
|||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
||||||
def get_app_model(
|
@overload
|
||||||
view: Callable[..., Any] | None = None,
|
def get_app_model[**P, R](
|
||||||
|
view: Callable[P, R],
|
||||||
*,
|
*,
|
||||||
mode: AppMode | list[AppMode] | None = None,
|
mode: AppMode | list[AppMode] | None = None,
|
||||||
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
|
) -> Callable[P, R]: ...
|
||||||
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_app_model[**P, R](
|
||||||
|
view: None = None,
|
||||||
|
*,
|
||||||
|
mode: AppMode | list[AppMode] | None = None,
|
||||||
|
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def get_app_model[**P, R](
|
||||||
|
view: Callable[P, R] | None = None,
|
||||||
|
*,
|
||||||
|
mode: AppMode | list[AppMode] | None = None,
|
||||||
|
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||||
|
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated_view(*args: Any, **kwargs: Any):
|
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
if not kwargs.get("app_id"):
|
if not kwargs.get("app_id"):
|
||||||
raise ValueError("missing app_id in path parameters")
|
raise ValueError("missing app_id in path parameters")
|
||||||
|
|
||||||
@@ -68,14 +84,30 @@ def get_app_model(
|
|||||||
return decorator(view)
|
return decorator(view)
|
||||||
|
|
||||||
|
|
||||||
def get_app_model_with_trial(
|
@overload
|
||||||
view: Callable[..., Any] | None = None,
|
def get_app_model_with_trial[**P, R](
|
||||||
|
view: Callable[P, R],
|
||||||
*,
|
*,
|
||||||
mode: AppMode | list[AppMode] | None = None,
|
mode: AppMode | list[AppMode] | None = None,
|
||||||
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
|
) -> Callable[P, R]: ...
|
||||||
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_app_model_with_trial[**P, R](
|
||||||
|
view: None = None,
|
||||||
|
*,
|
||||||
|
mode: AppMode | list[AppMode] | None = None,
|
||||||
|
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def get_app_model_with_trial[**P, R](
|
||||||
|
view: Callable[P, R] | None = None,
|
||||||
|
*,
|
||||||
|
mode: AppMode | list[AppMode] | None = None,
|
||||||
|
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||||
|
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated_view(*args: Any, **kwargs: Any):
|
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
if not kwargs.get("app_id"):
|
if not kwargs.get("app_id"):
|
||||||
raise ValueError("missing app_id in path parameters")
|
raise ValueError("missing app_id in path parameters")
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import Any, NoReturn
|
from typing import Any, NoReturn
|
||||||
|
|
||||||
from flask import Response, request
|
from flask import Response, request
|
||||||
@@ -55,7 +56,7 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
|
|||||||
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||||
|
|
||||||
|
|
||||||
def _api_prerequisite(f):
|
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||||
"""Common prerequisites for all draft workflow variable APIs.
|
"""Common prerequisites for all draft workflow variable APIs.
|
||||||
|
|
||||||
It ensures the following conditions are satisfied:
|
It ensures the following conditions are satisfied:
|
||||||
@@ -70,7 +71,7 @@ def _api_prerequisite(f):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, cast, overload
|
from typing import cast, overload
|
||||||
|
|
||||||
from flask import current_app, request
|
from flask import current_app, request
|
||||||
from flask_login import user_logged_in
|
from flask_login import user_logged_in
|
||||||
@@ -230,94 +231,73 @@ def cloud_edition_billing_rate_limit_check[**P, R](
|
|||||||
return interceptor
|
return interceptor
|
||||||
|
|
||||||
|
|
||||||
def validate_dataset_token(
|
def validate_dataset_token[R](view: Callable[..., R]) -> Callable[..., R]:
|
||||||
view: Callable[..., Any] | None = None,
|
positional_parameters = [
|
||||||
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
|
parameter
|
||||||
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
|
for parameter in inspect.signature(view).parameters.values()
|
||||||
@wraps(view_func)
|
if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
||||||
def decorated(*args: Any, **kwargs: Any) -> Any:
|
]
|
||||||
api_token = validate_and_get_api_token("dataset")
|
expects_bound_instance = bool(positional_parameters and positional_parameters[0].name in {"self", "cls"})
|
||||||
|
|
||||||
# get url path dataset_id from positional args or kwargs
|
@wraps(view)
|
||||||
# Flask passes URL path parameters as positional arguments
|
def decorated(*args: object, **kwargs: object) -> R:
|
||||||
dataset_id = None
|
api_token = validate_and_get_api_token("dataset")
|
||||||
|
|
||||||
# First try to get from kwargs (explicit parameter)
|
# Flask may pass URL path parameters positionally, so inspect both kwargs and args.
|
||||||
dataset_id = kwargs.get("dataset_id")
|
dataset_id = kwargs.get("dataset_id")
|
||||||
|
|
||||||
# If not in kwargs, try to extract from positional args
|
if not dataset_id and args:
|
||||||
if not dataset_id and args:
|
potential_id = args[0]
|
||||||
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
|
try:
|
||||||
# Check if first arg is likely a class instance (has __dict__ or __class__)
|
str_id = str(potential_id)
|
||||||
if len(args) > 1 and hasattr(args[0], "__dict__"):
|
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||||
# This is a class method, dataset_id should be in args[1]
|
dataset_id = str_id
|
||||||
potential_id = args[1]
|
except Exception:
|
||||||
# Validate it's a string-like UUID, not another object
|
logger.exception("Failed to parse dataset_id from positional args")
|
||||||
try:
|
|
||||||
# Try to convert to string and check if it's a valid UUID format
|
|
||||||
str_id = str(potential_id)
|
|
||||||
# Basic check: UUIDs are 36 chars with hyphens
|
|
||||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
|
||||||
dataset_id = str_id
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to parse dataset_id from class method args")
|
|
||||||
elif len(args) > 0:
|
|
||||||
# Not a class method, check if args[0] looks like a UUID
|
|
||||||
potential_id = args[0]
|
|
||||||
try:
|
|
||||||
str_id = str(potential_id)
|
|
||||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
|
||||||
dataset_id = str_id
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to parse dataset_id from positional args")
|
|
||||||
|
|
||||||
# Validate dataset if dataset_id is provided
|
if dataset_id:
|
||||||
if dataset_id:
|
dataset_id = str(dataset_id)
|
||||||
dataset_id = str(dataset_id)
|
dataset = db.session.scalar(
|
||||||
dataset = db.session.scalar(
|
select(Dataset)
|
||||||
select(Dataset)
|
.where(
|
||||||
.where(
|
Dataset.id == dataset_id,
|
||||||
Dataset.id == dataset_id,
|
Dataset.tenant_id == api_token.tenant_id,
|
||||||
Dataset.tenant_id == api_token.tenant_id,
|
|
||||||
)
|
|
||||||
.limit(1)
|
|
||||||
)
|
)
|
||||||
if not dataset:
|
.limit(1)
|
||||||
raise NotFound("Dataset not found.")
|
)
|
||||||
if not dataset.enable_api:
|
if not dataset:
|
||||||
raise Forbidden("Dataset api access is not enabled.")
|
raise NotFound("Dataset not found.")
|
||||||
tenant_account_join = db.session.execute(
|
if not dataset.enable_api:
|
||||||
select(Tenant, TenantAccountJoin)
|
raise Forbidden("Dataset api access is not enabled.")
|
||||||
.where(Tenant.id == api_token.tenant_id)
|
|
||||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
tenant_account_join = db.session.execute(
|
||||||
.where(TenantAccountJoin.role.in_(["owner"]))
|
select(Tenant, TenantAccountJoin)
|
||||||
.where(Tenant.status == TenantStatus.NORMAL)
|
.where(Tenant.id == api_token.tenant_id)
|
||||||
).one_or_none() # TODO: only owner information is required, so only one is returned.
|
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||||
if tenant_account_join:
|
.where(TenantAccountJoin.role.in_(["owner"]))
|
||||||
tenant, ta = tenant_account_join
|
.where(Tenant.status == TenantStatus.NORMAL)
|
||||||
account = db.session.get(Account, ta.account_id)
|
).one_or_none() # TODO: only owner information is required, so only one is returned.
|
||||||
# Login admin
|
if tenant_account_join:
|
||||||
if account:
|
tenant, ta = tenant_account_join
|
||||||
account.current_tenant = tenant
|
account = db.session.get(Account, ta.account_id)
|
||||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
# Login admin
|
||||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
if account:
|
||||||
else:
|
account.current_tenant = tenant
|
||||||
raise Unauthorized("Tenant owner account does not exist.")
|
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||||
|
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||||
else:
|
else:
|
||||||
raise Unauthorized("Tenant does not exist.")
|
raise Unauthorized("Tenant owner account does not exist.")
|
||||||
if args and isinstance(args[0], Resource):
|
else:
|
||||||
return view_func(args[0], api_token.tenant_id, *args[1:], **kwargs)
|
raise Unauthorized("Tenant does not exist.")
|
||||||
|
|
||||||
return view_func(api_token.tenant_id, *args, **kwargs)
|
if expects_bound_instance:
|
||||||
|
if not args:
|
||||||
|
raise TypeError("validate_dataset_token expected a bound resource instance.")
|
||||||
|
return view(args[0], api_token.tenant_id, *args[1:], **kwargs)
|
||||||
|
|
||||||
return decorated
|
return view(api_token.tenant_id, *args, **kwargs)
|
||||||
|
|
||||||
if view:
|
return decorated
|
||||||
return decorator(view)
|
|
||||||
|
|
||||||
# if view is None, it means that the decorator is used without parentheses
|
|
||||||
# use the decorator as a function for method_decorators
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def validate_and_get_api_token(scope: str | None = None):
|
def validate_and_get_api_token(scope: str | None = None):
|
||||||
|
|||||||
@@ -1,5 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from extensions.ext_login import DifyLoginManager
|
||||||
|
|
||||||
|
|
||||||
class DifyApp(Flask):
|
class DifyApp(Flask):
|
||||||
pass
|
"""Flask application type with Dify-specific extension attributes."""
|
||||||
|
|
||||||
|
login_manager: DifyLoginManager
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import flask_login
|
import flask_login
|
||||||
from flask import Response, request
|
from flask import Request, Response, request
|
||||||
from flask_login import user_loaded_from_request, user_logged_in
|
from flask_login import user_loaded_from_request, user_logged_in
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from werkzeug.exceptions import NotFound, Unauthorized
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
@@ -16,13 +17,35 @@ from models import Account, Tenant, TenantAccountJoin
|
|||||||
from models.model import AppMCPServer, EndUser
|
from models.model import AppMCPServer, EndUser
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
|
|
||||||
login_manager = flask_login.LoginManager()
|
type LoginUser = Account | EndUser
|
||||||
|
|
||||||
|
|
||||||
|
class DifyLoginManager(flask_login.LoginManager):
|
||||||
|
"""Project-specific Flask-Login manager with a stable unauthorized contract.
|
||||||
|
|
||||||
|
Dify registers `unauthorized_handler` below to always return a JSON `Response`.
|
||||||
|
Overriding this method lets callers rely on that narrower return type instead of
|
||||||
|
Flask-Login's broader callback contract.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def unauthorized(self) -> Response:
|
||||||
|
"""Return the registered unauthorized handler result as a Flask `Response`."""
|
||||||
|
return cast(Response, super().unauthorized())
|
||||||
|
|
||||||
|
def load_user_from_request_context(self) -> None:
|
||||||
|
"""Populate Flask-Login's request-local user cache for the current request."""
|
||||||
|
self._load_user()
|
||||||
|
|
||||||
|
|
||||||
|
login_manager = DifyLoginManager()
|
||||||
|
|
||||||
|
|
||||||
# Flask-Login configuration
|
# Flask-Login configuration
|
||||||
@login_manager.request_loader
|
@login_manager.request_loader
|
||||||
def load_user_from_request(request_from_flask_login):
|
def load_user_from_request(request_from_flask_login: Request) -> LoginUser | None:
|
||||||
"""Load user based on the request."""
|
"""Load user based on the request."""
|
||||||
|
del request_from_flask_login
|
||||||
|
|
||||||
# Skip authentication for documentation endpoints
|
# Skip authentication for documentation endpoints
|
||||||
if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
|
if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
|
||||||
return None
|
return None
|
||||||
@@ -100,10 +123,12 @@ def load_user_from_request(request_from_flask_login):
|
|||||||
raise NotFound("End user not found.")
|
raise NotFound("End user not found.")
|
||||||
return end_user
|
return end_user
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@user_logged_in.connect
|
@user_logged_in.connect
|
||||||
@user_loaded_from_request.connect
|
@user_loaded_from_request.connect
|
||||||
def on_user_logged_in(_sender, user):
|
def on_user_logged_in(_sender: object, user: LoginUser) -> None:
|
||||||
"""Called when a user logged in.
|
"""Called when a user logged in.
|
||||||
|
|
||||||
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
|
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
|
||||||
@@ -114,8 +139,10 @@ def on_user_logged_in(_sender, user):
|
|||||||
|
|
||||||
|
|
||||||
@login_manager.unauthorized_handler
|
@login_manager.unauthorized_handler
|
||||||
def unauthorized_handler():
|
def unauthorized_handler() -> Response:
|
||||||
"""Handle unauthorized requests."""
|
"""Handle unauthorized requests."""
|
||||||
|
# Keep this as a concrete `Response`; `DifyLoginManager.unauthorized()` narrows
|
||||||
|
# Flask-Login's callback contract based on this override.
|
||||||
return Response(
|
return Response(
|
||||||
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
|
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
|
||||||
status=401,
|
status=401,
|
||||||
@@ -123,5 +150,5 @@ def unauthorized_handler():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def init_app(app: DifyApp):
|
def init_app(app: DifyApp) -> None:
|
||||||
login_manager.init_app(app)
|
login_manager.init_app(app)
|
||||||
|
|||||||
@@ -2,19 +2,19 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from flask import current_app, g, has_request_context, request
|
from flask import Response, current_app, g, has_request_context, request
|
||||||
from flask_login.config import EXEMPT_METHODS
|
from flask_login.config import EXEMPT_METHODS
|
||||||
from werkzeug.local import LocalProxy
|
from werkzeug.local import LocalProxy
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from dify_app import DifyApp
|
||||||
|
from extensions.ext_login import DifyLoginManager
|
||||||
from libs.token import check_csrf_token
|
from libs.token import check_csrf_token
|
||||||
from models import Account
|
from models import Account
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from flask.typing import ResponseReturnValue
|
|
||||||
|
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
|
||||||
|
|
||||||
@@ -29,7 +29,13 @@ def _resolve_current_user() -> EndUser | Account | None:
|
|||||||
return get_current_object() if callable(get_current_object) else user_proxy # type: ignore
|
return get_current_object() if callable(get_current_object) else user_proxy # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def current_account_with_tenant():
|
def _get_login_manager() -> DifyLoginManager:
|
||||||
|
"""Return the project login manager with Dify's narrowed unauthorized contract."""
|
||||||
|
app = cast(DifyApp, current_app)
|
||||||
|
return app.login_manager
|
||||||
|
|
||||||
|
|
||||||
|
def current_account_with_tenant() -> tuple[Account, str]:
|
||||||
"""
|
"""
|
||||||
Resolve the underlying account for the current user proxy and ensure tenant context exists.
|
Resolve the underlying account for the current user proxy and ensure tenant context exists.
|
||||||
Allows tests to supply plain Account mocks without the LocalProxy helper.
|
Allows tests to supply plain Account mocks without the LocalProxy helper.
|
||||||
@@ -42,7 +48,7 @@ def current_account_with_tenant():
|
|||||||
return user, user.current_tenant_id
|
return user, user.current_tenant_id
|
||||||
|
|
||||||
|
|
||||||
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
|
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]:
|
||||||
"""
|
"""
|
||||||
If you decorate a view with this, it will ensure that the current user is
|
If you decorate a view with this, it will ensure that the current user is
|
||||||
logged in and authenticated before calling the actual view. (If they are
|
logged in and authenticated before calling the actual view. (If they are
|
||||||
@@ -77,13 +83,16 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseRetu
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
|
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||||
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
||||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||||
|
|
||||||
user = _resolve_current_user()
|
user = _resolve_current_user()
|
||||||
if user is None or not user.is_authenticated:
|
if user is None or not user.is_authenticated:
|
||||||
return current_app.login_manager.unauthorized() # type: ignore
|
# `DifyLoginManager` guarantees that the registered unauthorized handler
|
||||||
|
# is surfaced here as a concrete Flask `Response`.
|
||||||
|
unauthorized_response: Response = _get_login_manager().unauthorized()
|
||||||
|
return unauthorized_response
|
||||||
g._login_user = user
|
g._login_user = user
|
||||||
# we put csrf validation here for less conflicts
|
# we put csrf validation here for less conflicts
|
||||||
# TODO: maybe find a better place for it.
|
# TODO: maybe find a better place for it.
|
||||||
@@ -96,7 +105,7 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseRetu
|
|||||||
def _get_user() -> EndUser | Account | None:
|
def _get_user() -> EndUser | Account | None:
|
||||||
if has_request_context():
|
if has_request_context():
|
||||||
if "_login_user" not in g:
|
if "_login_user" not in g:
|
||||||
current_app.login_manager._load_user() # type: ignore
|
_get_login_manager().load_user_from_request_context()
|
||||||
|
|
||||||
return g._login_user
|
return g._login_user
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ def app():
|
|||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.config["TESTING"] = True
|
app.config["TESTING"] = True
|
||||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||||
app.login_manager = SimpleNamespace(_load_user=lambda: None)
|
app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None)
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from models.account import Account, TenantAccountRole
|
|||||||
def app():
|
def app():
|
||||||
flask_app = Flask(__name__)
|
flask_app = Flask(__name__)
|
||||||
flask_app.config["TESTING"] = True
|
flask_app.config["TESTING"] = True
|
||||||
flask_app.login_manager = SimpleNamespace(_load_user=lambda: None)
|
flask_app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None)
|
||||||
return flask_app
|
return flask_app
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
17
api/tests/unit_tests/extensions/test_ext_login.py
Normal file
17
api/tests/unit_tests/extensions/test_ext_login.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from flask import Response
|
||||||
|
|
||||||
|
from extensions.ext_login import unauthorized_handler
|
||||||
|
|
||||||
|
|
||||||
|
def test_unauthorized_handler_returns_json_response() -> None:
|
||||||
|
response = unauthorized_handler()
|
||||||
|
|
||||||
|
assert isinstance(response, Response)
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.content_type == "application/json"
|
||||||
|
assert json.loads(response.get_data(as_text=True)) == {
|
||||||
|
"code": "unauthorized",
|
||||||
|
"message": "Unauthorized.",
|
||||||
|
}
|
||||||
@@ -2,11 +2,12 @@ from types import SimpleNamespace
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from flask import Flask, g
|
from flask import Flask, Response, g
|
||||||
from flask_login import LoginManager, UserMixin
|
from flask_login import UserMixin
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
import libs.login as login_module
|
import libs.login as login_module
|
||||||
|
from extensions.ext_login import DifyLoginManager
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
|
|
||||||
@@ -39,9 +40,12 @@ def login_app(mocker: MockerFixture) -> Flask:
|
|||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.config["TESTING"] = True
|
app.config["TESTING"] = True
|
||||||
|
|
||||||
login_manager = LoginManager()
|
login_manager = DifyLoginManager()
|
||||||
login_manager.init_app(app)
|
login_manager.init_app(app)
|
||||||
login_manager.unauthorized = mocker.Mock(name="unauthorized", return_value="Unauthorized")
|
login_manager.unauthorized = mocker.Mock(
|
||||||
|
name="unauthorized",
|
||||||
|
return_value=Response("Unauthorized", status=401, content_type="application/json"),
|
||||||
|
)
|
||||||
|
|
||||||
@login_manager.user_loader
|
@login_manager.user_loader
|
||||||
def load_user(_user_id: str):
|
def load_user(_user_id: str):
|
||||||
@@ -109,18 +113,43 @@ class TestLoginRequired:
|
|||||||
resolved_user: MockUser | None,
|
resolved_user: MockUser | None,
|
||||||
description: str,
|
description: str,
|
||||||
):
|
):
|
||||||
"""Test that missing or unauthenticated users are redirected."""
|
"""Test that missing or unauthenticated users return the manager response."""
|
||||||
|
|
||||||
resolve_user = resolve_current_user(resolved_user)
|
resolve_user = resolve_current_user(resolved_user)
|
||||||
|
|
||||||
with login_app.test_request_context():
|
with login_app.test_request_context():
|
||||||
result = protected_view()
|
result = protected_view()
|
||||||
|
|
||||||
assert result == "Unauthorized", description
|
assert result is login_app.login_manager.unauthorized.return_value, description
|
||||||
|
assert isinstance(result, Response)
|
||||||
|
assert result.status_code == 401
|
||||||
resolve_user.assert_called_once_with()
|
resolve_user.assert_called_once_with()
|
||||||
login_app.login_manager.unauthorized.assert_called_once_with()
|
login_app.login_manager.unauthorized.assert_called_once_with()
|
||||||
csrf_check.assert_not_called()
|
csrf_check.assert_not_called()
|
||||||
|
|
||||||
|
def test_unauthorized_access_propagates_response_object(
|
||||||
|
self,
|
||||||
|
login_app: Flask,
|
||||||
|
protected_view,
|
||||||
|
csrf_check: MagicMock,
|
||||||
|
resolve_current_user,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test that unauthorized responses are propagated as Flask Response objects."""
|
||||||
|
resolve_user = resolve_current_user(None)
|
||||||
|
response = Response("Unauthorized", status=401, content_type="application/json")
|
||||||
|
mocker.patch.object(
|
||||||
|
login_module, "_get_login_manager", return_value=SimpleNamespace(unauthorized=lambda: response)
|
||||||
|
)
|
||||||
|
|
||||||
|
with login_app.test_request_context():
|
||||||
|
result = protected_view()
|
||||||
|
|
||||||
|
assert result is response
|
||||||
|
assert isinstance(result, Response)
|
||||||
|
resolve_user.assert_called_once_with()
|
||||||
|
csrf_check.assert_not_called()
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("method", "login_disabled"),
|
("method", "login_disabled"),
|
||||||
[
|
[
|
||||||
@@ -168,10 +197,14 @@ class TestGetUser:
|
|||||||
"""Test that _get_user loads user if not already in g."""
|
"""Test that _get_user loads user if not already in g."""
|
||||||
mock_user = MockUser("test_user")
|
mock_user = MockUser("test_user")
|
||||||
|
|
||||||
def _load_user() -> None:
|
def load_user_from_request_context() -> None:
|
||||||
g._login_user = mock_user
|
g._login_user = mock_user
|
||||||
|
|
||||||
load_user = mocker.patch.object(login_app.login_manager, "_load_user", side_effect=_load_user)
|
load_user = mocker.patch.object(
|
||||||
|
login_app.login_manager,
|
||||||
|
"load_user_from_request_context",
|
||||||
|
side_effect=load_user_from_request_context,
|
||||||
|
)
|
||||||
|
|
||||||
with login_app.test_request_context():
|
with login_app.test_request_context():
|
||||||
user = login_module._get_user()
|
user = login_module._get_user()
|
||||||
|
|||||||
Reference in New Issue
Block a user