chore(api): align Python support with 3.12 (#34419)

Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
99
2026-04-02 13:07:32 +08:00
committed by GitHub
parent cb9ee5903a
commit 8f9dbf269e
97 changed files with 410 additions and 1441 deletions

View File

@@ -10,7 +10,7 @@ import threading
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Protocol, TypeVar, final, runtime_checkable
from typing import Any, Protocol, final, runtime_checkable
from pydantic import BaseModel
@@ -188,8 +188,6 @@ class ExecutionContextBuilder:
_capturer: Callable[[], IExecutionContext] | None = None
_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
T = TypeVar("T", bound=BaseModel)
class ContextProviderNotFoundError(KeyError):
"""Raised when a tenant-scoped context provider is missing."""

View File

@@ -1,7 +1,4 @@
from contextvars import ContextVar
from typing import Generic, TypeVar
T = TypeVar("T")
class HiddenValue:
@@ -11,7 +8,7 @@ class HiddenValue:
_default = HiddenValue()
class RecyclableContextVar(Generic[T]):
class RecyclableContextVar[T]:
"""
RecyclableContextVar is a wrapper around ContextVar
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now

View File

@@ -1,14 +1,14 @@
from __future__ import annotations
from typing import Any, TypeAlias
from typing import Any
from graphon.file import helpers as file_helpers
from pydantic import BaseModel, ConfigDict, computed_field
from models.model import IconType
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
JSONObject: TypeAlias = dict[str, Any]
type JSONValue = str | int | float | bool | None | dict[str, Any] | list[Any]
type JSONObject = dict[str, Any]
class SystemParameters(BaseModel):

View File

@@ -2,7 +2,6 @@ import csv
import io
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource
@@ -20,9 +19,6 @@ from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from services.billing_service import BillingService
P = ParamSpec("P")
R = TypeVar("R")
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -72,9 +68,9 @@ console_ns.schema_model(
)
def admin_required(view: Callable[P, R]):
def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")

View File

@@ -1,7 +1,7 @@
import logging
import uuid
from datetime import datetime
from typing import Any, Literal, TypeAlias
from typing import Any, Literal
from flask import request
from flask_restx import Resource
@@ -152,7 +152,7 @@ class AppTracePayload(BaseModel):
return value
JSONValue: TypeAlias = Any
type JSONValue = Any
class ResponseModel(BaseModel):

View File

@@ -1,7 +1,7 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import Any, NoReturn, ParamSpec, TypeVar
from typing import Any
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
@@ -192,11 +192,8 @@ workflow_draft_variable_list_model = console_ns.model(
"WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy
)
P = ParamSpec("P")
R = TypeVar("R")
def _api_prerequisite(f: Callable[P, R]):
def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@@ -213,7 +210,7 @@ def _api_prerequisite(f: Callable[P, R]):
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs):
def wrapper(*args: Any, **kwargs: Any):
return f(*args, **kwargs)
return wrapper
@@ -270,7 +267,7 @@ class WorkflowVariableCollectionApi(Resource):
return Response("", 204)
def validate_node_id(node_id: str) -> NoReturn | None:
def validate_node_id(node_id: str) -> None:
if node_id in [
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
@@ -285,7 +282,6 @@ def validate_node_id(node_id: str) -> NoReturn | None:
raise InvalidArgumentError(
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
)
return None
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar, Union
from typing import Any
from sqlalchemy import select
@@ -9,11 +9,6 @@ from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models import App, AppMode
P = ParamSpec("P")
R = TypeVar("R")
P1 = ParamSpec("P1")
R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None:
_, current_tenant_id = current_account_with_tenant()
@@ -28,10 +23,14 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
return app_model
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P1, R1]):
def get_app_model(
view: Callable[..., Any] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(view_func)
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
def decorated_view(*args: Any, **kwargs: Any):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")
@@ -69,10 +68,14 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
return decorator(view)
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]):
def get_app_model_with_trial(
view: Callable[..., Any] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
def decorated_view(*args: Any, **kwargs: Any):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")

View File

@@ -1,8 +1,9 @@
from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from typing import Concatenate
from flask import jsonify, request
from flask.typing import ResponseReturnValue
from flask_restx import Resource
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel
@@ -16,10 +17,6 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
from .. import console_ns
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
class OAuthClientPayload(BaseModel):
client_id: str
@@ -39,9 +36,11 @@ class OAuthTokenRequest(BaseModel):
refresh_token: str | None = None
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
def oauth_server_client_id_required[T, **P, R](
view: Callable[Concatenate[T, OAuthProviderApp, P], R],
) -> Callable[Concatenate[T, P], R]:
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
json_data = request.get_json()
if json_data is None:
raise BadRequest("client_id is required")
@@ -58,9 +57,13 @@ def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderA
return decorated
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
def oauth_server_access_token_required[T, **P, R](
view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R],
) -> Callable[Concatenate[T, OAuthProviderApp, P], R | ResponseReturnValue]:
@wraps(view)
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
def decorated(
self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs
) -> R | ResponseReturnValue:
if not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")

View File

@@ -1,6 +1,5 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from sqlalchemy import select
@@ -9,13 +8,10 @@ from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.dataset import Pipeline
P = ParamSpec("P")
R = TypeVar("R")
def get_rag_pipeline(view_func: Callable[P, R]):
def get_rag_pipeline[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters")

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from typing import Concatenate
from flask import abort
from flask_restx import Resource
@@ -15,12 +15,8 @@ from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def installed_app_required[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
@@ -49,7 +45,7 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
return decorator
def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def user_allowed_to_access_app[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
@@ -73,7 +69,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
return decorator
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
def trial_app_required[**P, R](view: Callable[Concatenate[App, P], R] | None = None):
def decorator(view: Callable[Concatenate[App, P], R]):
@wraps(view)
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
@@ -106,7 +102,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
return decorator
def trial_feature_enable(view: Callable[P, R]):
def trial_feature_enable[**P, R](view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
@@ -117,7 +113,7 @@ def trial_feature_enable(view: Callable[P, R]):
return decorated
def explore_banner_enabled(view: Callable[P, R]):
def explore_banner_enabled[**P, R](view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()

View File

@@ -1,6 +1,5 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
@@ -9,17 +8,14 @@ from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.account import TenantPluginPermission
P = ParamSpec("P")
R = TypeVar("R")
def plugin_permission_required(
install_required: bool = False,
debug_required: bool = False,
):
def interceptor(view: Callable[P, R]):
def interceptor[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
tenant_id = current_tenant_id

View File

@@ -4,7 +4,6 @@ import os
import time
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request
from sqlalchemy import select
@@ -25,9 +24,6 @@ from services.operation_service import OperationService
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
P = ParamSpec("P")
R = TypeVar("R")
# Field names for decryption
FIELD_NAME_PASSWORD = "password"
FIELD_NAME_CODE = "code"
@@ -37,7 +33,7 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
# check account initialization
@@ -50,7 +46,7 @@ def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
return decorated
def only_edition_cloud(view: Callable[P, R]):
def only_edition_cloud[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "CLOUD":
@@ -61,7 +57,7 @@ def only_edition_cloud(view: Callable[P, R]):
return decorated
def only_edition_enterprise(view: Callable[P, R]):
def only_edition_enterprise[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ENTERPRISE_ENABLED:
@@ -72,7 +68,7 @@ def only_edition_enterprise(view: Callable[P, R]):
return decorated
def only_edition_self_hosted(view: Callable[P, R]):
def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "SELF_HOSTED":
@@ -83,7 +79,7 @@ def only_edition_self_hosted(view: Callable[P, R]):
return decorated
def cloud_edition_billing_enabled(view: Callable[P, R]):
def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
@@ -95,7 +91,7 @@ def cloud_edition_billing_enabled(view: Callable[P, R]):
return decorated
def cloud_edition_billing_resource_check(resource: str):
def cloud_edition_billing_resource_check[**P, R](resource: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
@@ -137,7 +133,9 @@ def cloud_edition_billing_resource_check(resource: str):
return interceptor
def cloud_edition_billing_knowledge_limit_check(resource: str):
def cloud_edition_billing_knowledge_limit_check[**P, R](
resource: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
@@ -160,7 +158,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
return interceptor
def cloud_edition_billing_rate_limit_check(resource: str):
def cloud_edition_billing_rate_limit_check[**P, R](resource: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
@@ -196,7 +194,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
return interceptor
def cloud_utm_record(view: Callable[P, R]):
def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
@@ -215,7 +213,7 @@ def cloud_utm_record(view: Callable[P, R]):
return decorated
def setup_required(view: Callable[P, R]) -> Callable[P, R]:
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
# check setup
@@ -229,7 +227,7 @@ def setup_required(view: Callable[P, R]) -> Callable[P, R]:
return decorated
def enterprise_license_required(view: Callable[P, R]):
def enterprise_license_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
settings = FeatureService.get_system_features()
@@ -241,7 +239,7 @@ def enterprise_license_required(view: Callable[P, R]):
return decorated
def email_password_login_enabled(view: Callable[P, R]):
def email_password_login_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
@@ -254,7 +252,7 @@ def email_password_login_enabled(view: Callable[P, R]):
return decorated
def email_register_enabled(view: Callable[P, R]):
def email_register_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
@@ -267,7 +265,7 @@ def email_register_enabled(view: Callable[P, R]):
return decorated
def enable_change_email(view: Callable[P, R]):
def enable_change_email[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
@@ -280,7 +278,7 @@ def enable_change_email(view: Callable[P, R]):
return decorated
def is_allow_transfer_owner(view: Callable[P, R]):
def is_allow_transfer_owner[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
from libs.workspace_permission import check_workspace_owner_transfer_permission
@@ -293,7 +291,7 @@ def is_allow_transfer_owner(view: Callable[P, R]):
return decorated
def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
def knowledge_pipeline_publish_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
@@ -305,7 +303,7 @@ def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
return decorated
def edit_permission_required(f: Callable[P, R]):
def edit_permission_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden
@@ -323,7 +321,7 @@ def edit_permission_required(f: Callable[P, R]):
return decorated_function
def is_admin_or_owner_required(f: Callable[P, R]):
def is_admin_or_owner_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden
@@ -339,7 +337,7 @@ def is_admin_or_owner_required(f: Callable[P, R]):
return decorated_function
def annotation_import_rate_limit(view: Callable[P, R]):
def annotation_import_rate_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""
Rate limiting decorator for annotation import operations.
@@ -388,7 +386,7 @@ def annotation_import_rate_limit(view: Callable[P, R]):
return decorated
def annotation_import_concurrency_limit(view: Callable[P, R]):
def annotation_import_concurrency_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""
Concurrency control decorator for annotation import operations.
@@ -455,7 +453,7 @@ def _decrypt_field(field_name: str, error_class: type[Exception], error_message:
payload[field_name] = decoded_value
def decrypt_password_field(view: Callable[P, R]):
def decrypt_password_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""
Decorator to decrypt password field in request payload.
@@ -477,7 +475,7 @@ def decrypt_password_field(view: Callable[P, R]):
return decorated
def decrypt_code_field(view: Callable[P, R]):
def decrypt_code_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""
Decorator to decrypt verification code field in request payload.

View File

@@ -1,6 +1,5 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import current_app, request
from flask_login import user_logged_in
@@ -13,9 +12,6 @@ from libs.login import current_user
from models.account import Tenant
from models.model import DefaultEndUserSessionID, EndUser
P = ParamSpec("P")
R = TypeVar("R")
class TenantUserPayload(BaseModel):
tenant_id: str
@@ -65,9 +61,9 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
return user_model
def get_user_tenant(view_func: Callable[P, R]):
def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
user_id = payload.user_id
@@ -97,10 +93,14 @@ def get_user_tenant(view_func: Callable[P, R]):
return decorated_view
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
def decorator(view_func: Callable[P, R]):
def plugin_data[**P, R](
view: Callable[P, R] | None = None,
*,
payload_type: type[BaseModel],
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
try:
data = request.get_json()
except Exception:

View File

@@ -3,10 +3,7 @@ from collections.abc import Callable
from functools import wraps
from hashlib import sha1
from hmac import new as hmac_new
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
from flask import abort, request
from configs import dify_config
@@ -14,9 +11,9 @@ from extensions.ext_database import db
from models.model import EndUser
def billing_inner_api_only(view: Callable[P, R]):
def billing_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.INNER_API:
abort(404)
@@ -30,9 +27,9 @@ def billing_inner_api_only(view: Callable[P, R]):
return decorated
def enterprise_inner_api_only(view: Callable[P, R]):
def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.INNER_API:
abort(404)
@@ -46,9 +43,9 @@ def enterprise_inner_api_only(view: Callable[P, R]):
return decorated
def enterprise_inner_api_user_auth(view: Callable[P, R]):
def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.INNER_API:
return view(*args, **kwargs)
@@ -82,9 +79,9 @@ def enterprise_inner_api_user_auth(view: Callable[P, R]):
return decorated
def plugin_inner_api_only(view: Callable[P, R]):
def plugin_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.PLUGIN_DAEMON_KEY:
abort(404)

View File

@@ -3,7 +3,7 @@ import time
from collections.abc import Callable
from enum import StrEnum, auto
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar, cast, overload
from typing import Any, cast, overload
from flask import current_app, request
from flask_login import user_logged_in
@@ -23,10 +23,6 @@ from services.api_token_service import ApiTokenCache, fetch_token_with_single_fl
from services.end_user_service import EndUserService
from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
logger = logging.getLogger(__name__)
@@ -46,16 +42,16 @@ class FetchUserArg(BaseModel):
@overload
def validate_app_token(view: Callable[P, R]) -> Callable[P, R]: ...
def validate_app_token[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
@overload
def validate_app_token(
def validate_app_token[**P, R](
view: None = None, *, fetch_user_arg: FetchUserArg | None = None
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def validate_app_token(
def validate_app_token[**P, R](
view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@@ -136,7 +132,10 @@ def validate_app_token(
return decorator(view)
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
def cloud_edition_billing_resource_check[**P, R](
resource: str,
api_token_type: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]):
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type)
@@ -166,7 +165,10 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
return interceptor
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
def cloud_edition_billing_knowledge_limit_check[**P, R](
resource: str,
api_token_type: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
@@ -188,7 +190,10 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
return interceptor
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
def cloud_edition_billing_rate_limit_check[**P, R](
resource: str,
api_token_type: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
@@ -225,20 +230,12 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
return interceptor
@overload
def validate_dataset_token(view: Callable[Concatenate[T, P], R]) -> Callable[P, R]: ...
@overload
def validate_dataset_token(view: None = None) -> Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]: ...
def validate_dataset_token(
view: Callable[Concatenate[T, P], R] | None = None,
) -> Callable[P, R] | Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]:
def decorator(view_func: Callable[Concatenate[T, P], R]) -> Callable[P, R]:
view: Callable[..., Any] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(view_func)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated(*args: Any, **kwargs: Any) -> Any:
api_token = validate_and_get_api_token("dataset")
# get url path dataset_id from positional args or kwargs
@@ -308,7 +305,10 @@ def validate_dataset_token(
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
return view_func(api_token.tenant_id, *args, **kwargs) # type: ignore[arg-type]
if args and isinstance(args[0], Resource):
return view_func(args[0], api_token.tenant_id, *args[1:], **kwargs)
return view_func(api_token.tenant_id, *args, **kwargs)
return decorated

View File

@@ -1,7 +1,7 @@
from collections.abc import Callable
from datetime import UTC, datetime
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from typing import Concatenate
from flask import request
from flask_restx import Resource
@@ -20,14 +20,13 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
P = ParamSpec("P")
R = TypeVar("R")
def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None = None):
def decorator(view: Callable[Concatenate[App, EndUser, P], R]):
def validate_jwt_token[**P, R](
view: Callable[Concatenate[App, EndUser, P], R] | None = None,
) -> Callable[P, R] | Callable[[Callable[Concatenate[App, EndUser, P], R]], Callable[P, R]]:
def decorator(view: Callable[Concatenate[App, EndUser, P], R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
app_model, end_user = decode_jwt_token()
return view(app_model, end_user, *args, **kwargs)
@@ -38,7 +37,7 @@ def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None =
return decorator
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None):
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) -> tuple[App, EndUser]:
system_features = FeatureService.get_system_features()
if not app_code:
app_code = str(request.headers.get(HEADER_NAME_APP_CODE))

View File

@@ -5,7 +5,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from typing import TYPE_CHECKING, Any, Literal, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -68,7 +68,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
@@ -81,7 +81,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
@@ -94,7 +94,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
@@ -106,7 +106,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
@@ -239,7 +239,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
conversation: Conversation,
message: Message,
application_generate_entity: AdvancedChatAppGenerateEntity,
@@ -271,9 +271,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
@@ -359,7 +359,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Account | EndUser,
args: LoopNodeRunPayload,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
@@ -439,7 +439,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
*,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_execution_repository: WorkflowExecutionRepository,
@@ -451,7 +451,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
pause_state_config: PauseStateLayerConfig | None = None,
graph_runtime_state: GraphRuntimeState | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
@@ -653,10 +653,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
queue_manager: AppQueueManager,
conversation: ConversationSnapshot,
message: MessageSnapshot,
user: Union[Account, EndUser],
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from typing import Any, Literal, overload
from flask import Flask, current_app
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
@@ -37,7 +37,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
self,
*,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
@@ -48,7 +48,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
self,
*,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
@@ -59,21 +59,21 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
self,
*,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ...
) -> Mapping | Generator[Mapping | str, None, None]: ...
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
) -> Mapping | Generator[Mapping | str, None, None]:
"""
Generate App response.

View File

@@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from typing import Any, Literal, overload
from flask import Flask, copy_current_request_context, current_app
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
@@ -36,7 +36,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
@@ -46,7 +46,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
@@ -56,20 +56,20 @@ class ChatAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]: ...
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]:
"""
Generate App response.

View File

@@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from typing import Any, Literal, overload
from flask import Flask, copy_current_request_context, current_app
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
@@ -36,7 +36,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
@@ -46,7 +46,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
@@ -56,20 +56,20 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = False,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: ...
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
@@ -244,10 +244,10 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
message_id: str,
user: Union[Account, EndUser],
user: Account | EndUser,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
) -> Mapping | Generator[Mapping | str, None, None]:
"""
Generate App response.

View File

@@ -7,7 +7,7 @@ import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, cast, overload
from typing import Any, Literal, cast, overload
from flask import Flask, current_app
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
@@ -62,7 +62,7 @@ class PipelineGenerator(BaseAppGenerator):
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
@@ -77,7 +77,7 @@ class PipelineGenerator(BaseAppGenerator):
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
@@ -92,28 +92,28 @@ class PipelineGenerator(BaseAppGenerator):
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
) -> Mapping[str, Any] | Generator[Mapping | str, None, None]: ...
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: str | None = None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None:
# Add null check for dataset
with Session(db.engine, expire_on_commit=False) as session:
@@ -278,7 +278,7 @@ class PipelineGenerator(BaseAppGenerator):
context: contextvars.Context,
pipeline: Pipeline,
workflow_id: str,
user: Union[Account, EndUser],
user: Account | EndUser,
application_generate_entity: RagPipelineGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
@@ -286,7 +286,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
workflow_thread_pool_id: str | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
@@ -624,10 +624,10 @@ class PipelineGenerator(BaseAppGenerator):
application_generate_entity: RagPipelineGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
"""
Handle response.
:param application_generate_entity: application generate entity
@@ -668,7 +668,7 @@ class PipelineGenerator(BaseAppGenerator):
datasource_info: Mapping[str, Any],
created_from: str,
position: int,
account: Union[Account, EndUser],
account: Account | EndUser,
batch: str,
document_form: str,
):
@@ -715,7 +715,7 @@ class PipelineGenerator(BaseAppGenerator):
pipeline: Pipeline,
workflow: Workflow,
start_node_id: str,
user: Union[Account, EndUser],
user: Account | EndUser,
) -> list[Mapping[str, Any]]:
"""
Format datasource info list.

View File

@@ -5,7 +5,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from typing import TYPE_CHECKING, Any, Literal, overload
from flask import Flask, current_app
from graphon.graph_engine.layers import GraphEngineLayer
@@ -64,7 +64,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
@@ -82,7 +82,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
@@ -100,7 +100,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
@@ -110,14 +110,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]: ...
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
@@ -127,7 +127,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]:
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
files: Sequence[Mapping[str, Any]] = args.get("files") or []
@@ -237,7 +237,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
application_generate_entity: WorkflowAppGenerateEntity,
graph_runtime_state: GraphRuntimeState,
workflow_execution_repository: WorkflowExecutionRepository,
@@ -245,7 +245,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Resume a paused workflow execution using the persisted runtime state.
"""
@@ -269,7 +269,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
@@ -280,7 +280,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
@@ -609,10 +609,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Annotated, Literal, Self, TypeAlias
from typing import Annotated, Literal, Self
from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent
@@ -27,7 +27,7 @@ class _AdvancedChatAppGenerateEntityWrapper(BaseModel):
entity: AdvancedChatAppGenerateEntity
_GenerateEntityUnion: TypeAlias = Annotated[
type _GenerateEntityUnion = Annotated[
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
Field(discriminator="type"),
]

View File

@@ -2,7 +2,7 @@ import logging
import time
from collections.abc import Generator
from threading import Thread
from typing import Any, Union, cast
from typing import Any, cast
from graphon.file import FileTransferMethod
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@@ -72,14 +72,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
"""
_task_state: EasyUITaskState
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
_application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity
_precomputed_event_type: StreamEvent | None = None
def __init__(
self,
application_generate_entity: Union[
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
],
application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
@@ -117,11 +115,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
def process(
self,
) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
]:
) -> (
ChatbotAppBlockingResponse
| CompletionAppBlockingResponse
| Generator[ChatbotAppStreamResponse | CompletionAppStreamResponse, None, None]
):
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
@@ -136,7 +134,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]:
) -> ChatbotAppBlockingResponse | CompletionAppBlockingResponse:
"""
Process blocking response.
:return:
@@ -148,7 +146,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.model_dump()
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
response: ChatbotAppBlockingResponse | CompletionAppBlockingResponse
if self._conversation_mode == AppMode.COMPLETION:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
@@ -183,7 +181,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
) -> Generator[ChatbotAppStreamResponse | CompletionAppStreamResponse, None, None]:
"""
To stream response.
:return:

View File

@@ -5,14 +5,13 @@ This layer centralizes model-quota deduction outside node implementations.
"""
import logging
from typing import TYPE_CHECKING, cast, final
from typing import TYPE_CHECKING, cast, final, override
from graphon.enums import BuiltinNodeTypes
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
from graphon.nodes.base.node import Node
from typing_extensions import override
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available

View File

@@ -10,7 +10,7 @@ associates with the node span.
import logging
from contextvars import Token
from dataclasses import dataclass
from typing import cast, final
from typing import cast, final, override
from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.graph_engine.layers import GraphEngineLayer
@@ -18,7 +18,6 @@ from graphon.graph_events import GraphNodeEventBase
from graphon.nodes.base.node import Node
from opentelemetry import context as context_api
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
from typing_extensions import override
from configs import dify_config
from extensions.otel.parser import (

View File

@@ -44,7 +44,8 @@ class HumanInputContent(BaseModel):
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
# Keep a runtime alias here: callers and tests expect identity with HumanInputContent.
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent # noqa: UP040
__all__ = [
"ExecutionExtraContentDomainModel",

View File

@@ -2,12 +2,13 @@ import importlib.util
import logging
import sys
from types import ModuleType
from typing import AnyStr
logger = logging.getLogger(__name__)
def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
def import_module_from_source[T: (str, bytes)](
*, module_name: str, py_file_path: T, use_lazy_loader: bool = False
) -> ModuleType:
"""
Importing a module from the source file directly
"""

View File

@@ -2,7 +2,6 @@ import os
from collections import OrderedDict
from collections.abc import Callable
from functools import lru_cache
from typing import TypeVar
from configs import dify_config
from core.tools.utils.yaml_utils import load_yaml_file_cached
@@ -65,10 +64,7 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str])
return position_map
T = TypeVar("T")
def is_filtered(
def is_filtered[T](
include_set: set[str],
exclude_set: set[str],
data: T,
@@ -97,11 +93,11 @@ def is_filtered(
return False
def sort_by_position_map(
def sort_by_position_map[T](
position_map: dict[str, int],
data: list[T],
name_func: Callable[[T], str],
):
) -> list[T]:
"""
Sort the objects by the position map.
If the name of the object is not in the position map, it will be put at the end.
@@ -116,11 +112,11 @@ def sort_by_position_map(
return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf")))
def sort_to_dict_by_position_map(
def sort_to_dict_by_position_map[T](
position_map: dict[str, int],
data: list[T],
name_func: Callable[[T], str],
):
) -> OrderedDict[str, T]:
"""
Sort the objects into a ordered dict by the position map.
If the name of the object is not in the position map, it will be put at the end.

View File

@@ -4,7 +4,7 @@ Proxy requests to avoid SSRF
import logging
import time
from typing import Any, TypeAlias
from typing import Any
import httpx
from pydantic import TypeAdapter, ValidationError
@@ -20,8 +20,8 @@ SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
Headers: TypeAlias = dict[str, str]
_HEADERS_ADAPTER = TypeAdapter(Headers)
type Headers = dict[str, str]
_HEADERS_ADAPTER: TypeAdapter[Headers] = TypeAdapter(Headers)
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"

View File

@@ -3,7 +3,7 @@ import queue
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, TypeAlias, final
from typing import Any, final
from urllib.parse import urljoin, urlparse
import httpx
@@ -33,9 +33,9 @@ class _StatusError:
# Type aliases for better readability
ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
type ReadQueue = queue.Queue[SessionMessage | Exception | None]
type WriteQueue = queue.Queue[SessionMessage | Exception | None]
type StatusQueue = queue.Queue[_StatusReady | _StatusError]
class SSETransport:

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, Generic, TypeVar
from typing import Any, TypeVar
from pydantic import BaseModel
@@ -9,13 +9,12 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
@dataclass
class RequestContext(Generic[SessionT, LifespanContextT]):
class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]:
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT

View File

@@ -4,7 +4,7 @@ from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from datetime import timedelta
from types import TracebackType
from typing import Any, Generic, Self, TypeVar
from typing import Any, Self, cast
from httpx import HTTPStatusError
from pydantic import BaseModel
@@ -34,16 +34,10 @@ from core.mcp.types import (
logger = logging.getLogger(__name__)
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResultT: ClientResult | ServerResult]:
"""Handles responding to MCP requests and manages request lifecycle.
This class MUST be used as a context manager to ensure proper cleanup and
@@ -60,7 +54,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""
request: ReceiveRequestT
_session: Any
_session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]"
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
def __init__(
@@ -68,7 +62,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""",
session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
):
self.request_id = request_id
@@ -111,7 +105,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self.completed = True
self._session._send_response(request_id=self.request_id, response=response)
self._session.send_response(request_id=self.request_id, response=response)
def cancel(self):
"""Cancel this request and mark it as completed."""
@@ -120,21 +114,19 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self.completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
self._session._send_response(
self._session.send_response(
request_id=self.request_id,
response=ErrorData(code=0, message="Request cancelled", data=None),
)
class BaseSession(
Generic[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
],
):
class BaseSession[
SendRequestT: ClientRequest | ServerRequest,
SendNotificationT: ClientNotification | ServerNotification,
SendResultT: ClientResult | ServerResult,
ReceiveRequestT: ClientRequest | ServerRequest,
ReceiveNotificationT: ClientNotification | ServerNotification,
]:
"""
Implements an MCP "session" on top of read/write streams, including features
like request/response linking, notifications, and progress.
@@ -204,13 +196,13 @@ class BaseSession(
# The receiver thread should have already exited due to the None message in the queue
self._executor.shutdown(wait=False)
def send_request(
def send_request[T: BaseModel](
self,
request: SendRequestT,
result_type: type[ReceiveResultT],
result_type: type[T],
request_read_timeout_seconds: timedelta | None = None,
metadata: MessageMetadata | None = None,
) -> ReceiveResultT:
) -> T:
"""
Sends a request and wait for a response. Raises an McpError if the
response contains an error. If a request read timeout is provided, it
@@ -299,7 +291,7 @@ class BaseSession(
)
self._write_stream.put(session_message)
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
def send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
@@ -346,6 +338,7 @@ class BaseSession(
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
validated_request = cast(ReceiveRequestT, validated_request)
responder = RequestResponder[ReceiveRequestT, SendResultT](
request_id=message.message.root.id,
@@ -366,6 +359,7 @@ class BaseSession(
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
notification = cast(ReceiveNotificationT, notification)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable
from dataclasses import dataclass
from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar
from typing import Annotated, Any, Literal
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
from pydantic.networks import AnyUrl, UrlConstraints
@@ -31,7 +31,7 @@ ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
AnyFunction: TypeAlias = Callable[..., Any]
type AnyFunction = Callable[..., Any]
class RequestParams(BaseModel):
@@ -68,12 +68,7 @@ class NotificationParams(BaseModel):
"""
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None)
NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams | dict[str, Any] | None)
MethodT = TypeVar("MethodT", bound=str)
class Request(BaseModel, Generic[RequestParamsT, MethodT]):
class Request[RequestParamsT: RequestParams | dict[str, Any] | None, MethodT: str](BaseModel):
"""Base class for JSON-RPC requests."""
method: MethodT
@@ -81,14 +76,14 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
model_config = ConfigDict(extra="allow")
class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]):
class PaginatedRequest[T: str](Request[PaginatedRequestParams | None, T]):
"""Base class for paginated requests,
matching the schema's PaginatedRequest interface."""
params: PaginatedRequestParams | None = None
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
class Notification[NotificationParamsT: NotificationParams | dict[str, Any] | None, MethodT: str](BaseModel):
"""Base class for JSON-RPC notifications."""
method: MethodT
@@ -736,7 +731,7 @@ class ResourceLink(Resource):
ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource
"""A content block that can be used in prompts and tool results."""
Content: TypeAlias = ContentBlock
type Content = ContentBlock
# """DEPRECATED: Content is deprecated, you should use ContentBlock directly."""

View File

@@ -6,7 +6,7 @@ import queue
import threading
import time
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, TypedDict
from uuid import UUID, uuid4
from cachetools import LRUCache
@@ -14,7 +14,6 @@ from flask import current_app
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from typing_extensions import TypedDict
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
@@ -464,7 +463,7 @@ class OpsTraceManager:
@classmethod
def get_ops_trace_instance(
cls,
app_id: Union[UUID, str] | None = None,
app_id: UUID | str | None = None,
):
"""
Get ops trace through model config
@@ -717,7 +716,7 @@ class TraceTask:
self,
trace_type: Any,
message_id: str | None = None,
workflow_execution: Optional["WorkflowExecution"] = None,
workflow_execution: "WorkflowExecution | None" = None,
conversation_id: str | None = None,
user_id: str | None = None,
timer: Any | None = None,

View File

@@ -1,5 +1,4 @@
from collections.abc import Generator, Mapping
from typing import Generic, TypeVar
from pydantic import BaseModel
@@ -19,9 +18,6 @@ class BaseBackwardsInvocation:
yield BaseBackwardsInvocationResponse(data=response).model_dump_json().encode()
T = TypeVar("T", bound=dict | Mapping | str | bool | int | BaseModel)
class BaseBackwardsInvocationResponse(BaseModel, Generic[T]):
class BaseBackwardsInvocationResponse[T: dict | Mapping | str | bool | int | BaseModel](BaseModel):
data: T | None = None
error: str = ""

View File

@@ -4,7 +4,7 @@ import enum
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from typing import Any, Generic, TypeVar
from typing import Any
from graphon.model_runtime.entities.model_entities import AIModelEntity
from graphon.model_runtime.entities.provider_entities import ProviderEntity
@@ -19,10 +19,8 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
from core.trigger.entities.entities import TriggerProviderEntity
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
class PluginDaemonBasicResponse(BaseModel, Generic[T]):
class PluginDaemonBasicResponse[T: BaseModel | dict | list | bool | str](BaseModel):
"""
Basic response from plugin daemon.
"""

View File

@@ -2,7 +2,7 @@ import inspect
import json
import logging
from collections.abc import Callable, Generator
from typing import Any, TypeVar, cast
from typing import Any, cast
import httpx
from graphon.model_runtime.errors.invoke import (
@@ -51,8 +51,6 @@ elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout):
else:
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
T = TypeVar("T", bound=(BaseModel | dict[str, Any] | list[Any] | bool | str))
logger = logging.getLogger(__name__)
_httpx_client: httpx.Client = get_pooled_http_client(
@@ -191,7 +189,7 @@ class BasePluginClient:
logger.exception("Stream request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
def _stream_request_with_model(
def _stream_request_with_model[T: BaseModel | dict[str, Any] | list[Any] | bool | str](
self,
method: str,
path: str,
@@ -207,7 +205,7 @@ class BasePluginClient:
for line in self._stream_request(method, path, params, headers, data, files):
yield type_(**json.loads(line)) # type: ignore
def _request_with_model(
def _request_with_model[T: BaseModel | dict[str, Any] | list[Any] | bool | str](
self,
method: str,
path: str,
@@ -223,7 +221,7 @@ class BasePluginClient:
response = self._request(method, path, headers, data, params, files)
return type_(**response.json()) # type: ignore[return-value]
def _request_with_plugin_daemon_response(
def _request_with_plugin_daemon_response[T: BaseModel | dict[str, Any] | list[Any] | bool | str](
self,
method: str,
path: str,
@@ -278,7 +276,7 @@ class BasePluginClient:
return rep.data
def _request_with_plugin_daemon_response_stream(
def _request_with_plugin_daemon_response_stream[T: BaseModel | dict[str, Any] | list[Any] | bool | str](
self,
method: str,
path: str,

View File

@@ -1,12 +1,9 @@
from collections.abc import Generator
from dataclasses import dataclass, field
from typing import TypeVar, Union
from core.agent.entities import AgentInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage
MessageType = TypeVar("MessageType", bound=Union[ToolInvokeMessage, AgentInvokeMessage])
@dataclass
class FileChunk:
@@ -22,11 +19,11 @@ class FileChunk:
self.data = bytearray(self.total_length)
def merge_blob_chunks(
response: Generator[MessageType, None, None],
def merge_blob_chunks[T: ToolInvokeMessage | AgentInvokeMessage](
response: Generator[T, None, None],
max_file_size: int = 30 * 1024 * 1024,
max_chunk_size: int = 8192,
) -> Generator[MessageType, None, None]:
) -> Generator[T, None, None]:
"""
Merge streaming blob chunks into complete blob messages.

View File

@@ -1,6 +1,7 @@
from typing import TypedDict
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
from typing_extensions import TypedDict
from core.model_manager import ModelInstance, ModelManager
from core.rag.data_post_processor.reorder import ReorderRunner

View File

@@ -1,10 +1,9 @@
from collections import defaultdict
from typing import Any
from typing import Any, TypedDict
import orjson
from pydantic import BaseModel
from sqlalchemy import select
from typing_extensions import TypedDict
from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler

View File

@@ -1,13 +1,12 @@
import concurrent.futures
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any, NotRequired
from typing import Any, NotRequired, TypedDict
from flask import Flask, current_app
from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from typing_extensions import TypedDict
from configs import dify_config
from core.db.session_factory import session_factory

View File

@@ -3,7 +3,7 @@ import logging
import uuid
from collections.abc import Callable
from functools import wraps
from typing import Any, Concatenate, ParamSpec, TypeVar
from typing import Any, Concatenate
from mo_vector.client import MoVectorClient # type: ignore
from pydantic import BaseModel, model_validator
@@ -20,15 +20,12 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T", bound="MatrixoneVector")
def ensure_client(func: Callable[Concatenate[T, P], R]):
def ensure_client[T: MatrixoneVector, **P, R](
func: Callable[Concatenate[T, P], R],
) -> Callable[Concatenate[T, P], R]:
@wraps(func)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs):
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)

View File

@@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any
import qdrant_client
from flask import current_app
@@ -36,8 +36,8 @@ if TYPE_CHECKING:
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
DictFilter = dict[str, Union[str, int, bool, dict, list]]
MetadataFilter = Union[DictFilter, common_types.Filter]
type DictFilter = dict[str, str | int | bool | dict | list]
type MetadataFilter = DictFilter | common_types.Filter
class PathQdrantParams(BaseModel):

View File

@@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any
import httpx
import qdrant_client
@@ -40,8 +40,8 @@ if TYPE_CHECKING:
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
DictFilter = dict[str, Union[str, int, bool, dict, list]]
MetadataFilter = Union[DictFilter, common_types.Filter]
type DictFilter = dict[str, str | int | bool | dict | list]
type MetadataFilter = DictFilter | common_types.Filter
class TidbOnQdrantConfig(BaseModel):

View File

@@ -1,5 +1,6 @@
from typing import TypedDict
from pydantic import BaseModel
from typing_extensions import TypedDict
from models.dataset import DocumentSegment

View File

@@ -12,11 +12,11 @@ import mimetypes
from collections.abc import Generator, Mapping
from io import BufferedReader, BytesIO
from pathlib import Path, PurePath
from typing import Any, Union
from typing import Any
from pydantic import BaseModel, ConfigDict, model_validator
PathLike = Union[str, PurePath]
type PathLike = str | PurePath
class Blob(BaseModel):
@@ -29,7 +29,7 @@ class Blob(BaseModel):
Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
"""
data: Union[bytes, str, None] = None # Raw data
data: bytes | str | None = None # Raw data
mimetype: str | None = None # Not to be confused with a file extension
encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string
# Location where the original content was found
@@ -75,7 +75,7 @@ class Blob(BaseModel):
raise ValueError(f"Unable to get bytes for blob {self}")
@contextlib.contextmanager
def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]:
def as_bytes_io(self) -> Generator[BytesIO | BufferedReader, None, None]:
"""Read data as a byte stream."""
if isinstance(self.data, bytes):
yield BytesIO(self.data)
@@ -117,7 +117,7 @@ class Blob(BaseModel):
@classmethod
def from_data(
cls,
data: Union[str, bytes],
data: str | bytes,
*,
encoding: str = "utf-8",
mime_type: str | None = None,

View File

@@ -1,9 +1,8 @@
import json
import time
from typing import Any, NotRequired, cast
from typing import Any, NotRequired, TypedDict, cast
import httpx
from typing_extensions import TypedDict
from extensions.ext_storage import storage

View File

@@ -1,11 +1,10 @@
import json
from collections.abc import Generator
from typing import Any, Union
from typing import Any, TypedDict
from urllib.parse import urljoin
import httpx
from httpx import Response
from typing_extensions import TypedDict
from core.rag.extractor.watercrawl.exceptions import (
WaterCrawlAuthenticationError,
@@ -142,7 +141,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
def create_crawl_request(
self,
url: Union[list, str] | None = None,
url: list | str | None = None,
spider_options: SpiderOptions | None = None,
page_options: PageOptions | None = None,
plugin_options: dict[str, Any] | None = None,

View File

@@ -1,8 +1,6 @@
from collections.abc import Generator
from datetime import datetime
from typing import Any
from typing_extensions import TypedDict
from typing import Any, TypedDict
from core.rag.extractor.watercrawl.client import PageOptions, SpiderOptions, WaterCrawlAPIClient

View File

@@ -7,12 +7,11 @@ import os
import re
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, NotRequired, Optional
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
from urllib.parse import unquote, urlparse
import httpx
from sqlalchemy import select
from typing_extensions import TypedDict
from configs import dify_config
from core.entities.knowledge_entities import PreviewDetail
@@ -118,11 +117,12 @@ class BaseIndexProcessor(ABC):
max_tokens: int,
chunk_overlap: int,
separator: str,
embedding_model_instance: Optional["ModelInstance"],
embedding_model_instance: "ModelInstance | None",
) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
"""
character_splitter: TextSplitter
if processing_rule_mode in ["custom", "hierarchical"]:
# The user-defined segmentation rule
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
@@ -148,7 +148,7 @@ class BaseIndexProcessor(ABC):
embedding_model_instance=embedding_model_instance,
)
return character_splitter # type: ignore
return character_splitter
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
"""

View File

@@ -4,19 +4,13 @@ from __future__ import annotations
import codecs
import re
from typing import Any
from collections.abc import Collection
from typing import Any, Literal
from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
from core.model_manager import ModelInstance
from core.rag.splitter.text_splitter import (
TS,
Collection,
Literal,
RecursiveCharacterTextSplitter,
Set,
Union,
)
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
@@ -25,13 +19,13 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
"""
@classmethod
def from_encoder(
cls: type[TS],
def from_encoder[T: EnhanceRecursiveCharacterTextSplitter](
cls: type[T],
embedding_model_instance: ModelInstance | None,
allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
allowed_special: Literal["all"] | set[str] = set(),
disallowed_special: Literal["all"] | Collection[str] = "all",
**kwargs: Any,
):
) -> T:
def _token_encoder(texts: list[str]) -> list[int]:
if not texts:
return []

View File

@@ -6,19 +6,12 @@ import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Collection, Iterable, Sequence, Set
from dataclasses import dataclass
from typing import (
Any,
Literal,
TypeVar,
Union,
)
from typing import Any, Literal
from core.rag.models.document import BaseDocumentTransformer, Document
logger = logging.getLogger(__name__)
TS = TypeVar("TS", bound="TextSplitter")
def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]:
# Now that we have the separator, split the text
@@ -194,8 +187,8 @@ class TokenTextSplitter(TextSplitter):
self,
encoding_name: str = "gpt2",
model_name: str | None = None,
allowed_special: Union[Literal["all"], Set[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
allowed_special: Literal["all"] | Set[str] = set(),
disallowed_special: Literal["all"] | Collection[str] = "all",
**kwargs: Any,
):
"""Create a new TextSplitter."""

View File

@@ -6,7 +6,6 @@ providing improved performance by offloading database operations to background w
"""
import logging
from typing import Union
from graphon.entities import WorkflowExecution
from sqlalchemy.engine import Engine
@@ -47,7 +46,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
user: Account | EndUser,
app_id: str | None,
triggered_from: WorkflowRunTriggeredFrom | None,
):

View File

@@ -7,7 +7,6 @@ providing improved performance by offloading database operations to background w
import logging
from collections.abc import Sequence
from typing import Union
from graphon.entities import WorkflowNodeExecution
from sqlalchemy.engine import Engine
@@ -54,7 +53,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
user: Account | EndUser,
app_id: str | None,
triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
):

View File

@@ -7,7 +7,7 @@ allowing users to configure different repository backends through string paths.
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal, Protocol, Union
from typing import Literal, Protocol
from graphon.entities import WorkflowExecution, WorkflowNodeExecution
from sqlalchemy.engine import Engine
@@ -61,8 +61,8 @@ class DifyCoreRepositoryFactory:
@classmethod
def create_workflow_execution_repository(
cls,
session_factory: Union[sessionmaker, Engine],
user: Union[Account, EndUser],
session_factory: sessionmaker | Engine,
user: Account | EndUser,
app_id: str,
triggered_from: WorkflowRunTriggeredFrom,
) -> WorkflowExecutionRepository:
@@ -97,8 +97,8 @@ class DifyCoreRepositoryFactory:
@classmethod
def create_workflow_node_execution_repository(
cls,
session_factory: Union[sessionmaker, Engine],
user: Union[Account, EndUser],
session_factory: sessionmaker | Engine,
user: Account | EndUser,
app_id: str,
triggered_from: WorkflowNodeExecutionTriggeredFrom,
) -> WorkflowNodeExecutionRepository:

View File

@@ -4,7 +4,6 @@ SQLAlchemy implementation of the WorkflowExecutionRepository.
import json
import logging
from typing import Union
from graphon.entities import WorkflowExecution
from graphon.enums import WorkflowExecutionStatus, WorkflowType
@@ -40,7 +39,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
user: Account | EndUser,
app_id: str | None,
triggered_from: WorkflowRunTriggeredFrom | None,
):

View File

@@ -7,7 +7,7 @@ import json
import logging
from collections.abc import Callable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import Any, TypeVar, Union
from typing import Any
import psycopg2.errors
from graphon.entities import WorkflowNodeExecution
@@ -63,7 +63,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
user: Account | EndUser,
app_id: str | None,
triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
):
@@ -551,10 +551,7 @@ def _deterministic_json_dump(value: Mapping[str, Any]) -> str:
return json.dumps(value, sort_keys=True)
_T = TypeVar("_T")
def _find_first(seq: Sequence[_T], pred: Callable[[_T], bool]) -> _T | None:
def _find_first[T](seq: Sequence[T], pred: Callable[[T], bool]) -> T | None:
filtered = [i for i in seq if pred(i)]
if filtered:
return filtered[0]

View File

@@ -3,15 +3,15 @@ import re
import threading
from collections import deque
from dataclasses import dataclass
from typing import Any, Union
from typing import Any
from core.schemas.registry import SchemaRegistry
logger = logging.getLogger(__name__)
# Type aliases for better clarity
SchemaType = Union[dict[str, Any], list[Any], str, int, float, bool, None]
SchemaDict = dict[str, Any]
type SchemaType = dict[str, Any] | list[Any] | str | int | float | bool | None
type SchemaDict = dict[str, Any]
# Pre-compiled pattern for better performance
_DIFY_SCHEMA_PATTERN = re.compile(r"^https://dify\.ai/schemas/(v\d+)/(.+)\.json$")
@@ -54,7 +54,7 @@ class QueueItem:
current: Any
parent: Any | None
key: Union[str, int] | None
key: str | int | None
depth: int
ref_path: set[str]

View File

@@ -1,6 +1,5 @@
import hashlib
import logging
from typing import TypeVar
from redis import RedisError
@@ -11,8 +10,6 @@ logger = logging.getLogger(__name__)
TRIGGER_DEBUG_EVENT_TTL = 300
TTriggerDebugEvent = TypeVar("TTriggerDebugEvent", bound="BaseDebugEvent")
class TriggerDebugEventBus:
"""
@@ -81,15 +78,15 @@ class TriggerDebugEventBus:
return 0
@classmethod
def poll(
def poll[T: BaseDebugEvent](
cls,
event_type: type[TTriggerDebugEvent],
event_type: type[T],
pool_key: str,
tenant_id: str,
user_id: str,
app_id: str,
node_id: str,
) -> TTriggerDebugEvent | None:
) -> T | None:
"""
Poll for an event or register to the waiting pool.

View File

@@ -2,7 +2,7 @@ import importlib
import pkgutil
from collections.abc import Callable, Iterator, Mapping, MutableMapping
from functools import lru_cache
from typing import TYPE_CHECKING, Any, TypeAlias, cast, final
from typing import TYPE_CHECKING, Any, cast, final, override
from graphon.entities.base_node_data import BaseNodeData
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
@@ -22,7 +22,6 @@ from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeDat
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing_extensions import override
from configs import dify_config
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
@@ -192,7 +191,7 @@ class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Nod
NODE_TYPE_CLASSES_MAPPING: MutableMapping[NodeType, Mapping[str, type[Node]]] = _LazyNodeTypeClassesMapping()
LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
type LLMCompatibleNodeData = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
def fetch_memory(

View File

@@ -3,7 +3,7 @@ import logging
import ssl
from collections.abc import Callable
from datetime import timedelta
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union
from typing import TYPE_CHECKING, Any, Union
import redis
from redis import RedisError
@@ -297,12 +297,7 @@ def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
return RedisBroadcastChannel(_pubsub_redis_client)
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def redis_fallback(default_return: T | None = None): # type: ignore
def redis_fallback[T](default_return: T | None = None): # type: ignore
"""
decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
@@ -310,9 +305,9 @@ def redis_fallback(default_return: T | None = None): # type: ignore
default_return: The value to return when a Redis operation fails. Defaults to None.
"""
def decorator(func: Callable[P, R]):
def decorator[**P, R](func: Callable[P, R]) -> Callable[P, R | T | None]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | T | None:
try:
return func(*args, **kwargs)
except RedisError as e:

View File

@@ -2,7 +2,6 @@ import json
import logging
import os
import time
from typing import Union
from graphon.entities import WorkflowExecution
from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
@@ -27,7 +26,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
user: Account | EndUser,
app_id: str | None,
triggered_from: WorkflowRunTriggeredFrom | None,
):

View File

@@ -11,7 +11,7 @@ import os
import time
from collections.abc import Sequence
from datetime import datetime
from typing import Any, Union
from typing import Any
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@@ -109,7 +109,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
user: Account | EndUser,
app_id: str | None,
triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
):

View File

@@ -1,6 +1,6 @@
import functools
from collections.abc import Callable
from typing import ParamSpec, TypeVar, cast
from typing import cast
from opentelemetry.trace import get_tracer
@@ -8,9 +8,6 @@ from configs import dify_config
from extensions.otel.decorators.handler import SpanHandler
from extensions.otel.runtime import is_instrument_flag_enabled
P = ParamSpec("P")
R = TypeVar("R")
_HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()}
@@ -21,7 +18,7 @@ def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler:
return _HANDLER_INSTANCES[handler_class]
def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[Callable[P, R]], Callable[P, R]]:
def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Decorator that traces a function with an OpenTelemetry span.

View File

@@ -1,11 +1,9 @@
import inspect
from collections.abc import Callable, Mapping
from typing import Any, TypeVar
from typing import Any
from opentelemetry.trace import SpanKind, Status, StatusCode
R = TypeVar("R")
class SpanHandler:
"""
@@ -31,9 +29,9 @@ class SpanHandler:
"""
return f"{wrapped.__module__}.{wrapped.__qualname__}"
def _extract_arguments(
def _extract_arguments[T](
self,
wrapped: Callable[..., R],
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
) -> dict[str, Any] | None:
@@ -61,13 +59,13 @@ class SpanHandler:
except Exception:
return None
def wrapper(
def wrapper[T](
self,
tracer: Any,
wrapped: Callable[..., R],
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
) -> R:
) -> T:
"""
Fully control the wrapper behavior.

View File

@@ -1,6 +1,6 @@
import logging
from collections.abc import Callable, Mapping
from typing import Any, TypeVar
from typing import Any
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.util.types import AttributeValue
@@ -12,19 +12,16 @@ from models.model import Account
logger = logging.getLogger(__name__)
R = TypeVar("R")
class AppGenerateHandler(SpanHandler):
"""Span handler for ``AppGenerateService.generate``."""
def wrapper(
def wrapper[T](
self,
tracer: Any,
wrapped: Callable[..., R],
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
) -> R:
) -> T:
try:
arguments = self._extract_arguments(wrapped, args, kwargs)
if not arguments:

View File

@@ -1,12 +1,12 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, TypeAlias
from typing import Any
from graphon.file import File
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
JSONValue: TypeAlias = Any
type JSONValue = Any
class ResponseModel(BaseModel):

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from datetime import datetime
from typing import TypeAlias
from uuid import uuid4
from graphon.file import File
@@ -10,7 +9,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
JSONValueType: TypeAlias = JSONValue
type JSONValueType = JSONValue
class ResponseModel(BaseModel):

View File

@@ -1,12 +1,10 @@
import contextvars
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING
from flask import Flask, g
T = TypeVar("T")
if TYPE_CHECKING:
from models import Account, EndUser

View File

@@ -42,13 +42,7 @@ def current_account_with_tenant():
return user, user.current_tenant_id
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
"""
If you decorate a view with this, it will ensure that the current user is
logged in and authenticated before calling the actual view. (If they are

View File

@@ -1,26 +1,20 @@
import logging
import sys
import urllib.parse
from dataclasses import dataclass
from typing import NotRequired
from typing import NotRequired, TypedDict
import httpx
from pydantic import TypeAdapter, ValidationError
from core.helper.http_client_pooling import get_pooled_http_client
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
logger = logging.getLogger(__name__)
JsonObject = dict[str, object]
JsonObjectList = list[JsonObject]
type JsonObject = dict[str, object]
type JsonObjectList = list[JsonObject]
JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
JSON_OBJECT_ADAPTER: TypeAdapter[JsonObject] = TypeAdapter(JsonObject)
JSON_OBJECT_LIST_ADAPTER: TypeAdapter[JsonObjectList] = TypeAdapter(JsonObjectList)
# Reuse a pooled httpx.Client for OAuth flows (public endpoints, no SSRF proxy).
_http_client: httpx.Client = get_pooled_http_client(

View File

@@ -1,6 +1,5 @@
import sys
import urllib.parse
from typing import Any, Literal
from typing import Any, Literal, TypedDict
import httpx
from flask_login import current_user
@@ -12,11 +11,6 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
class NotionPageSummary(TypedDict):
page_id: str

View File

@@ -8,7 +8,7 @@ from datetime import datetime
from decimal import Decimal
from enum import StrEnum, auto
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, cast
from uuid import uuid4
import sqlalchemy as sa
@@ -19,7 +19,6 @@ from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from graphon.file import helpers as file_helpers
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import TypedDict
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS

View File

@@ -1,6 +1,6 @@
import enum
import uuid
from typing import Any, Generic, TypeVar
from typing import Any
import sqlalchemy as sa
from sqlalchemy import CHAR, TEXT, VARCHAR, LargeBinary, TypeDecorator
@@ -110,17 +110,14 @@ class AdjustedJSON(TypeDecorator[dict | list | None]):
return value
_E = TypeVar("_E", bound=enum.StrEnum)
class EnumText(TypeDecorator[_E | None], Generic[_E]):
class EnumText[T: enum.StrEnum](TypeDecorator[T | None]):
impl = VARCHAR
cache_ok = True
_length: int
_enum_class: type[_E]
_enum_class: type[T]
def __init__(self, enum_class: type[_E], length: int | None = None):
def __init__(self, enum_class: type[T], length: int | None = None):
self._enum_class = enum_class
max_enum_value_len = max(len(e.value) for e in enum_class)
if length is not None:
@@ -131,25 +128,25 @@ class EnumText(TypeDecorator[_E | None], Generic[_E]):
# leave some rooms for future longer enum values.
self._length = max(max_enum_value_len, 20)
def process_bind_param(self, value: _E | str | None, dialect: Dialect) -> str | None:
def process_bind_param(self, value: T | str | None, dialect: Dialect) -> str | None:
if value is None:
return value
if isinstance(value, self._enum_class):
return value.value
# Since _E is bound to StrEnum which inherits from str, at this point value must be str
# Since T is bound to StrEnum which inherits from str, at this point value must be str
self._enum_class(value)
return value
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
return dialect.type_descriptor(VARCHAR(self._length))
def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None:
def process_result_value(self, value: str | None, dialect: Dialect) -> T | None:
if value is None or value == "":
return None
# Type annotation guarantees value is str at this point
return self._enum_class(value)
def compare_values(self, x: _E | None, y: _E | None) -> bool:
def compare_values(self, x: T | None, y: T | None) -> bool:
if x is None or y is None:
return x is y
return x == y

View File

@@ -1,7 +1,7 @@
[project]
name = "dify-api"
version = "1.13.3"
requires-python = ">=3.11,<3.13"
requires-python = "~=3.12.0"
dependencies = [
"aliyun-log-python-sdk~=0.9.37",
@@ -232,5 +232,5 @@ vdb = [
project-includes = ["."]
project-excludes = [".venv", "migrations/"]
python-platform = "linux"
python-version = "3.11.0"
python-version = "3.12.0"
infer-with-first-use = false

View File

@@ -50,6 +50,6 @@
"reportUntypedFunctionDecorator": "hint",
"reportUnnecessaryTypeIgnoreComment": "hint",
"reportAttributeAccessIssue": "hint",
"pythonVersion": "3.11",
"pythonVersion": "3.12",
"pythonPlatform": "All"
}
}

View File

@@ -5,12 +5,11 @@ import secrets
import uuid
from datetime import UTC, datetime, timedelta
from hashlib import sha256
from typing import Any, cast
from typing import Any, TypedDict, cast
from pydantic import BaseModel, TypeAdapter
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from typing_extensions import TypedDict
class InvitationData(TypedDict):

View File

@@ -1,7 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
from typing_extensions import TypedDict
from typing import Any, TypedDict
class AuthCredentials(TypedDict):

View File

@@ -2,13 +2,12 @@ import json
import logging
import os
from collections.abc import Sequence
from typing import Literal
from typing import Literal, TypedDict
import httpx
from pydantic import TypeAdapter
from sqlalchemy import select
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
from typing_extensions import TypedDict
from werkzeug.exceptions import InternalServerError
from core.helper.http_client_pooling import get_pooled_http_client

View File

@@ -1,7 +1,7 @@
import contextlib
import logging
from collections.abc import Callable, Sequence
from typing import Any, Union
from typing import Any
from graphon.variables.types import SegmentType
from sqlalchemy import asc, desc, func, or_, select
@@ -37,7 +37,7 @@ class ConversationService:
*,
session: Session,
app_model: App,
user: Union[Account, EndUser] | None,
user: Account | EndUser | None,
last_id: str | None,
limit: int,
invoke_from: InvokeFrom,
@@ -119,7 +119,7 @@ class ConversationService:
cls,
app_model: App,
conversation_id: str,
user: Union[Account, EndUser] | None,
user: Account | EndUser | None,
name: str | None,
auto_generate: bool,
):
@@ -159,7 +159,7 @@ class ConversationService:
return conversation
@classmethod
def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
def get_conversation(cls, app_model: App, conversation_id: str, user: Account | EndUser | None):
conversation = db.session.scalar(
select(Conversation)
.where(
@@ -179,7 +179,7 @@ class ConversationService:
return conversation
@classmethod
def delete(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
def delete(cls, app_model: App, conversation_id: str, user: Account | EndUser | None):
"""
Delete a conversation only if it belongs to the given user and app context.
@@ -209,7 +209,7 @@ class ConversationService:
cls,
app_model: App,
conversation_id: str,
user: Union[Account, EndUser] | None,
user: Account | EndUser | None,
limit: int,
last_id: str | None,
variable_name: str | None = None,
@@ -278,7 +278,7 @@ class ConversationService:
app_model: App,
conversation_id: str,
variable_id: str,
user: Union[Account, EndUser] | None,
user: Account | EndUser | None,
new_value: Any,
):
"""

View File

@@ -1,5 +1,4 @@
from collections.abc import Sequence
from typing import Union
from graphon.model_runtime.entities.model_entities import ModelType
from pydantic import TypeAdapter
@@ -57,7 +56,7 @@ class MessageService:
def pagination_by_first_id(
cls,
app_model: App,
user: Union[Account, EndUser] | None,
user: Account | EndUser | None,
conversation_id: str,
first_id: str | None,
limit: int,
@@ -117,7 +116,7 @@ class MessageService:
def pagination_by_last_id(
cls,
app_model: App,
user: Union[Account, EndUser] | None,
user: Account | EndUser | None,
last_id: str | None,
limit: int,
conversation_id: str | None = None,
@@ -170,7 +169,7 @@ class MessageService:
*,
app_model: App,
message_id: str,
user: Union[Account, EndUser] | None,
user: Account | EndUser | None,
rating: FeedbackRating | None,
content: str | None,
):
@@ -221,7 +220,7 @@ class MessageService:
return [record.to_dict() for record in feedbacks]
@classmethod
def get_message(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
def get_message(cls, app_model: App, user: Account | EndUser | None, message_id: str):
message = db.session.scalar(
select(Message)
.where(
@@ -241,7 +240,7 @@ class MessageService:
@classmethod
def get_suggested_questions_after_answer(
cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str, invoke_from: InvokeFrom
cls, app_model: App, user: Account | EndUser | None, message_id: str, invoke_from: InvokeFrom
) -> list[str]:
if not user:
raise ValueError("user cannot be None")

View File

@@ -5,7 +5,7 @@ import time
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any
from typing import Any, TypedDict
from uuid import uuid4
import click
@@ -14,7 +14,6 @@ import tqdm
from flask import Flask, current_app
from pydantic import TypeAdapter
from sqlalchemy.orm import Session
from typing_extensions import TypedDict
from core.agent.entities import AgentToolEntity
from core.helper import marketplace

View File

@@ -13,13 +13,12 @@ from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime
from typing import Any, cast
from typing import Any, TypedDict, cast
import click
from pydantic import TypeAdapter
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.engine import CursorResult
from typing_extensions import TypedDict
class _TableInfo(TypedDict, total=False):

View File

@@ -1,5 +1,3 @@
from typing import Union
from sqlalchemy import select
from extensions.ext_database import db
@@ -14,7 +12,7 @@ from services.message_service import MessageService
class SavedMessageService:
@classmethod
def pagination_by_last_id(
cls, app_model: App, user: Union[Account, EndUser] | None, last_id: str | None, limit: int
cls, app_model: App, user: Account | EndUser | None, last_id: str | None, limit: int
) -> InfiniteScrollPagination:
if not user:
raise ValueError("User is required")
@@ -34,7 +32,7 @@ class SavedMessageService:
)
@classmethod
def save(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
def save(cls, app_model: App, user: Account | EndUser | None, message_id: str):
if not user:
return
saved_message = db.session.scalar(
@@ -64,7 +62,7 @@ class SavedMessageService:
db.session.commit()
@classmethod
def delete(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
def delete(cls, app_model: App, user: Account | EndUser | None, message_id: str):
if not user:
return
saved_message = db.session.scalar(

View File

@@ -1,11 +1,10 @@
import json
import logging
from typing import Any, cast
from typing import Any, TypedDict, cast
from graphon.model_runtime.utils.encoders import jsonable_encoder
from httpx import get
from sqlalchemy import select
from typing_extensions import TypedDict
from core.entities.provider_entities import ProviderConfig
from core.tools.__base.tool_runtime import ToolRuntime

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import dataclasses
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any, Generic, TypeAlias, TypeVar, overload
from typing import Any, overload
from graphon.file import File
from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable
@@ -43,12 +43,9 @@ class _PCKeys:
CHILD_CONTENTS = "child_contents"
_T = TypeVar("_T")
@dataclasses.dataclass(frozen=True)
class _PartResult(Generic[_T]):
value: _T
class _PartResult[T]:
value: T
value_size: int
truncated: bool
@@ -61,7 +58,7 @@ class UnknownTypeError(Exception):
pass
JSONTypes: TypeAlias = int | float | str | list[object] | dict[str, object] | None | bool
type JSONTypes = int | float | str | list[object] | dict[str, object] | None | bool
@dataclasses.dataclass(frozen=True)

View File

@@ -1,5 +1,3 @@
from typing import Union
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -20,7 +18,7 @@ class WebConversationService:
*,
session: Session,
app_model: App,
user: Union[Account, EndUser] | None,
user: Account | EndUser | None,
last_id: str | None,
limit: int,
invoke_from: InvokeFrom,
@@ -61,7 +59,7 @@ class WebConversationService:
)
@classmethod
def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
def pin(cls, app_model: App, conversation_id: str, user: Account | EndUser | None):
if not user:
return
pinned_conversation = db.session.scalar(
@@ -93,7 +91,7 @@ class WebConversationService:
db.session.commit()
@classmethod
def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
def unpin(cls, app_model: App, conversation_id: str, user: Account | EndUser | None):
if not user:
return
pinned_conversation = db.session.scalar(

View File

@@ -1,5 +1,5 @@
import json
from typing import Any
from typing import Any, TypedDict
from graphon.file import FileUploadConfig
from graphon.model_runtime.entities.llm_entities import LLMMode
@@ -7,7 +7,6 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder
from graphon.nodes import BuiltinNodeTypes
from graphon.variables.input_entities import VariableEntity
from sqlalchemy import select
from typing_extensions import TypedDict
from core.app.app_config.entities import (
DatasetEntity,

View File

@@ -1,12 +1,11 @@
import json
import uuid
from datetime import datetime
from typing import Any
from typing import Any, TypedDict
from graphon.enums import WorkflowExecutionStatus
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session
from typing_extensions import TypedDict
from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
from models.enums import AppTriggerType, CreatorUserRole

View File

@@ -3,7 +3,7 @@ import logging
import uuid
from collections.abc import Generator, Mapping
from enum import StrEnum
from typing import Annotated, Any, TypeAlias, Union
from typing import Annotated, Any
from celery import shared_task
from flask import current_app, json
@@ -68,7 +68,7 @@ def _get_user_type_descriminator(value: Any):
return None
User: TypeAlias = Annotated[
type User = Annotated[
(Annotated[_Account, Tag(_UserType.ACCOUNT)] | Annotated[_EndUser, Tag(_UserType.END_USER)]),
Discriminator(_get_user_type_descriminator),
]
@@ -93,7 +93,7 @@ class AppExecutionParams(BaseModel):
cls,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,

View File

@@ -12,7 +12,7 @@ import os
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path
from typing import Protocol, TypeVar
from typing import Protocol
import psycopg2
import pytest
@@ -48,11 +48,8 @@ class _CloserProtocol(Protocol):
pass
_Closer = TypeVar("_Closer", bound=_CloserProtocol)
@contextmanager
def _auto_close(closer: _Closer) -> Generator[_Closer, None, None]:
def _auto_close[T: _CloserProtocol](closer: T) -> Generator[T, None, None]:
yield closer
closer.close()

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable, Iterable
from enum import StrEnum
from typing import Any, NamedTuple, TypeVar
from typing import Any, NamedTuple
import pytest
import sqlalchemy as sa
@@ -58,10 +58,7 @@ class _ColumnTest(_Base):
long_value: Mapped[_EnumWithLongValue] = mapped_column(EnumText(enum_class=_EnumWithLongValue), nullable=False)
_T = TypeVar("_T")
def _first(it: Iterable[_T]) -> _T:
def _first[T](it: Iterable[T]) -> T:
ls = list(it)
if not ls:
raise ValueError("List is empty")

View File

@@ -6,7 +6,7 @@ and data_source_detail_dict for all data_source_type values, including "local_fi
"""
import json
from typing import Generic, Literal, NotRequired, TypedDict, TypeVar, Union
from typing import Literal, NotRequired, TypedDict
from models.dataset import Document
@@ -31,12 +31,10 @@ class WebsiteCrawlInfo(TypedDict):
job_id: str
RawInfo = Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo]
T_type = TypeVar("T_type", bound=str)
T_info = TypeVar("T_info", bound=Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo])
type RawInfo = LocalFileInfo | UploadFileInfo | NotionImportInfo | WebsiteCrawlInfo
class Case(TypedDict, Generic[T_type, T_info]):
class Case[T_type: str, T_info: RawInfo](TypedDict):
data_source_type: T_type
data_source_info: str
expected_raw: T_info
@@ -47,7 +45,7 @@ UploadFileCase = Case[Literal["upload_file"], UploadFileInfo]
NotionImportCase = Case[Literal["notion_import"], NotionImportInfo]
WebsiteCrawlCase = Case[Literal["website_crawl"], WebsiteCrawlInfo]
AnyCase = Union[LocalFileCase, UploadFileCase, NotionImportCase, WebsiteCrawlCase]
type AnyCase = LocalFileCase | UploadFileCase | NotionImportCase | WebsiteCrawlCase
case_1: LocalFileCase = {

View File

@@ -77,7 +77,6 @@ class TestAuthType:
def test_auth_type_immutability(self):
"""Test that enum values cannot be modified"""
# In Python 3.11+, enum members are read-only
with pytest.raises(AttributeError):
AuthType.FIRECRAWL = "modified"

View File

@@ -1,6 +1,5 @@
import json
import operator
from typing import TypeVar
from unittest.mock import Mock, patch
import httpx
@@ -16,10 +15,8 @@ from core.tools.entities.tool_entities import (
ToolInvokeMessage,
)
_T = TypeVar("_T")
def _get_message_by_type(msgs: list[ToolInvokeMessage], msg_type: type[_T]) -> ToolInvokeMessage | None:
def _get_message_by_type[T](msgs: list[ToolInvokeMessage], msg_type: type[T]) -> ToolInvokeMessage | None:
return next((i for i in msgs if isinstance(i.message, msg_type)), None)

892
api/uv.lock generated

File diff suppressed because it is too large Load Diff