chore(api): align Python support with 3.12 (#34419)

Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
99
2026-04-02 13:07:32 +08:00
committed by GitHub
parent cb9ee5903a
commit 8f9dbf269e
97 changed files with 410 additions and 1441 deletions

View File

@@ -10,7 +10,7 @@ import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager from contextlib import AbstractContextManager, contextmanager
from typing import Any, Protocol, TypeVar, final, runtime_checkable from typing import Any, Protocol, final, runtime_checkable
from pydantic import BaseModel from pydantic import BaseModel
@@ -188,8 +188,6 @@ class ExecutionContextBuilder:
_capturer: Callable[[], IExecutionContext] | None = None _capturer: Callable[[], IExecutionContext] | None = None
_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {} _tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
T = TypeVar("T", bound=BaseModel)
class ContextProviderNotFoundError(KeyError): class ContextProviderNotFoundError(KeyError):
"""Raised when a tenant-scoped context provider is missing.""" """Raised when a tenant-scoped context provider is missing."""

View File

@@ -1,7 +1,4 @@
from contextvars import ContextVar from contextvars import ContextVar
from typing import Generic, TypeVar
T = TypeVar("T")
class HiddenValue: class HiddenValue:
@@ -11,7 +8,7 @@ class HiddenValue:
_default = HiddenValue() _default = HiddenValue()
class RecyclableContextVar(Generic[T]): class RecyclableContextVar[T]:
""" """
RecyclableContextVar is a wrapper around ContextVar RecyclableContextVar is a wrapper around ContextVar
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now

View File

@@ -1,14 +1,14 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, TypeAlias from typing import Any
from graphon.file import helpers as file_helpers from graphon.file import helpers as file_helpers
from pydantic import BaseModel, ConfigDict, computed_field from pydantic import BaseModel, ConfigDict, computed_field
from models.model import IconType from models.model import IconType
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any] type JSONValue = str | int | float | bool | None | dict[str, Any] | list[Any]
JSONObject: TypeAlias = dict[str, Any] type JSONObject = dict[str, Any]
class SystemParameters(BaseModel): class SystemParameters(BaseModel):

View File

@@ -2,7 +2,6 @@ import csv
import io import io
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
@@ -20,9 +19,6 @@ from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from services.billing_service import BillingService from services.billing_service import BillingService
P = ParamSpec("P")
R = TypeVar("R")
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -72,9 +68,9 @@ console_ns.schema_model(
) )
def admin_required(view: Callable[P, R]): def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.ADMIN_API_KEY: if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.") raise Unauthorized("API key is invalid.")

View File

@@ -1,7 +1,7 @@
import logging import logging
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Any, Literal, TypeAlias from typing import Any, Literal
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
@@ -152,7 +152,7 @@ class AppTracePayload(BaseModel):
return value return value
JSONValue: TypeAlias = Any type JSONValue = Any
class ResponseModel(BaseModel): class ResponseModel(BaseModel):

View File

@@ -1,7 +1,7 @@
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Any, NoReturn, ParamSpec, TypeVar from typing import Any
from flask import Response, request from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with from flask_restx import Resource, fields, marshal, marshal_with
@@ -192,11 +192,8 @@ workflow_draft_variable_list_model = console_ns.model(
"WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy "WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy
) )
P = ParamSpec("P")
R = TypeVar("R")
def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
def _api_prerequisite(f: Callable[P, R]):
"""Common prerequisites for all draft workflow variable APIs. """Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied: It ensures the following conditions are satisfied:
@@ -213,7 +210,7 @@ def _api_prerequisite(f: Callable[P, R]):
@edit_permission_required @edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@wraps(f) @wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs): def wrapper(*args: Any, **kwargs: Any):
return f(*args, **kwargs) return f(*args, **kwargs)
return wrapper return wrapper
@@ -270,7 +267,7 @@ class WorkflowVariableCollectionApi(Resource):
return Response("", 204) return Response("", 204)
def validate_node_id(node_id: str) -> NoReturn | None: def validate_node_id(node_id: str) -> None:
if node_id in [ if node_id in [
CONVERSATION_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID,
@@ -285,7 +282,6 @@ def validate_node_id(node_id: str) -> NoReturn | None:
raise InvalidArgumentError( raise InvalidArgumentError(
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}", f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
) )
return None
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar, Union from typing import Any
from sqlalchemy import select from sqlalchemy import select
@@ -9,11 +9,6 @@ from extensions.ext_database import db
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from models import App, AppMode from models import App, AppMode
P = ParamSpec("P")
R = TypeVar("R")
P1 = ParamSpec("P1")
R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None: def _load_app_model(app_id: str) -> App | None:
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
@@ -28,10 +23,14 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
return app_model return app_model
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None): def get_app_model(
def decorator(view_func: Callable[P1, R1]): view: Callable[..., Any] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(view_func) @wraps(view_func)
def decorated_view(*args: P1.args, **kwargs: P1.kwargs): def decorated_view(*args: Any, **kwargs: Any):
if not kwargs.get("app_id"): if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters") raise ValueError("missing app_id in path parameters")
@@ -69,10 +68,14 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
return decorator(view) return decorator(view)
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None): def get_app_model_with_trial(
def decorator(view_func: Callable[P, R]): view: Callable[..., Any] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(view_func) @wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs): def decorated_view(*args: Any, **kwargs: Any):
if not kwargs.get("app_id"): if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters") raise ValueError("missing app_id in path parameters")

View File

@@ -1,8 +1,9 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar from typing import Concatenate
from flask import jsonify, request from flask import jsonify, request
from flask.typing import ResponseReturnValue
from flask_restx import Resource from flask_restx import Resource
from graphon.model_runtime.utils.encoders import jsonable_encoder from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel from pydantic import BaseModel
@@ -16,10 +17,6 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
from .. import console_ns from .. import console_ns
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
class OAuthClientPayload(BaseModel): class OAuthClientPayload(BaseModel):
client_id: str client_id: str
@@ -39,9 +36,11 @@ class OAuthTokenRequest(BaseModel):
refresh_token: str | None = None refresh_token: str | None = None
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]): def oauth_server_client_id_required[T, **P, R](
view: Callable[Concatenate[T, OAuthProviderApp, P], R],
) -> Callable[Concatenate[T, P], R]:
@wraps(view) @wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs): def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
json_data = request.get_json() json_data = request.get_json()
if json_data is None: if json_data is None:
raise BadRequest("client_id is required") raise BadRequest("client_id is required")
@@ -58,9 +57,13 @@ def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderA
return decorated return decorated
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]): def oauth_server_access_token_required[T, **P, R](
view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R],
) -> Callable[Concatenate[T, OAuthProviderApp, P], R | ResponseReturnValue]:
@wraps(view) @wraps(view)
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs): def decorated(
self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs
) -> R | ResponseReturnValue:
if not isinstance(oauth_provider_app, OAuthProviderApp): if not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app") raise BadRequest("Invalid oauth_provider_app")

View File

@@ -1,6 +1,5 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from sqlalchemy import select from sqlalchemy import select
@@ -9,13 +8,10 @@ from extensions.ext_database import db
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from models.dataset import Pipeline from models.dataset import Pipeline
P = ParamSpec("P")
R = TypeVar("R")
def get_rag_pipeline[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
def get_rag_pipeline(view_func: Callable[P, R]):
@wraps(view_func) @wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs): def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
if not kwargs.get("pipeline_id"): if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters") raise ValueError("missing pipeline_id in path parameters")

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar from typing import Concatenate
from flask import abort from flask import abort
from flask_restx import Resource from flask_restx import Resource
@@ -15,12 +15,8 @@ from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def installed_app_required[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]): def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view) @wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
@@ -49,7 +45,7 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
return decorator return decorator
def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | None = None): def user_allowed_to_access_app[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]): def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view) @wraps(view)
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs): def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
@@ -73,7 +69,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
return decorator return decorator
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None): def trial_app_required[**P, R](view: Callable[Concatenate[App, P], R] | None = None):
def decorator(view: Callable[Concatenate[App, P], R]): def decorator(view: Callable[Concatenate[App, P], R]):
@wraps(view) @wraps(view)
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs): def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
@@ -106,7 +102,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
return decorator return decorator
def trial_feature_enable(view: Callable[P, R]): def trial_feature_enable[**P, R](view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features() features = FeatureService.get_system_features()
@@ -117,7 +113,7 @@ def trial_feature_enable(view: Callable[P, R]):
return decorated return decorated
def explore_banner_enabled(view: Callable[P, R]): def explore_banner_enabled[**P, R](view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features() features = FeatureService.get_system_features()

View File

@@ -1,6 +1,5 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@@ -9,17 +8,14 @@ from extensions.ext_database import db
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from models.account import TenantPluginPermission from models.account import TenantPluginPermission
P = ParamSpec("P")
R = TypeVar("R")
def plugin_permission_required( def plugin_permission_required(
install_required: bool = False, install_required: bool = False,
debug_required: bool = False, debug_required: bool = False,
): ):
def interceptor(view: Callable[P, R]): def interceptor[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
user = current_user user = current_user
tenant_id = current_tenant_id tenant_id = current_tenant_id

View File

@@ -4,7 +4,6 @@ import os
import time import time
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request from flask import abort, request
from sqlalchemy import select from sqlalchemy import select
@@ -25,9 +24,6 @@ from services.operation_service import OperationService
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
P = ParamSpec("P")
R = TypeVar("R")
# Field names for decryption # Field names for decryption
FIELD_NAME_PASSWORD = "password" FIELD_NAME_PASSWORD = "password"
FIELD_NAME_CODE = "code" FIELD_NAME_CODE = "code"
@@ -37,7 +33,7 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code" ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]: def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R: def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
# check account initialization # check account initialization
@@ -50,7 +46,7 @@ def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
return decorated return decorated
def only_edition_cloud(view: Callable[P, R]): def only_edition_cloud[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "CLOUD": if dify_config.EDITION != "CLOUD":
@@ -61,7 +57,7 @@ def only_edition_cloud(view: Callable[P, R]):
return decorated return decorated
def only_edition_enterprise(view: Callable[P, R]): def only_edition_enterprise[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ENTERPRISE_ENABLED: if not dify_config.ENTERPRISE_ENABLED:
@@ -72,7 +68,7 @@ def only_edition_enterprise(view: Callable[P, R]):
return decorated return decorated
def only_edition_self_hosted(view: Callable[P, R]): def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "SELF_HOSTED": if dify_config.EDITION != "SELF_HOSTED":
@@ -83,7 +79,7 @@ def only_edition_self_hosted(view: Callable[P, R]):
return decorated return decorated
def cloud_edition_billing_enabled(view: Callable[P, R]): def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
@@ -95,7 +91,7 @@ def cloud_edition_billing_enabled(view: Callable[P, R]):
return decorated return decorated
def cloud_edition_billing_resource_check(resource: str): def cloud_edition_billing_resource_check[**P, R](resource: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
@@ -137,7 +133,9 @@ def cloud_edition_billing_resource_check(resource: str):
return interceptor return interceptor
def cloud_edition_billing_knowledge_limit_check(resource: str): def cloud_edition_billing_knowledge_limit_check[**P, R](
resource: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
@@ -160,7 +158,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
return interceptor return interceptor
def cloud_edition_billing_rate_limit_check(resource: str): def cloud_edition_billing_rate_limit_check[**P, R](resource: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
@@ -196,7 +194,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
return interceptor return interceptor
def cloud_utm_record(view: Callable[P, R]): def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
@@ -215,7 +213,7 @@ def cloud_utm_record(view: Callable[P, R]):
return decorated return decorated
def setup_required(view: Callable[P, R]) -> Callable[P, R]: def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R: def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
# check setup # check setup
@@ -229,7 +227,7 @@ def setup_required(view: Callable[P, R]) -> Callable[P, R]:
return decorated return decorated
def enterprise_license_required(view: Callable[P, R]): def enterprise_license_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
settings = FeatureService.get_system_features() settings = FeatureService.get_system_features()
@@ -241,7 +239,7 @@ def enterprise_license_required(view: Callable[P, R]):
return decorated return decorated
def email_password_login_enabled(view: Callable[P, R]): def email_password_login_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features() features = FeatureService.get_system_features()
@@ -254,7 +252,7 @@ def email_password_login_enabled(view: Callable[P, R]):
return decorated return decorated
def email_register_enabled(view: Callable[P, R]): def email_register_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features() features = FeatureService.get_system_features()
@@ -267,7 +265,7 @@ def email_register_enabled(view: Callable[P, R]):
return decorated return decorated
def enable_change_email(view: Callable[P, R]): def enable_change_email[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features() features = FeatureService.get_system_features()
@@ -280,7 +278,7 @@ def enable_change_email(view: Callable[P, R]):
return decorated return decorated
def is_allow_transfer_owner(view: Callable[P, R]): def is_allow_transfer_owner[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
from libs.workspace_permission import check_workspace_owner_transfer_permission from libs.workspace_permission import check_workspace_owner_transfer_permission
@@ -293,7 +291,7 @@ def is_allow_transfer_owner(view: Callable[P, R]):
return decorated return decorated
def knowledge_pipeline_publish_enabled(view: Callable[P, R]): def knowledge_pipeline_publish_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
@@ -305,7 +303,7 @@ def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
return decorated return decorated
def edit_permission_required(f: Callable[P, R]): def edit_permission_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
@wraps(f) @wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs): def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@@ -323,7 +321,7 @@ def edit_permission_required(f: Callable[P, R]):
return decorated_function return decorated_function
def is_admin_or_owner_required(f: Callable[P, R]): def is_admin_or_owner_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
@wraps(f) @wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs): def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@@ -339,7 +337,7 @@ def is_admin_or_owner_required(f: Callable[P, R]):
return decorated_function return decorated_function
def annotation_import_rate_limit(view: Callable[P, R]): def annotation_import_rate_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]:
""" """
Rate limiting decorator for annotation import operations. Rate limiting decorator for annotation import operations.
@@ -388,7 +386,7 @@ def annotation_import_rate_limit(view: Callable[P, R]):
return decorated return decorated
def annotation_import_concurrency_limit(view: Callable[P, R]): def annotation_import_concurrency_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]:
""" """
Concurrency control decorator for annotation import operations. Concurrency control decorator for annotation import operations.
@@ -455,7 +453,7 @@ def _decrypt_field(field_name: str, error_class: type[Exception], error_message:
payload[field_name] = decoded_value payload[field_name] = decoded_value
def decrypt_password_field(view: Callable[P, R]): def decrypt_password_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
""" """
Decorator to decrypt password field in request payload. Decorator to decrypt password field in request payload.
@@ -477,7 +475,7 @@ def decrypt_password_field(view: Callable[P, R]):
return decorated return decorated
def decrypt_code_field(view: Callable[P, R]): def decrypt_code_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
""" """
Decorator to decrypt verification code field in request payload. Decorator to decrypt verification code field in request payload.

View File

@@ -1,6 +1,5 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
@@ -13,9 +12,6 @@ from libs.login import current_user
from models.account import Tenant from models.account import Tenant
from models.model import DefaultEndUserSessionID, EndUser from models.model import DefaultEndUserSessionID, EndUser
P = ParamSpec("P")
R = TypeVar("R")
class TenantUserPayload(BaseModel): class TenantUserPayload(BaseModel):
tenant_id: str tenant_id: str
@@ -65,9 +61,9 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
return user_model return user_model
def get_user_tenant(view_func: Callable[P, R]): def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func) @wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs): def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {}) payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
user_id = payload.user_id user_id = payload.user_id
@@ -97,10 +93,14 @@ def get_user_tenant(view_func: Callable[P, R]):
return decorated_view return decorated_view
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]): def plugin_data[**P, R](
def decorator(view_func: Callable[P, R]): view: Callable[P, R] | None = None,
*,
payload_type: type[BaseModel],
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func) @wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs): def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
try: try:
data = request.get_json() data = request.get_json()
except Exception: except Exception:

View File

@@ -3,10 +3,7 @@ from collections.abc import Callable
from functools import wraps from functools import wraps
from hashlib import sha1 from hashlib import sha1
from hmac import new as hmac_new from hmac import new as hmac_new
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
from flask import abort, request from flask import abort, request
from configs import dify_config from configs import dify_config
@@ -14,9 +11,9 @@ from extensions.ext_database import db
from models.model import EndUser from models.model import EndUser
def billing_inner_api_only(view: Callable[P, R]): def billing_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.INNER_API: if not dify_config.INNER_API:
abort(404) abort(404)
@@ -30,9 +27,9 @@ def billing_inner_api_only(view: Callable[P, R]):
return decorated return decorated
def enterprise_inner_api_only(view: Callable[P, R]): def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.INNER_API: if not dify_config.INNER_API:
abort(404) abort(404)
@@ -46,9 +43,9 @@ def enterprise_inner_api_only(view: Callable[P, R]):
return decorated return decorated
def enterprise_inner_api_user_auth(view: Callable[P, R]): def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.INNER_API: if not dify_config.INNER_API:
return view(*args, **kwargs) return view(*args, **kwargs)
@@ -82,9 +79,9 @@ def enterprise_inner_api_user_auth(view: Callable[P, R]):
return decorated return decorated
def plugin_inner_api_only(view: Callable[P, R]): def plugin_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.PLUGIN_DAEMON_KEY: if not dify_config.PLUGIN_DAEMON_KEY:
abort(404) abort(404)

View File

@@ -3,7 +3,7 @@ import time
from collections.abc import Callable from collections.abc import Callable
from enum import StrEnum, auto from enum import StrEnum, auto
from functools import wraps from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar, cast, overload from typing import Any, cast, overload
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
@@ -23,10 +23,6 @@ from services.api_token_service import ApiTokenCache, fetch_token_with_single_fl
from services.end_user_service import EndUserService from services.end_user_service import EndUserService
from services.feature_service import FeatureService from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -46,16 +42,16 @@ class FetchUserArg(BaseModel):
@overload @overload
def validate_app_token(view: Callable[P, R]) -> Callable[P, R]: ... def validate_app_token[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
@overload @overload
def validate_app_token( def validate_app_token[**P, R](
view: None = None, *, fetch_user_arg: FetchUserArg | None = None view: None = None, *, fetch_user_arg: FetchUserArg | None = None
) -> Callable[[Callable[P, R]], Callable[P, R]]: ... ) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def validate_app_token( def validate_app_token[**P, R](
view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: ) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]: def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@@ -136,7 +132,10 @@ def validate_app_token(
return decorator(view) return decorator(view)
def cloud_edition_billing_resource_check(resource: str, api_token_type: str): def cloud_edition_billing_resource_check[**P, R](
resource: str,
api_token_type: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]): def interceptor(view: Callable[P, R]):
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type) api_token = validate_and_get_api_token(api_token_type)
@@ -166,7 +165,10 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
return interceptor return interceptor
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str): def cloud_edition_billing_knowledge_limit_check[**P, R](
resource: str,
api_token_type: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
@@ -188,7 +190,10 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
return interceptor return interceptor
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): def cloud_edition_billing_rate_limit_check[**P, R](
resource: str,
api_token_type: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
@@ -225,20 +230,12 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
return interceptor return interceptor
@overload
def validate_dataset_token(view: Callable[Concatenate[T, P], R]) -> Callable[P, R]: ...
@overload
def validate_dataset_token(view: None = None) -> Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]: ...
def validate_dataset_token( def validate_dataset_token(
view: Callable[Concatenate[T, P], R] | None = None, view: Callable[..., Any] | None = None,
) -> Callable[P, R] | Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]: ) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[Concatenate[T, P], R]) -> Callable[P, R]: def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(view_func) @wraps(view_func)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R: def decorated(*args: Any, **kwargs: Any) -> Any:
api_token = validate_and_get_api_token("dataset") api_token = validate_and_get_api_token("dataset")
# get url path dataset_id from positional args or kwargs # get url path dataset_id from positional args or kwargs
@@ -308,7 +305,10 @@ def validate_dataset_token(
raise Unauthorized("Tenant owner account does not exist.") raise Unauthorized("Tenant owner account does not exist.")
else: else:
raise Unauthorized("Tenant does not exist.") raise Unauthorized("Tenant does not exist.")
return view_func(api_token.tenant_id, *args, **kwargs) # type: ignore[arg-type] if args and isinstance(args[0], Resource):
return view_func(args[0], api_token.tenant_id, *args[1:], **kwargs)
return view_func(api_token.tenant_id, *args, **kwargs)
return decorated return decorated

View File

@@ -1,7 +1,7 @@
from collections.abc import Callable from collections.abc import Callable
from datetime import UTC, datetime from datetime import UTC, datetime
from functools import wraps from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar from typing import Concatenate
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
@@ -20,14 +20,13 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService from services.webapp_auth_service import WebAppAuthService
P = ParamSpec("P")
R = TypeVar("R")
def validate_jwt_token[**P, R](
def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None = None): view: Callable[Concatenate[App, EndUser, P], R] | None = None,
def decorator(view: Callable[Concatenate[App, EndUser, P], R]): ) -> Callable[P, R] | Callable[[Callable[Concatenate[App, EndUser, P], R]], Callable[P, R]]:
def decorator(view: Callable[Concatenate[App, EndUser, P], R]) -> Callable[P, R]:
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
app_model, end_user = decode_jwt_token() app_model, end_user = decode_jwt_token()
return view(app_model, end_user, *args, **kwargs) return view(app_model, end_user, *args, **kwargs)
@@ -38,7 +37,7 @@ def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None =
return decorator return decorator
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None): def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) -> tuple[App, EndUser]:
system_features = FeatureService.get_system_features() system_features = FeatureService.get_system_features()
if not app_code: if not app_code:
app_code = str(request.headers.get(HEADER_NAME_APP_CODE)) app_code = str(request.headers.get(HEADER_NAME_APP_CODE))

View File

@@ -5,7 +5,7 @@ import logging
import threading import threading
import uuid import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, Union, overload from typing import TYPE_CHECKING, Any, Literal, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
@@ -68,7 +68,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self, self,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
workflow_run_id: str, workflow_run_id: str,
@@ -81,7 +81,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self, self,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
workflow_run_id: str, workflow_run_id: str,
@@ -94,7 +94,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self, self,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
workflow_run_id: str, workflow_run_id: str,
@@ -106,7 +106,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self, self,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
workflow_run_id: str, workflow_run_id: str,
@@ -239,7 +239,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
*, *,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
@@ -271,9 +271,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
user: Account | EndUser, user: Account | EndUser,
args: Mapping, args: Mapping[str, Any],
streaming: bool = True, streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
""" """
Generate App response. Generate App response.
@@ -359,7 +359,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Account | EndUser, user: Account | EndUser,
args: LoopNodeRunPayload, args: LoopNodeRunPayload,
streaming: bool = True, streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
""" """
Generate App response. Generate App response.
@@ -439,7 +439,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self, self,
*, *,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_execution_repository: WorkflowExecutionRepository, workflow_execution_repository: WorkflowExecutionRepository,
@@ -451,7 +451,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
pause_state_config: PauseStateLayerConfig | None = None, pause_state_config: PauseStateLayerConfig | None = None,
graph_runtime_state: GraphRuntimeState | None = None, graph_runtime_state: GraphRuntimeState | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (), graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
""" """
Generate App response. Generate App response.
@@ -653,10 +653,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: ConversationSnapshot, conversation: ConversationSnapshot,
message: MessageSnapshot, message: MessageSnapshot,
user: Union[Account, EndUser], user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory, draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False, stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: ) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
""" """
Handle response. Handle response.
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity

View File

@@ -3,7 +3,7 @@ import logging
import threading import threading
import uuid import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload from typing import Any, Literal, overload
from flask import Flask, current_app from flask import Flask, current_app
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
@@ -37,7 +37,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
self, self,
*, *,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[False], streaming: Literal[False],
@@ -48,7 +48,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
self, self,
*, *,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[True], streaming: Literal[True],
@@ -59,21 +59,21 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
self, self,
*, *,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool, streaming: bool,
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ... ) -> Mapping | Generator[Mapping | str, None, None]: ...
def generate( def generate(
self, self,
*, *,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ) -> Mapping | Generator[Mapping | str, None, None]:
""" """
Generate App response. Generate App response.

View File

@@ -3,7 +3,7 @@ import logging
import threading import threading
import uuid import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload from typing import Any, Literal, overload
from flask import Flask, copy_current_request_context, current_app from flask import Flask, copy_current_request_context, current_app
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
@@ -36,7 +36,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
def generate( def generate(
self, self,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[True], streaming: Literal[True],
@@ -46,7 +46,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
def generate( def generate(
self, self,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[False], streaming: Literal[False],
@@ -56,20 +56,20 @@ class ChatAppGenerator(MessageBasedAppGenerator):
def generate( def generate(
self, self,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool, streaming: bool,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ... ) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]: ...
def generate( def generate(
self, self,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]:
""" """
Generate App response. Generate App response.

View File

@@ -3,7 +3,7 @@ import logging
import threading import threading
import uuid import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload from typing import Any, Literal, overload
from flask import Flask, copy_current_request_context, current_app from flask import Flask, copy_current_request_context, current_app
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
@@ -36,7 +36,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
def generate( def generate(
self, self,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[True], streaming: Literal[True],
@@ -46,7 +46,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
def generate( def generate(
self, self,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[False], streaming: Literal[False],
@@ -56,20 +56,20 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
def generate( def generate(
self, self,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = False, streaming: bool = False,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ... ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: ...
def generate( def generate(
self, self,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
""" """
Generate App response. Generate App response.
@@ -244,10 +244,10 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
self, self,
app_model: App, app_model: App,
message_id: str, message_id: str,
user: Union[Account, EndUser], user: Account | EndUser,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = True, stream: bool = True,
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ) -> Mapping | Generator[Mapping | str, None, None]:
""" """
Generate App response. Generate App response.

View File

@@ -7,7 +7,7 @@ import threading
import time import time
import uuid import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, cast, overload from typing import Any, Literal, cast, overload
from flask import Flask, current_app from flask import Flask, current_app
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
@@ -62,7 +62,7 @@ class PipelineGenerator(BaseAppGenerator):
*, *,
pipeline: Pipeline, pipeline: Pipeline,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[True], streaming: Literal[True],
@@ -77,7 +77,7 @@ class PipelineGenerator(BaseAppGenerator):
*, *,
pipeline: Pipeline, pipeline: Pipeline,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[False], streaming: Literal[False],
@@ -92,28 +92,28 @@ class PipelineGenerator(BaseAppGenerator):
*, *,
pipeline: Pipeline, pipeline: Pipeline,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool, streaming: bool,
call_depth: int, call_depth: int,
workflow_thread_pool_id: str | None, workflow_thread_pool_id: str | None,
is_retry: bool = False, is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... ) -> Mapping[str, Any] | Generator[Mapping | str, None, None]: ...
def generate( def generate(
self, self,
*, *,
pipeline: Pipeline, pipeline: Pipeline,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
call_depth: int = 0, call_depth: int = 0,
workflow_thread_pool_id: str | None = None, workflow_thread_pool_id: str | None = None,
is_retry: bool = False, is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None:
# Add null check for dataset # Add null check for dataset
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
@@ -278,7 +278,7 @@ class PipelineGenerator(BaseAppGenerator):
context: contextvars.Context, context: contextvars.Context,
pipeline: Pipeline, pipeline: Pipeline,
workflow_id: str, workflow_id: str,
user: Union[Account, EndUser], user: Account | EndUser,
application_generate_entity: RagPipelineGenerateEntity, application_generate_entity: RagPipelineGenerateEntity,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository, workflow_execution_repository: WorkflowExecutionRepository,
@@ -286,7 +286,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: bool = True, streaming: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
workflow_thread_pool_id: str | None = None, workflow_thread_pool_id: str | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
""" """
Generate App response. Generate App response.
@@ -624,10 +624,10 @@ class PipelineGenerator(BaseAppGenerator):
application_generate_entity: RagPipelineGenerateEntity, application_generate_entity: RagPipelineGenerateEntity,
workflow: Workflow, workflow: Workflow,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
user: Union[Account, EndUser], user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory, draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False, stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: ) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
""" """
Handle response. Handle response.
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@@ -668,7 +668,7 @@ class PipelineGenerator(BaseAppGenerator):
datasource_info: Mapping[str, Any], datasource_info: Mapping[str, Any],
created_from: str, created_from: str,
position: int, position: int,
account: Union[Account, EndUser], account: Account | EndUser,
batch: str, batch: str,
document_form: str, document_form: str,
): ):
@@ -715,7 +715,7 @@ class PipelineGenerator(BaseAppGenerator):
pipeline: Pipeline, pipeline: Pipeline,
workflow: Workflow, workflow: Workflow,
start_node_id: str, start_node_id: str,
user: Union[Account, EndUser], user: Account | EndUser,
) -> list[Mapping[str, Any]]: ) -> list[Mapping[str, Any]]:
""" """
Format datasource info list. Format datasource info list.

View File

@@ -5,7 +5,7 @@ import logging
import threading import threading
import uuid import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, Union, overload from typing import TYPE_CHECKING, Any, Literal, overload
from flask import Flask, current_app from flask import Flask, current_app
from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_engine.layers import GraphEngineLayer
@@ -64,7 +64,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*, *,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[True], streaming: Literal[True],
@@ -82,7 +82,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*, *,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[False], streaming: Literal[False],
@@ -100,7 +100,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*, *,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool, streaming: bool,
@@ -110,14 +110,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
root_node_id: str | None = None, root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (), graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None, pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ... ) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]: ...
def generate( def generate(
self, self,
*, *,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
@@ -127,7 +127,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
root_node_id: str | None = None, root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (), graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None, pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]:
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
files: Sequence[Mapping[str, Any]] = args.get("files") or [] files: Sequence[Mapping[str, Any]] = args.get("files") or []
@@ -237,7 +237,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*, *,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
workflow_execution_repository: WorkflowExecutionRepository, workflow_execution_repository: WorkflowExecutionRepository,
@@ -245,7 +245,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
graph_engine_layers: Sequence[GraphEngineLayer] = (), graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None, pause_state_config: PauseStateLayerConfig | None = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
""" """
Resume a paused workflow execution using the persisted runtime state. Resume a paused workflow execution using the persisted runtime state.
""" """
@@ -269,7 +269,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*, *,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository, workflow_execution_repository: WorkflowExecutionRepository,
@@ -280,7 +280,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
graph_engine_layers: Sequence[GraphEngineLayer] = (), graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None, graph_runtime_state: GraphRuntimeState | None = None,
pause_state_config: PauseStateLayerConfig | None = None, pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
""" """
Generate App response. Generate App response.
@@ -609,10 +609,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow, workflow: Workflow,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
user: Union[Account, EndUser], user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory, draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False, stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: ) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
""" """
Handle response. Handle response.
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Annotated, Literal, Self, TypeAlias from typing import Annotated, Literal, Self
from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent
@@ -27,7 +27,7 @@ class _AdvancedChatAppGenerateEntityWrapper(BaseModel):
entity: AdvancedChatAppGenerateEntity entity: AdvancedChatAppGenerateEntity
_GenerateEntityUnion: TypeAlias = Annotated[ type _GenerateEntityUnion = Annotated[
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper, _WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
Field(discriminator="type"), Field(discriminator="type"),
] ]

View File

@@ -2,7 +2,7 @@ import logging
import time import time
from collections.abc import Generator from collections.abc import Generator
from threading import Thread from threading import Thread
from typing import Any, Union, cast from typing import Any, cast
from graphon.file import FileTransferMethod from graphon.file import FileTransferMethod
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@@ -72,14 +72,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
""" """
_task_state: EasyUITaskState _task_state: EasyUITaskState
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity] _application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity
_precomputed_event_type: StreamEvent | None = None _precomputed_event_type: StreamEvent | None = None
def __init__( def __init__(
self, self,
application_generate_entity: Union[ application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity,
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
],
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
@@ -117,11 +115,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
def process( def process(
self, self,
) -> Union[ ) -> (
ChatbotAppBlockingResponse, ChatbotAppBlockingResponse
CompletionAppBlockingResponse, | CompletionAppBlockingResponse
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], | Generator[ChatbotAppStreamResponse | CompletionAppStreamResponse, None, None]
]: ):
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name( self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
@@ -136,7 +134,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
def _to_blocking_response( def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None] self, generator: Generator[StreamResponse, None, None]
) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]: ) -> ChatbotAppBlockingResponse | CompletionAppBlockingResponse:
""" """
Process blocking response. Process blocking response.
:return: :return:
@@ -148,7 +146,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
extras = {"usage": self._task_state.llm_result.usage.model_dump()} extras = {"usage": self._task_state.llm_result.usage.model_dump()}
if self._task_state.metadata: if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.model_dump() extras["metadata"] = self._task_state.metadata.model_dump()
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] response: ChatbotAppBlockingResponse | CompletionAppBlockingResponse
if self._conversation_mode == AppMode.COMPLETION: if self._conversation_mode == AppMode.COMPLETION:
response = CompletionAppBlockingResponse( response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@@ -183,7 +181,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
def _to_stream_response( def _to_stream_response(
self, generator: Generator[StreamResponse, None, None] self, generator: Generator[StreamResponse, None, None]
) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]: ) -> Generator[ChatbotAppStreamResponse | CompletionAppStreamResponse, None, None]:
""" """
To stream response. To stream response.
:return: :return:

View File

@@ -5,14 +5,13 @@ This layer centralizes model-quota deduction outside node implementations.
""" """
import logging import logging
from typing import TYPE_CHECKING, cast, final from typing import TYPE_CHECKING, cast, final, override
from graphon.enums import BuiltinNodeTypes from graphon.enums import BuiltinNodeTypes
from graphon.graph_engine.entities.commands import AbortCommand, CommandType from graphon.graph_engine.entities.commands import AbortCommand, CommandType
from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
from graphon.nodes.base.node import Node from graphon.nodes.base.node import Node
from typing_extensions import override
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available from core.app.llm import deduct_llm_quota, ensure_llm_quota_available

View File

@@ -10,7 +10,7 @@ associates with the node span.
import logging import logging
from contextvars import Token from contextvars import Token
from dataclasses import dataclass from dataclasses import dataclass
from typing import cast, final from typing import cast, final, override
from graphon.enums import BuiltinNodeTypes, NodeType from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_engine.layers import GraphEngineLayer
@@ -18,7 +18,6 @@ from graphon.graph_events import GraphNodeEventBase
from graphon.nodes.base.node import Node from graphon.nodes.base.node import Node
from opentelemetry import context as context_api from opentelemetry import context as context_api
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
from typing_extensions import override
from configs import dify_config from configs import dify_config
from extensions.otel.parser import ( from extensions.otel.parser import (

View File

@@ -44,7 +44,8 @@ class HumanInputContent(BaseModel):
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT) type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent # Keep a runtime alias here: callers and tests expect identity with HumanInputContent.
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent # noqa: UP040
__all__ = [ __all__ = [
"ExecutionExtraContentDomainModel", "ExecutionExtraContentDomainModel",

View File

@@ -2,12 +2,13 @@ import importlib.util
import logging import logging
import sys import sys
from types import ModuleType from types import ModuleType
from typing import AnyStr
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType: def import_module_from_source[T: (str, bytes)](
*, module_name: str, py_file_path: T, use_lazy_loader: bool = False
) -> ModuleType:
""" """
Importing a module from the source file directly Importing a module from the source file directly
""" """

View File

@@ -2,7 +2,6 @@ import os
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Callable from collections.abc import Callable
from functools import lru_cache from functools import lru_cache
from typing import TypeVar
from configs import dify_config from configs import dify_config
from core.tools.utils.yaml_utils import load_yaml_file_cached from core.tools.utils.yaml_utils import load_yaml_file_cached
@@ -65,10 +64,7 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str])
return position_map return position_map
T = TypeVar("T") def is_filtered[T](
def is_filtered(
include_set: set[str], include_set: set[str],
exclude_set: set[str], exclude_set: set[str],
data: T, data: T,
@@ -97,11 +93,11 @@ def is_filtered(
return False return False
def sort_by_position_map( def sort_by_position_map[T](
position_map: dict[str, int], position_map: dict[str, int],
data: list[T], data: list[T],
name_func: Callable[[T], str], name_func: Callable[[T], str],
): ) -> list[T]:
""" """
Sort the objects by the position map. Sort the objects by the position map.
If the name of the object is not in the position map, it will be put at the end. If the name of the object is not in the position map, it will be put at the end.
@@ -116,11 +112,11 @@ def sort_by_position_map(
return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf"))) return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf")))
def sort_to_dict_by_position_map( def sort_to_dict_by_position_map[T](
position_map: dict[str, int], position_map: dict[str, int],
data: list[T], data: list[T],
name_func: Callable[[T], str], name_func: Callable[[T], str],
): ) -> OrderedDict[str, T]:
""" """
Sort the objects into a ordered dict by the position map. Sort the objects into a ordered dict by the position map.
If the name of the object is not in the position map, it will be put at the end. If the name of the object is not in the position map, it will be put at the end.

View File

@@ -4,7 +4,7 @@ Proxy requests to avoid SSRF
import logging import logging
import time import time
from typing import Any, TypeAlias from typing import Any
import httpx import httpx
from pydantic import TypeAdapter, ValidationError from pydantic import TypeAdapter, ValidationError
@@ -20,8 +20,8 @@ SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
BACKOFF_FACTOR = 0.5 BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504] STATUS_FORCELIST = [429, 500, 502, 503, 504]
Headers: TypeAlias = dict[str, str] type Headers = dict[str, str]
_HEADERS_ADAPTER = TypeAdapter(Headers) _HEADERS_ADAPTER: TypeAdapter[Headers] = TypeAdapter(Headers)
_SSL_VERIFIED_POOL_KEY = "ssrf:verified" _SSL_VERIFIED_POOL_KEY = "ssrf:verified"
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified" _SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"

View File

@@ -3,7 +3,7 @@ import queue
from collections.abc import Generator from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, TypeAlias, final from typing import Any, final
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
import httpx import httpx
@@ -33,9 +33,9 @@ class _StatusError:
# Type aliases for better readability # Type aliases for better readability
ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None] type ReadQueue = queue.Queue[SessionMessage | Exception | None]
WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None] type WriteQueue = queue.Queue[SessionMessage | Exception | None]
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError] type StatusQueue = queue.Queue[_StatusReady | _StatusError]
class SSETransport: class SSETransport:

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import StrEnum from enum import StrEnum
from typing import Any, Generic, TypeVar from typing import Any, TypeVar
from pydantic import BaseModel from pydantic import BaseModel
@@ -9,13 +9,12 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION] SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT") LifespanContextT = TypeVar("LifespanContextT")
@dataclass @dataclass
class RequestContext(Generic[SessionT, LifespanContextT]): class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]:
request_id: RequestId request_id: RequestId
meta: RequestParams.Meta | None meta: RequestParams.Meta | None
session: SessionT session: SessionT

View File

@@ -4,7 +4,7 @@ from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from datetime import timedelta from datetime import timedelta
from types import TracebackType from types import TracebackType
from typing import Any, Generic, Self, TypeVar from typing import Any, Self, cast
from httpx import HTTPStatusError from httpx import HTTPStatusError
from pydantic import BaseModel from pydantic import BaseModel
@@ -34,16 +34,10 @@ from core.mcp.types import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
DEFAULT_RESPONSE_READ_TIMEOUT = 1.0 DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
class RequestResponder(Generic[ReceiveRequestT, SendResultT]): class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResultT: ClientResult | ServerResult]:
"""Handles responding to MCP requests and manages request lifecycle. """Handles responding to MCP requests and manages request lifecycle.
This class MUST be used as a context manager to ensure proper cleanup and This class MUST be used as a context manager to ensure proper cleanup and
@@ -60,7 +54,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
""" """
request: ReceiveRequestT request: ReceiveRequestT
_session: Any _session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]"
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any] _on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
def __init__( def __init__(
@@ -68,7 +62,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
request_id: RequestId, request_id: RequestId,
request_meta: RequestParams.Meta | None, request_meta: RequestParams.Meta | None,
request: ReceiveRequestT, request: ReceiveRequestT,
session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""", session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
): ):
self.request_id = request_id self.request_id = request_id
@@ -111,7 +105,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self.completed = True self.completed = True
self._session._send_response(request_id=self.request_id, response=response) self._session.send_response(request_id=self.request_id, response=response)
def cancel(self): def cancel(self):
"""Cancel this request and mark it as completed.""" """Cancel this request and mark it as completed."""
@@ -120,21 +114,19 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self.completed = True # Mark as completed so it's removed from in_flight self.completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation # Send an error response to indicate cancellation
self._session._send_response( self._session.send_response(
request_id=self.request_id, request_id=self.request_id,
response=ErrorData(code=0, message="Request cancelled", data=None), response=ErrorData(code=0, message="Request cancelled", data=None),
) )
class BaseSession( class BaseSession[
Generic[ SendRequestT: ClientRequest | ServerRequest,
SendRequestT, SendNotificationT: ClientNotification | ServerNotification,
SendNotificationT, SendResultT: ClientResult | ServerResult,
SendResultT, ReceiveRequestT: ClientRequest | ServerRequest,
ReceiveRequestT, ReceiveNotificationT: ClientNotification | ServerNotification,
ReceiveNotificationT, ]:
],
):
""" """
Implements an MCP "session" on top of read/write streams, including features Implements an MCP "session" on top of read/write streams, including features
like request/response linking, notifications, and progress. like request/response linking, notifications, and progress.
@@ -204,13 +196,13 @@ class BaseSession(
# The receiver thread should have already exited due to the None message in the queue # The receiver thread should have already exited due to the None message in the queue
self._executor.shutdown(wait=False) self._executor.shutdown(wait=False)
def send_request( def send_request[T: BaseModel](
self, self,
request: SendRequestT, request: SendRequestT,
result_type: type[ReceiveResultT], result_type: type[T],
request_read_timeout_seconds: timedelta | None = None, request_read_timeout_seconds: timedelta | None = None,
metadata: MessageMetadata | None = None, metadata: MessageMetadata | None = None,
) -> ReceiveResultT: ) -> T:
""" """
Sends a request and wait for a response. Raises an McpError if the Sends a request and wait for a response. Raises an McpError if the
response contains an error. If a request read timeout is provided, it response contains an error. If a request read timeout is provided, it
@@ -299,7 +291,7 @@ class BaseSession(
) )
self._write_stream.put(session_message) self._write_stream.put(session_message)
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData): def send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
if isinstance(response, ErrorData): if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
@@ -346,6 +338,7 @@ class BaseSession(
validated_request = self._receive_request_type.model_validate( validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
) )
validated_request = cast(ReceiveRequestT, validated_request)
responder = RequestResponder[ReceiveRequestT, SendResultT]( responder = RequestResponder[ReceiveRequestT, SendResultT](
request_id=message.message.root.id, request_id=message.message.root.id,
@@ -366,6 +359,7 @@ class BaseSession(
notification = self._receive_notification_type.model_validate( notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
) )
notification = cast(ReceiveNotificationT, notification)
# Handle cancellation notifications # Handle cancellation notifications
if isinstance(notification.root, CancelledNotification): if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId cancelled_id = notification.root.params.requestId

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar from typing import Annotated, Any, Literal
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
from pydantic.networks import AnyUrl, UrlConstraints from pydantic.networks import AnyUrl, UrlConstraints
@@ -31,7 +31,7 @@ ProgressToken = str | int
Cursor = str Cursor = str
Role = Literal["user", "assistant"] Role = Literal["user", "assistant"]
RequestId = Annotated[int | str, Field(union_mode="left_to_right")] RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
AnyFunction: TypeAlias = Callable[..., Any] type AnyFunction = Callable[..., Any]
class RequestParams(BaseModel): class RequestParams(BaseModel):
@@ -68,12 +68,7 @@ class NotificationParams(BaseModel):
""" """
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None) class Request[RequestParamsT: RequestParams | dict[str, Any] | None, MethodT: str](BaseModel):
NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams | dict[str, Any] | None)
MethodT = TypeVar("MethodT", bound=str)
class Request(BaseModel, Generic[RequestParamsT, MethodT]):
"""Base class for JSON-RPC requests.""" """Base class for JSON-RPC requests."""
method: MethodT method: MethodT
@@ -81,14 +76,14 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]): class PaginatedRequest[T: str](Request[PaginatedRequestParams | None, T]):
"""Base class for paginated requests, """Base class for paginated requests,
matching the schema's PaginatedRequest interface.""" matching the schema's PaginatedRequest interface."""
params: PaginatedRequestParams | None = None params: PaginatedRequestParams | None = None
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]): class Notification[NotificationParamsT: NotificationParams | dict[str, Any] | None, MethodT: str](BaseModel):
"""Base class for JSON-RPC notifications.""" """Base class for JSON-RPC notifications."""
method: MethodT method: MethodT
@@ -736,7 +731,7 @@ class ResourceLink(Resource):
ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource
"""A content block that can be used in prompts and tool results.""" """A content block that can be used in prompts and tool results."""
Content: TypeAlias = ContentBlock type Content = ContentBlock
# """DEPRECATED: Content is deprecated, you should use ContentBlock directly.""" # """DEPRECATED: Content is deprecated, you should use ContentBlock directly."""

View File

@@ -6,7 +6,7 @@ import queue
import threading import threading
import time import time
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, TypedDict
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from cachetools import LRUCache from cachetools import LRUCache
@@ -14,7 +14,6 @@ from flask import current_app
from pydantic import TypeAdapter from pydantic import TypeAdapter
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from typing_extensions import TypedDict
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import ( from core.ops.entities.config_entity import (
@@ -464,7 +463,7 @@ class OpsTraceManager:
@classmethod @classmethod
def get_ops_trace_instance( def get_ops_trace_instance(
cls, cls,
app_id: Union[UUID, str] | None = None, app_id: UUID | str | None = None,
): ):
""" """
Get ops trace through model config Get ops trace through model config
@@ -717,7 +716,7 @@ class TraceTask:
self, self,
trace_type: Any, trace_type: Any,
message_id: str | None = None, message_id: str | None = None,
workflow_execution: Optional["WorkflowExecution"] = None, workflow_execution: "WorkflowExecution | None" = None,
conversation_id: str | None = None, conversation_id: str | None = None,
user_id: str | None = None, user_id: str | None = None,
timer: Any | None = None, timer: Any | None = None,

View File

@@ -1,5 +1,4 @@
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Generic, TypeVar
from pydantic import BaseModel from pydantic import BaseModel
@@ -19,9 +18,6 @@ class BaseBackwardsInvocation:
yield BaseBackwardsInvocationResponse(data=response).model_dump_json().encode() yield BaseBackwardsInvocationResponse(data=response).model_dump_json().encode()
T = TypeVar("T", bound=dict | Mapping | str | bool | int | BaseModel) class BaseBackwardsInvocationResponse[T: dict | Mapping | str | bool | int | BaseModel](BaseModel):
class BaseBackwardsInvocationResponse(BaseModel, Generic[T]):
data: T | None = None data: T | None = None
error: str = "" error: str = ""

View File

@@ -4,7 +4,7 @@ import enum
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import datetime from datetime import datetime
from enum import StrEnum from enum import StrEnum
from typing import Any, Generic, TypeVar from typing import Any
from graphon.model_runtime.entities.model_entities import AIModelEntity from graphon.model_runtime.entities.model_entities import AIModelEntity
from graphon.model_runtime.entities.provider_entities import ProviderEntity from graphon.model_runtime.entities.provider_entities import ProviderEntity
@@ -19,10 +19,8 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
from core.trigger.entities.entities import TriggerProviderEntity from core.trigger.entities.entities import TriggerProviderEntity
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
class PluginDaemonBasicResponse[T: BaseModel | dict | list | bool | str](BaseModel):
class PluginDaemonBasicResponse(BaseModel, Generic[T]):
""" """
Basic response from plugin daemon. Basic response from plugin daemon.
""" """

View File

@@ -2,7 +2,7 @@ import inspect
import json import json
import logging import logging
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
from typing import Any, TypeVar, cast from typing import Any, cast
import httpx import httpx
from graphon.model_runtime.errors.invoke import ( from graphon.model_runtime.errors.invoke import (
@@ -51,8 +51,6 @@ elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout):
else: else:
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config) plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
T = TypeVar("T", bound=(BaseModel | dict[str, Any] | list[Any] | bool | str))
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_httpx_client: httpx.Client = get_pooled_http_client( _httpx_client: httpx.Client = get_pooled_http_client(
@@ -191,7 +189,7 @@ class BasePluginClient:
logger.exception("Stream request to Plugin Daemon Service failed") logger.exception("Stream request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed") raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
def _stream_request_with_model( def _stream_request_with_model[T: BaseModel | dict[str, Any] | list[Any] | bool | str](
self, self,
method: str, method: str,
path: str, path: str,
@@ -207,7 +205,7 @@ class BasePluginClient:
for line in self._stream_request(method, path, params, headers, data, files): for line in self._stream_request(method, path, params, headers, data, files):
yield type_(**json.loads(line)) # type: ignore yield type_(**json.loads(line)) # type: ignore
def _request_with_model( def _request_with_model[T: BaseModel | dict[str, Any] | list[Any] | bool | str](
self, self,
method: str, method: str,
path: str, path: str,
@@ -223,7 +221,7 @@ class BasePluginClient:
response = self._request(method, path, headers, data, params, files) response = self._request(method, path, headers, data, params, files)
return type_(**response.json()) # type: ignore[return-value] return type_(**response.json()) # type: ignore[return-value]
def _request_with_plugin_daemon_response( def _request_with_plugin_daemon_response[T: BaseModel | dict[str, Any] | list[Any] | bool | str](
self, self,
method: str, method: str,
path: str, path: str,
@@ -278,7 +276,7 @@ class BasePluginClient:
return rep.data return rep.data
def _request_with_plugin_daemon_response_stream( def _request_with_plugin_daemon_response_stream[T: BaseModel | dict[str, Any] | list[Any] | bool | str](
self, self,
method: str, method: str,
path: str, path: str,

View File

@@ -1,12 +1,9 @@
from collections.abc import Generator from collections.abc import Generator
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TypeVar, Union
from core.agent.entities import AgentInvokeMessage from core.agent.entities import AgentInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
MessageType = TypeVar("MessageType", bound=Union[ToolInvokeMessage, AgentInvokeMessage])
@dataclass @dataclass
class FileChunk: class FileChunk:
@@ -22,11 +19,11 @@ class FileChunk:
self.data = bytearray(self.total_length) self.data = bytearray(self.total_length)
def merge_blob_chunks( def merge_blob_chunks[T: ToolInvokeMessage | AgentInvokeMessage](
response: Generator[MessageType, None, None], response: Generator[T, None, None],
max_file_size: int = 30 * 1024 * 1024, max_file_size: int = 30 * 1024 * 1024,
max_chunk_size: int = 8192, max_chunk_size: int = 8192,
) -> Generator[MessageType, None, None]: ) -> Generator[T, None, None]:
""" """
Merge streaming blob chunks into complete blob messages. Merge streaming blob chunks into complete blob messages.

View File

@@ -1,6 +1,7 @@
from typing import TypedDict
from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
from typing_extensions import TypedDict
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.data_post_processor.reorder import ReorderRunner

View File

@@ -1,10 +1,9 @@
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any, TypedDict
import orjson import orjson
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import select from sqlalchemy import select
from typing_extensions import TypedDict
from configs import dify_config from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler

View File

@@ -1,13 +1,12 @@
import concurrent.futures import concurrent.futures
import logging import logging
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any, NotRequired from typing import Any, NotRequired, TypedDict
from flask import Flask, current_app from flask import Flask, current_app
from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session, load_only from sqlalchemy.orm import Session, load_only
from typing_extensions import TypedDict
from configs import dify_config from configs import dify_config
from core.db.session_factory import session_factory from core.db.session_factory import session_factory

View File

@@ -3,7 +3,7 @@ import logging
import uuid import uuid
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Any, Concatenate, ParamSpec, TypeVar from typing import Any, Concatenate
from mo_vector.client import MoVectorClient # type: ignore from mo_vector.client import MoVectorClient # type: ignore
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@@ -20,15 +20,12 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T", bound="MatrixoneVector") def ensure_client[T: MatrixoneVector, **P, R](
func: Callable[Concatenate[T, P], R],
) -> Callable[Concatenate[T, P], R]:
def ensure_client(func: Callable[Concatenate[T, P], R]):
@wraps(func) @wraps(func)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs): def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
if self.client is None: if self.client is None:
self.client = self._get_client(None, False) self.client = self._get_client(None, False)
return func(self, *args, **kwargs) return func(self, *args, **kwargs)

View File

@@ -3,7 +3,7 @@ import os
import uuid import uuid
from collections.abc import Generator, Iterable, Sequence from collections.abc import Generator, Iterable, Sequence
from itertools import islice from itertools import islice
from typing import TYPE_CHECKING, Any, Union from typing import TYPE_CHECKING, Any
import qdrant_client import qdrant_client
from flask import current_app from flask import current_app
@@ -36,8 +36,8 @@ if TYPE_CHECKING:
from qdrant_client.conversions import common_types from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest from qdrant_client.http import models as rest
DictFilter = dict[str, Union[str, int, bool, dict, list]] type DictFilter = dict[str, str | int | bool | dict | list]
MetadataFilter = Union[DictFilter, common_types.Filter] type MetadataFilter = DictFilter | common_types.Filter
class PathQdrantParams(BaseModel): class PathQdrantParams(BaseModel):

View File

@@ -3,7 +3,7 @@ import os
import uuid import uuid
from collections.abc import Generator, Iterable, Sequence from collections.abc import Generator, Iterable, Sequence
from itertools import islice from itertools import islice
from typing import TYPE_CHECKING, Any, Union from typing import TYPE_CHECKING, Any
import httpx import httpx
import qdrant_client import qdrant_client
@@ -40,8 +40,8 @@ if TYPE_CHECKING:
from qdrant_client.conversions import common_types from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest from qdrant_client.http import models as rest
DictFilter = dict[str, Union[str, int, bool, dict, list]] type DictFilter = dict[str, str | int | bool | dict | list]
MetadataFilter = Union[DictFilter, common_types.Filter] type MetadataFilter = DictFilter | common_types.Filter
class TidbOnQdrantConfig(BaseModel): class TidbOnQdrantConfig(BaseModel):

View File

@@ -1,5 +1,6 @@
from typing import TypedDict
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import TypedDict
from models.dataset import DocumentSegment from models.dataset import DocumentSegment

View File

@@ -12,11 +12,11 @@ import mimetypes
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from io import BufferedReader, BytesIO from io import BufferedReader, BytesIO
from pathlib import Path, PurePath from pathlib import Path, PurePath
from typing import Any, Union from typing import Any
from pydantic import BaseModel, ConfigDict, model_validator from pydantic import BaseModel, ConfigDict, model_validator
PathLike = Union[str, PurePath] type PathLike = str | PurePath
class Blob(BaseModel): class Blob(BaseModel):
@@ -29,7 +29,7 @@ class Blob(BaseModel):
Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
""" """
data: Union[bytes, str, None] = None # Raw data data: bytes | str | None = None # Raw data
mimetype: str | None = None # Not to be confused with a file extension mimetype: str | None = None # Not to be confused with a file extension
encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string
# Location where the original content was found # Location where the original content was found
@@ -75,7 +75,7 @@ class Blob(BaseModel):
raise ValueError(f"Unable to get bytes for blob {self}") raise ValueError(f"Unable to get bytes for blob {self}")
@contextlib.contextmanager @contextlib.contextmanager
def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]: def as_bytes_io(self) -> Generator[BytesIO | BufferedReader, None, None]:
"""Read data as a byte stream.""" """Read data as a byte stream."""
if isinstance(self.data, bytes): if isinstance(self.data, bytes):
yield BytesIO(self.data) yield BytesIO(self.data)
@@ -117,7 +117,7 @@ class Blob(BaseModel):
@classmethod @classmethod
def from_data( def from_data(
cls, cls,
data: Union[str, bytes], data: str | bytes,
*, *,
encoding: str = "utf-8", encoding: str = "utf-8",
mime_type: str | None = None, mime_type: str | None = None,

View File

@@ -1,9 +1,8 @@
import json import json
import time import time
from typing import Any, NotRequired, cast from typing import Any, NotRequired, TypedDict, cast
import httpx import httpx
from typing_extensions import TypedDict
from extensions.ext_storage import storage from extensions.ext_storage import storage

View File

@@ -1,11 +1,10 @@
import json import json
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Union from typing import Any, TypedDict
from urllib.parse import urljoin from urllib.parse import urljoin
import httpx import httpx
from httpx import Response from httpx import Response
from typing_extensions import TypedDict
from core.rag.extractor.watercrawl.exceptions import ( from core.rag.extractor.watercrawl.exceptions import (
WaterCrawlAuthenticationError, WaterCrawlAuthenticationError,
@@ -142,7 +141,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
def create_crawl_request( def create_crawl_request(
self, self,
url: Union[list, str] | None = None, url: list | str | None = None,
spider_options: SpiderOptions | None = None, spider_options: SpiderOptions | None = None,
page_options: PageOptions | None = None, page_options: PageOptions | None = None,
plugin_options: dict[str, Any] | None = None, plugin_options: dict[str, Any] | None = None,

View File

@@ -1,8 +1,6 @@
from collections.abc import Generator from collections.abc import Generator
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any, TypedDict
from typing_extensions import TypedDict
from core.rag.extractor.watercrawl.client import PageOptions, SpiderOptions, WaterCrawlAPIClient from core.rag.extractor.watercrawl.client import PageOptions, SpiderOptions, WaterCrawlAPIClient

View File

@@ -7,12 +7,11 @@ import os
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, NotRequired, Optional from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
from urllib.parse import unquote, urlparse from urllib.parse import unquote, urlparse
import httpx import httpx
from sqlalchemy import select from sqlalchemy import select
from typing_extensions import TypedDict
from configs import dify_config from configs import dify_config
from core.entities.knowledge_entities import PreviewDetail from core.entities.knowledge_entities import PreviewDetail
@@ -118,11 +117,12 @@ class BaseIndexProcessor(ABC):
max_tokens: int, max_tokens: int,
chunk_overlap: int, chunk_overlap: int,
separator: str, separator: str,
embedding_model_instance: Optional["ModelInstance"], embedding_model_instance: "ModelInstance | None",
) -> TextSplitter: ) -> TextSplitter:
""" """
Get the NodeParser object according to the processing rule. Get the NodeParser object according to the processing rule.
""" """
character_splitter: TextSplitter
if processing_rule_mode in ["custom", "hierarchical"]: if processing_rule_mode in ["custom", "hierarchical"]:
# The user-defined segmentation rule # The user-defined segmentation rule
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
@@ -148,7 +148,7 @@ class BaseIndexProcessor(ABC):
embedding_model_instance=embedding_model_instance, embedding_model_instance=embedding_model_instance,
) )
return character_splitter # type: ignore return character_splitter
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]: def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
""" """

View File

@@ -4,19 +4,13 @@ from __future__ import annotations
import codecs import codecs
import re import re
from typing import Any from collections.abc import Collection
from typing import Any, Literal
from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.rag.splitter.text_splitter import ( from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
TS,
Collection,
Literal,
RecursiveCharacterTextSplitter,
Set,
Union,
)
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
@@ -25,13 +19,13 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
""" """
@classmethod @classmethod
def from_encoder( def from_encoder[T: EnhanceRecursiveCharacterTextSplitter](
cls: type[TS], cls: type[T],
embedding_model_instance: ModelInstance | None, embedding_model_instance: ModelInstance | None,
allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037 allowed_special: Literal["all"] | set[str] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037 disallowed_special: Literal["all"] | Collection[str] = "all",
**kwargs: Any, **kwargs: Any,
): ) -> T:
def _token_encoder(texts: list[str]) -> list[int]: def _token_encoder(texts: list[str]) -> list[int]:
if not texts: if not texts:
return [] return []

View File

@@ -6,19 +6,12 @@ import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Collection, Iterable, Sequence, Set from collections.abc import Callable, Collection, Iterable, Sequence, Set
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import Any, Literal
Any,
Literal,
TypeVar,
Union,
)
from core.rag.models.document import BaseDocumentTransformer, Document from core.rag.models.document import BaseDocumentTransformer, Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TS = TypeVar("TS", bound="TextSplitter")
def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]: def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]:
# Now that we have the separator, split the text # Now that we have the separator, split the text
@@ -194,8 +187,8 @@ class TokenTextSplitter(TextSplitter):
self, self,
encoding_name: str = "gpt2", encoding_name: str = "gpt2",
model_name: str | None = None, model_name: str | None = None,
allowed_special: Union[Literal["all"], Set[str]] = set(), allowed_special: Literal["all"] | Set[str] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all", disallowed_special: Literal["all"] | Collection[str] = "all",
**kwargs: Any, **kwargs: Any,
): ):
"""Create a new TextSplitter.""" """Create a new TextSplitter."""

View File

@@ -6,7 +6,6 @@ providing improved performance by offloading database operations to background w
""" """
import logging import logging
from typing import Union
from graphon.entities import WorkflowExecution from graphon.entities import WorkflowExecution
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
@@ -47,7 +46,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__( def __init__(
self, self,
session_factory: sessionmaker | Engine, session_factory: sessionmaker | Engine,
user: Union[Account, EndUser], user: Account | EndUser,
app_id: str | None, app_id: str | None,
triggered_from: WorkflowRunTriggeredFrom | None, triggered_from: WorkflowRunTriggeredFrom | None,
): ):

View File

@@ -7,7 +7,6 @@ providing improved performance by offloading database operations to background w
import logging import logging
from collections.abc import Sequence from collections.abc import Sequence
from typing import Union
from graphon.entities import WorkflowNodeExecution from graphon.entities import WorkflowNodeExecution
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
@@ -54,7 +53,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
def __init__( def __init__(
self, self,
session_factory: sessionmaker | Engine, session_factory: sessionmaker | Engine,
user: Union[Account, EndUser], user: Account | EndUser,
app_id: str | None, app_id: str | None,
triggered_from: WorkflowNodeExecutionTriggeredFrom | None, triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
): ):

View File

@@ -7,7 +7,7 @@ allowing users to configure different repository backends through string paths.
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal, Protocol, Union from typing import Literal, Protocol
from graphon.entities import WorkflowExecution, WorkflowNodeExecution from graphon.entities import WorkflowExecution, WorkflowNodeExecution
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
@@ -61,8 +61,8 @@ class DifyCoreRepositoryFactory:
@classmethod @classmethod
def create_workflow_execution_repository( def create_workflow_execution_repository(
cls, cls,
session_factory: Union[sessionmaker, Engine], session_factory: sessionmaker | Engine,
user: Union[Account, EndUser], user: Account | EndUser,
app_id: str, app_id: str,
triggered_from: WorkflowRunTriggeredFrom, triggered_from: WorkflowRunTriggeredFrom,
) -> WorkflowExecutionRepository: ) -> WorkflowExecutionRepository:
@@ -97,8 +97,8 @@ class DifyCoreRepositoryFactory:
@classmethod @classmethod
def create_workflow_node_execution_repository( def create_workflow_node_execution_repository(
cls, cls,
session_factory: Union[sessionmaker, Engine], session_factory: sessionmaker | Engine,
user: Union[Account, EndUser], user: Account | EndUser,
app_id: str, app_id: str,
triggered_from: WorkflowNodeExecutionTriggeredFrom, triggered_from: WorkflowNodeExecutionTriggeredFrom,
) -> WorkflowNodeExecutionRepository: ) -> WorkflowNodeExecutionRepository:

View File

@@ -4,7 +4,6 @@ SQLAlchemy implementation of the WorkflowExecutionRepository.
import json import json
import logging import logging
from typing import Union
from graphon.entities import WorkflowExecution from graphon.entities import WorkflowExecution
from graphon.enums import WorkflowExecutionStatus, WorkflowType from graphon.enums import WorkflowExecutionStatus, WorkflowType
@@ -40,7 +39,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__( def __init__(
self, self,
session_factory: sessionmaker | Engine, session_factory: sessionmaker | Engine,
user: Union[Account, EndUser], user: Account | EndUser,
app_id: str | None, app_id: str | None,
triggered_from: WorkflowRunTriggeredFrom | None, triggered_from: WorkflowRunTriggeredFrom | None,
): ):

View File

@@ -7,7 +7,7 @@ import json
import logging import logging
from collections.abc import Callable, Mapping, Sequence from collections.abc import Callable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any, TypeVar, Union from typing import Any
import psycopg2.errors import psycopg2.errors
from graphon.entities import WorkflowNodeExecution from graphon.entities import WorkflowNodeExecution
@@ -63,7 +63,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
def __init__( def __init__(
self, self,
session_factory: sessionmaker | Engine, session_factory: sessionmaker | Engine,
user: Union[Account, EndUser], user: Account | EndUser,
app_id: str | None, app_id: str | None,
triggered_from: WorkflowNodeExecutionTriggeredFrom | None, triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
): ):
@@ -551,10 +551,7 @@ def _deterministic_json_dump(value: Mapping[str, Any]) -> str:
return json.dumps(value, sort_keys=True) return json.dumps(value, sort_keys=True)
_T = TypeVar("_T") def _find_first[T](seq: Sequence[T], pred: Callable[[T], bool]) -> T | None:
def _find_first(seq: Sequence[_T], pred: Callable[[_T], bool]) -> _T | None:
filtered = [i for i in seq if pred(i)] filtered = [i for i in seq if pred(i)]
if filtered: if filtered:
return filtered[0] return filtered[0]

View File

@@ -3,15 +3,15 @@ import re
import threading import threading
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union from typing import Any
from core.schemas.registry import SchemaRegistry from core.schemas.registry import SchemaRegistry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Type aliases for better clarity # Type aliases for better clarity
SchemaType = Union[dict[str, Any], list[Any], str, int, float, bool, None] type SchemaType = dict[str, Any] | list[Any] | str | int | float | bool | None
SchemaDict = dict[str, Any] type SchemaDict = dict[str, Any]
# Pre-compiled pattern for better performance # Pre-compiled pattern for better performance
_DIFY_SCHEMA_PATTERN = re.compile(r"^https://dify\.ai/schemas/(v\d+)/(.+)\.json$") _DIFY_SCHEMA_PATTERN = re.compile(r"^https://dify\.ai/schemas/(v\d+)/(.+)\.json$")
@@ -54,7 +54,7 @@ class QueueItem:
current: Any current: Any
parent: Any | None parent: Any | None
key: Union[str, int] | None key: str | int | None
depth: int depth: int
ref_path: set[str] ref_path: set[str]

View File

@@ -1,6 +1,5 @@
import hashlib import hashlib
import logging import logging
from typing import TypeVar
from redis import RedisError from redis import RedisError
@@ -11,8 +10,6 @@ logger = logging.getLogger(__name__)
TRIGGER_DEBUG_EVENT_TTL = 300 TRIGGER_DEBUG_EVENT_TTL = 300
TTriggerDebugEvent = TypeVar("TTriggerDebugEvent", bound="BaseDebugEvent")
class TriggerDebugEventBus: class TriggerDebugEventBus:
""" """
@@ -81,15 +78,15 @@ class TriggerDebugEventBus:
return 0 return 0
@classmethod @classmethod
def poll( def poll[T: BaseDebugEvent](
cls, cls,
event_type: type[TTriggerDebugEvent], event_type: type[T],
pool_key: str, pool_key: str,
tenant_id: str, tenant_id: str,
user_id: str, user_id: str,
app_id: str, app_id: str,
node_id: str, node_id: str,
) -> TTriggerDebugEvent | None: ) -> T | None:
""" """
Poll for an event or register to the waiting pool. Poll for an event or register to the waiting pool.

View File

@@ -2,7 +2,7 @@ import importlib
import pkgutil import pkgutil
from collections.abc import Callable, Iterator, Mapping, MutableMapping from collections.abc import Callable, Iterator, Mapping, MutableMapping
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Any, TypeAlias, cast, final from typing import TYPE_CHECKING, Any, cast, final, override
from graphon.entities.base_node_data import BaseNodeData from graphon.entities.base_node_data import BaseNodeData
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
@@ -22,7 +22,6 @@ from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeDat
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing_extensions import override
from configs import dify_config from configs import dify_config
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
@@ -192,7 +191,7 @@ class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Nod
NODE_TYPE_CLASSES_MAPPING: MutableMapping[NodeType, Mapping[str, type[Node]]] = _LazyNodeTypeClassesMapping() NODE_TYPE_CLASSES_MAPPING: MutableMapping[NodeType, Mapping[str, type[Node]]] = _LazyNodeTypeClassesMapping()
LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData type LLMCompatibleNodeData = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
def fetch_memory( def fetch_memory(

View File

@@ -3,7 +3,7 @@ import logging
import ssl import ssl
from collections.abc import Callable from collections.abc import Callable
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union from typing import TYPE_CHECKING, Any, Union
import redis import redis
from redis import RedisError from redis import RedisError
@@ -297,12 +297,7 @@ def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
return RedisBroadcastChannel(_pubsub_redis_client) return RedisBroadcastChannel(_pubsub_redis_client)
P = ParamSpec("P") def redis_fallback[T](default_return: T | None = None): # type: ignore
R = TypeVar("R")
T = TypeVar("T")
def redis_fallback(default_return: T | None = None): # type: ignore
""" """
decorator to handle Redis operation exceptions and return a default value when Redis is unavailable. decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
@@ -310,9 +305,9 @@ def redis_fallback(default_return: T | None = None): # type: ignore
default_return: The value to return when a Redis operation fails. Defaults to None. default_return: The value to return when a Redis operation fails. Defaults to None.
""" """
def decorator(func: Callable[P, R]): def decorator[**P, R](func: Callable[P, R]) -> Callable[P, R | T | None]:
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs): def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | T | None:
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
except RedisError as e: except RedisError as e:

View File

@@ -2,7 +2,6 @@ import json
import logging import logging
import os import os
import time import time
from typing import Union
from graphon.entities import WorkflowExecution from graphon.entities import WorkflowExecution
from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
@@ -27,7 +26,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__( def __init__(
self, self,
session_factory: sessionmaker | Engine, session_factory: sessionmaker | Engine,
user: Union[Account, EndUser], user: Account | EndUser,
app_id: str | None, app_id: str | None,
triggered_from: WorkflowRunTriggeredFrom | None, triggered_from: WorkflowRunTriggeredFrom | None,
): ):

View File

@@ -11,7 +11,7 @@ import os
import time import time
from collections.abc import Sequence from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from typing import Any, Union from typing import Any
from graphon.entities import WorkflowNodeExecution from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@@ -109,7 +109,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
def __init__( def __init__(
self, self,
session_factory: sessionmaker | Engine, session_factory: sessionmaker | Engine,
user: Union[Account, EndUser], user: Account | EndUser,
app_id: str | None, app_id: str | None,
triggered_from: WorkflowNodeExecutionTriggeredFrom | None, triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
): ):

View File

@@ -1,6 +1,6 @@
import functools import functools
from collections.abc import Callable from collections.abc import Callable
from typing import ParamSpec, TypeVar, cast from typing import cast
from opentelemetry.trace import get_tracer from opentelemetry.trace import get_tracer
@@ -8,9 +8,6 @@ from configs import dify_config
from extensions.otel.decorators.handler import SpanHandler from extensions.otel.decorators.handler import SpanHandler
from extensions.otel.runtime import is_instrument_flag_enabled from extensions.otel.runtime import is_instrument_flag_enabled
P = ParamSpec("P")
R = TypeVar("R")
_HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()} _HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()}
@@ -21,7 +18,7 @@ def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler:
return _HANDLER_INSTANCES[handler_class] return _HANDLER_INSTANCES[handler_class]
def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[Callable[P, R]], Callable[P, R]]: def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callable[[Callable[P, R]], Callable[P, R]]:
""" """
Decorator that traces a function with an OpenTelemetry span. Decorator that traces a function with an OpenTelemetry span.

View File

@@ -1,11 +1,9 @@
import inspect import inspect
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from typing import Any, TypeVar from typing import Any
from opentelemetry.trace import SpanKind, Status, StatusCode from opentelemetry.trace import SpanKind, Status, StatusCode
R = TypeVar("R")
class SpanHandler: class SpanHandler:
""" """
@@ -31,9 +29,9 @@ class SpanHandler:
""" """
return f"{wrapped.__module__}.{wrapped.__qualname__}" return f"{wrapped.__module__}.{wrapped.__qualname__}"
def _extract_arguments( def _extract_arguments[T](
self, self,
wrapped: Callable[..., R], wrapped: Callable[..., T],
args: tuple[object, ...], args: tuple[object, ...],
kwargs: Mapping[str, object], kwargs: Mapping[str, object],
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
@@ -61,13 +59,13 @@ class SpanHandler:
except Exception: except Exception:
return None return None
def wrapper( def wrapper[T](
self, self,
tracer: Any, tracer: Any,
wrapped: Callable[..., R], wrapped: Callable[..., T],
args: tuple[object, ...], args: tuple[object, ...],
kwargs: Mapping[str, object], kwargs: Mapping[str, object],
) -> R: ) -> T:
""" """
Fully control the wrapper behavior. Fully control the wrapper behavior.

View File

@@ -1,6 +1,6 @@
import logging import logging
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from typing import Any, TypeVar from typing import Any
from opentelemetry.trace import SpanKind, Status, StatusCode from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.util.types import AttributeValue from opentelemetry.util.types import AttributeValue
@@ -12,19 +12,16 @@ from models.model import Account
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
R = TypeVar("R")
class AppGenerateHandler(SpanHandler): class AppGenerateHandler(SpanHandler):
"""Span handler for ``AppGenerateService.generate``.""" """Span handler for ``AppGenerateService.generate``."""
def wrapper( def wrapper[T](
self, self,
tracer: Any, tracer: Any,
wrapped: Callable[..., R], wrapped: Callable[..., T],
args: tuple[object, ...], args: tuple[object, ...],
kwargs: Mapping[str, object], kwargs: Mapping[str, object],
) -> R: ) -> T:
try: try:
arguments = self._extract_arguments(wrapped, args, kwargs) arguments = self._extract_arguments(wrapped, args, kwargs)
if not arguments: if not arguments:

View File

@@ -1,12 +1,12 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import Any, TypeAlias from typing import Any
from graphon.file import File from graphon.file import File
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
JSONValue: TypeAlias = Any type JSONValue = Any
class ResponseModel(BaseModel): class ResponseModel(BaseModel):

View File

@@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import TypeAlias
from uuid import uuid4 from uuid import uuid4
from graphon.file import File from graphon.file import File
@@ -10,7 +9,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
from fields.conversation_fields import AgentThought, JSONValue, MessageFile from fields.conversation_fields import AgentThought, JSONValue, MessageFile
JSONValueType: TypeAlias = JSONValue type JSONValueType = JSONValue
class ResponseModel(BaseModel): class ResponseModel(BaseModel):

View File

@@ -1,12 +1,10 @@
import contextvars import contextvars
from collections.abc import Iterator from collections.abc import Iterator
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, TypeVar from typing import TYPE_CHECKING
from flask import Flask, g from flask import Flask, g
T = TypeVar("T")
if TYPE_CHECKING: if TYPE_CHECKING:
from models import Account, EndUser from models import Account, EndUser

View File

@@ -42,13 +42,7 @@ def current_account_with_tenant():
return user, user.current_tenant_id return user, user.current_tenant_id
from typing import ParamSpec, TypeVar def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
P = ParamSpec("P")
R = TypeVar("R")
def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
""" """
If you decorate a view with this, it will ensure that the current user is If you decorate a view with this, it will ensure that the current user is
logged in and authenticated before calling the actual view. (If they are logged in and authenticated before calling the actual view. (If they are

View File

@@ -1,26 +1,20 @@
import logging import logging
import sys
import urllib.parse import urllib.parse
from dataclasses import dataclass from dataclasses import dataclass
from typing import NotRequired from typing import NotRequired, TypedDict
import httpx import httpx
from pydantic import TypeAdapter, ValidationError from pydantic import TypeAdapter, ValidationError
from core.helper.http_client_pooling import get_pooled_http_client from core.helper.http_client_pooling import get_pooled_http_client
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
JsonObject = dict[str, object] type JsonObject = dict[str, object]
JsonObjectList = list[JsonObject] type JsonObjectList = list[JsonObject]
JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject) JSON_OBJECT_ADAPTER: TypeAdapter[JsonObject] = TypeAdapter(JsonObject)
JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList) JSON_OBJECT_LIST_ADAPTER: TypeAdapter[JsonObjectList] = TypeAdapter(JsonObjectList)
# Reuse a pooled httpx.Client for OAuth flows (public endpoints, no SSRF proxy). # Reuse a pooled httpx.Client for OAuth flows (public endpoints, no SSRF proxy).
_http_client: httpx.Client = get_pooled_http_client( _http_client: httpx.Client = get_pooled_http_client(

View File

@@ -1,6 +1,5 @@
import sys
import urllib.parse import urllib.parse
from typing import Any, Literal from typing import Any, Literal, TypedDict
import httpx import httpx
from flask_login import current_user from flask_login import current_user
@@ -12,11 +11,6 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding from models.source import DataSourceOauthBinding
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
class NotionPageSummary(TypedDict): class NotionPageSummary(TypedDict):
page_id: str page_id: str

View File

@@ -8,7 +8,7 @@ from datetime import datetime
from decimal import Decimal from decimal import Decimal
from enum import StrEnum, auto from enum import StrEnum, auto
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, cast
from uuid import uuid4 from uuid import uuid4
import sqlalchemy as sa import sqlalchemy as sa
@@ -19,7 +19,6 @@ from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from graphon.file import helpers as file_helpers from graphon.file import helpers as file_helpers
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import TypedDict
from configs import dify_config from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS from constants import DEFAULT_FILE_NUMBER_LIMITS

View File

@@ -1,6 +1,6 @@
import enum import enum
import uuid import uuid
from typing import Any, Generic, TypeVar from typing import Any
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import CHAR, TEXT, VARCHAR, LargeBinary, TypeDecorator from sqlalchemy import CHAR, TEXT, VARCHAR, LargeBinary, TypeDecorator
@@ -110,17 +110,14 @@ class AdjustedJSON(TypeDecorator[dict | list | None]):
return value return value
_E = TypeVar("_E", bound=enum.StrEnum) class EnumText[T: enum.StrEnum](TypeDecorator[T | None]):
class EnumText(TypeDecorator[_E | None], Generic[_E]):
impl = VARCHAR impl = VARCHAR
cache_ok = True cache_ok = True
_length: int _length: int
_enum_class: type[_E] _enum_class: type[T]
def __init__(self, enum_class: type[_E], length: int | None = None): def __init__(self, enum_class: type[T], length: int | None = None):
self._enum_class = enum_class self._enum_class = enum_class
max_enum_value_len = max(len(e.value) for e in enum_class) max_enum_value_len = max(len(e.value) for e in enum_class)
if length is not None: if length is not None:
@@ -131,25 +128,25 @@ class EnumText(TypeDecorator[_E | None], Generic[_E]):
# leave some rooms for future longer enum values. # leave some rooms for future longer enum values.
self._length = max(max_enum_value_len, 20) self._length = max(max_enum_value_len, 20)
def process_bind_param(self, value: _E | str | None, dialect: Dialect) -> str | None: def process_bind_param(self, value: T | str | None, dialect: Dialect) -> str | None:
if value is None: if value is None:
return value return value
if isinstance(value, self._enum_class): if isinstance(value, self._enum_class):
return value.value return value.value
# Since _E is bound to StrEnum which inherits from str, at this point value must be str # Since T is bound to StrEnum which inherits from str, at this point value must be str
self._enum_class(value) self._enum_class(value)
return value return value
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
return dialect.type_descriptor(VARCHAR(self._length)) return dialect.type_descriptor(VARCHAR(self._length))
def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None: def process_result_value(self, value: str | None, dialect: Dialect) -> T | None:
if value is None or value == "": if value is None or value == "":
return None return None
# Type annotation guarantees value is str at this point # Type annotation guarantees value is str at this point
return self._enum_class(value) return self._enum_class(value)
def compare_values(self, x: _E | None, y: _E | None) -> bool: def compare_values(self, x: T | None, y: T | None) -> bool:
if x is None or y is None: if x is None or y is None:
return x is y return x is y
return x == y return x == y

View File

@@ -1,7 +1,7 @@
[project] [project]
name = "dify-api" name = "dify-api"
version = "1.13.3" version = "1.13.3"
requires-python = ">=3.11,<3.13" requires-python = "~=3.12.0"
dependencies = [ dependencies = [
"aliyun-log-python-sdk~=0.9.37", "aliyun-log-python-sdk~=0.9.37",
@@ -232,5 +232,5 @@ vdb = [
project-includes = ["."] project-includes = ["."]
project-excludes = [".venv", "migrations/"] project-excludes = [".venv", "migrations/"]
python-platform = "linux" python-platform = "linux"
python-version = "3.11.0" python-version = "3.12.0"
infer-with-first-use = false infer-with-first-use = false

View File

@@ -50,6 +50,6 @@
"reportUntypedFunctionDecorator": "hint", "reportUntypedFunctionDecorator": "hint",
"reportUnnecessaryTypeIgnoreComment": "hint", "reportUnnecessaryTypeIgnoreComment": "hint",
"reportAttributeAccessIssue": "hint", "reportAttributeAccessIssue": "hint",
"pythonVersion": "3.11", "pythonVersion": "3.12",
"pythonPlatform": "All" "pythonPlatform": "All"
} }

View File

@@ -5,12 +5,11 @@ import secrets
import uuid import uuid
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from hashlib import sha256 from hashlib import sha256
from typing import Any, cast from typing import Any, TypedDict, cast
from pydantic import BaseModel, TypeAdapter from pydantic import BaseModel, TypeAdapter
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing_extensions import TypedDict
class InvitationData(TypedDict): class InvitationData(TypedDict):

View File

@@ -1,7 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any, TypedDict
from typing_extensions import TypedDict
class AuthCredentials(TypedDict): class AuthCredentials(TypedDict):

View File

@@ -2,13 +2,12 @@ import json
import logging import logging
import os import os
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal from typing import Literal, TypedDict
import httpx import httpx
from pydantic import TypeAdapter from pydantic import TypeAdapter
from sqlalchemy import select from sqlalchemy import select
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
from typing_extensions import TypedDict
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
from core.helper.http_client_pooling import get_pooled_http_client from core.helper.http_client_pooling import get_pooled_http_client

View File

@@ -1,7 +1,7 @@
import contextlib import contextlib
import logging import logging
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Any, Union from typing import Any
from graphon.variables.types import SegmentType from graphon.variables.types import SegmentType
from sqlalchemy import asc, desc, func, or_, select from sqlalchemy import asc, desc, func, or_, select
@@ -37,7 +37,7 @@ class ConversationService:
*, *,
session: Session, session: Session,
app_model: App, app_model: App,
user: Union[Account, EndUser] | None, user: Account | EndUser | None,
last_id: str | None, last_id: str | None,
limit: int, limit: int,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
@@ -119,7 +119,7 @@ class ConversationService:
cls, cls,
app_model: App, app_model: App,
conversation_id: str, conversation_id: str,
user: Union[Account, EndUser] | None, user: Account | EndUser | None,
name: str | None, name: str | None,
auto_generate: bool, auto_generate: bool,
): ):
@@ -159,7 +159,7 @@ class ConversationService:
return conversation return conversation
@classmethod @classmethod
def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): def get_conversation(cls, app_model: App, conversation_id: str, user: Account | EndUser | None):
conversation = db.session.scalar( conversation = db.session.scalar(
select(Conversation) select(Conversation)
.where( .where(
@@ -179,7 +179,7 @@ class ConversationService:
return conversation return conversation
@classmethod @classmethod
def delete(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): def delete(cls, app_model: App, conversation_id: str, user: Account | EndUser | None):
""" """
Delete a conversation only if it belongs to the given user and app context. Delete a conversation only if it belongs to the given user and app context.
@@ -209,7 +209,7 @@ class ConversationService:
cls, cls,
app_model: App, app_model: App,
conversation_id: str, conversation_id: str,
user: Union[Account, EndUser] | None, user: Account | EndUser | None,
limit: int, limit: int,
last_id: str | None, last_id: str | None,
variable_name: str | None = None, variable_name: str | None = None,
@@ -278,7 +278,7 @@ class ConversationService:
app_model: App, app_model: App,
conversation_id: str, conversation_id: str,
variable_id: str, variable_id: str,
user: Union[Account, EndUser] | None, user: Account | EndUser | None,
new_value: Any, new_value: Any,
): ):
""" """

View File

@@ -1,5 +1,4 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Union
from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.model_entities import ModelType
from pydantic import TypeAdapter from pydantic import TypeAdapter
@@ -57,7 +56,7 @@ class MessageService:
def pagination_by_first_id( def pagination_by_first_id(
cls, cls,
app_model: App, app_model: App,
user: Union[Account, EndUser] | None, user: Account | EndUser | None,
conversation_id: str, conversation_id: str,
first_id: str | None, first_id: str | None,
limit: int, limit: int,
@@ -117,7 +116,7 @@ class MessageService:
def pagination_by_last_id( def pagination_by_last_id(
cls, cls,
app_model: App, app_model: App,
user: Union[Account, EndUser] | None, user: Account | EndUser | None,
last_id: str | None, last_id: str | None,
limit: int, limit: int,
conversation_id: str | None = None, conversation_id: str | None = None,
@@ -170,7 +169,7 @@ class MessageService:
*, *,
app_model: App, app_model: App,
message_id: str, message_id: str,
user: Union[Account, EndUser] | None, user: Account | EndUser | None,
rating: FeedbackRating | None, rating: FeedbackRating | None,
content: str | None, content: str | None,
): ):
@@ -221,7 +220,7 @@ class MessageService:
return [record.to_dict() for record in feedbacks] return [record.to_dict() for record in feedbacks]
@classmethod @classmethod
def get_message(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): def get_message(cls, app_model: App, user: Account | EndUser | None, message_id: str):
message = db.session.scalar( message = db.session.scalar(
select(Message) select(Message)
.where( .where(
@@ -241,7 +240,7 @@ class MessageService:
@classmethod @classmethod
def get_suggested_questions_after_answer( def get_suggested_questions_after_answer(
cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str, invoke_from: InvokeFrom cls, app_model: App, user: Account | EndUser | None, message_id: str, invoke_from: InvokeFrom
) -> list[str]: ) -> list[str]:
if not user: if not user:
raise ValueError("user cannot be None") raise ValueError("user cannot be None")

View File

@@ -5,7 +5,7 @@ import time
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, TypedDict
from uuid import uuid4 from uuid import uuid4
import click import click
@@ -14,7 +14,6 @@ import tqdm
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import TypeAdapter from pydantic import TypeAdapter
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing_extensions import TypedDict
from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.helper import marketplace from core.helper import marketplace

View File

@@ -13,13 +13,12 @@ from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, cast from typing import Any, TypedDict, cast
import click import click
from pydantic import TypeAdapter from pydantic import TypeAdapter
from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.engine import CursorResult from sqlalchemy.engine import CursorResult
from typing_extensions import TypedDict
class _TableInfo(TypedDict, total=False): class _TableInfo(TypedDict, total=False):

View File

@@ -1,5 +1,3 @@
from typing import Union
from sqlalchemy import select from sqlalchemy import select
from extensions.ext_database import db from extensions.ext_database import db
@@ -14,7 +12,7 @@ from services.message_service import MessageService
class SavedMessageService: class SavedMessageService:
@classmethod @classmethod
def pagination_by_last_id( def pagination_by_last_id(
cls, app_model: App, user: Union[Account, EndUser] | None, last_id: str | None, limit: int cls, app_model: App, user: Account | EndUser | None, last_id: str | None, limit: int
) -> InfiniteScrollPagination: ) -> InfiniteScrollPagination:
if not user: if not user:
raise ValueError("User is required") raise ValueError("User is required")
@@ -34,7 +32,7 @@ class SavedMessageService:
) )
@classmethod @classmethod
def save(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): def save(cls, app_model: App, user: Account | EndUser | None, message_id: str):
if not user: if not user:
return return
saved_message = db.session.scalar( saved_message = db.session.scalar(
@@ -64,7 +62,7 @@ class SavedMessageService:
db.session.commit() db.session.commit()
@classmethod @classmethod
def delete(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str): def delete(cls, app_model: App, user: Account | EndUser | None, message_id: str):
if not user: if not user:
return return
saved_message = db.session.scalar( saved_message = db.session.scalar(

View File

@@ -1,11 +1,10 @@
import json import json
import logging import logging
from typing import Any, cast from typing import Any, TypedDict, cast
from graphon.model_runtime.utils.encoders import jsonable_encoder from graphon.model_runtime.utils.encoders import jsonable_encoder
from httpx import get from httpx import get
from sqlalchemy import select from sqlalchemy import select
from typing_extensions import TypedDict
from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import dataclasses import dataclasses
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Generic, TypeAlias, TypeVar, overload from typing import Any, overload
from graphon.file import File from graphon.file import File
from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable
@@ -43,12 +43,9 @@ class _PCKeys:
CHILD_CONTENTS = "child_contents" CHILD_CONTENTS = "child_contents"
_T = TypeVar("_T")
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class _PartResult(Generic[_T]): class _PartResult[T]:
value: _T value: T
value_size: int value_size: int
truncated: bool truncated: bool
@@ -61,7 +58,7 @@ class UnknownTypeError(Exception):
pass pass
JSONTypes: TypeAlias = int | float | str | list[object] | dict[str, object] | None | bool type JSONTypes = int | float | str | list[object] | dict[str, object] | None | bool
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)

View File

@@ -1,5 +1,3 @@
from typing import Union
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -20,7 +18,7 @@ class WebConversationService:
*, *,
session: Session, session: Session,
app_model: App, app_model: App,
user: Union[Account, EndUser] | None, user: Account | EndUser | None,
last_id: str | None, last_id: str | None,
limit: int, limit: int,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
@@ -61,7 +59,7 @@ class WebConversationService:
) )
@classmethod @classmethod
def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): def pin(cls, app_model: App, conversation_id: str, user: Account | EndUser | None):
if not user: if not user:
return return
pinned_conversation = db.session.scalar( pinned_conversation = db.session.scalar(
@@ -93,7 +91,7 @@ class WebConversationService:
db.session.commit() db.session.commit()
@classmethod @classmethod
def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): def unpin(cls, app_model: App, conversation_id: str, user: Account | EndUser | None):
if not user: if not user:
return return
pinned_conversation = db.session.scalar( pinned_conversation = db.session.scalar(

View File

@@ -1,5 +1,5 @@
import json import json
from typing import Any from typing import Any, TypedDict
from graphon.file import FileUploadConfig from graphon.file import FileUploadConfig
from graphon.model_runtime.entities.llm_entities import LLMMode from graphon.model_runtime.entities.llm_entities import LLMMode
@@ -7,7 +7,6 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder
from graphon.nodes import BuiltinNodeTypes from graphon.nodes import BuiltinNodeTypes
from graphon.variables.input_entities import VariableEntity from graphon.variables.input_entities import VariableEntity
from sqlalchemy import select from sqlalchemy import select
from typing_extensions import TypedDict
from core.app.app_config.entities import ( from core.app.app_config.entities import (
DatasetEntity, DatasetEntity,

View File

@@ -1,12 +1,11 @@
import json import json
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any, TypedDict
from graphon.enums import WorkflowExecutionStatus from graphon.enums import WorkflowExecutionStatus
from sqlalchemy import and_, func, or_, select from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing_extensions import TypedDict
from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
from models.enums import AppTriggerType, CreatorUserRole from models.enums import AppTriggerType, CreatorUserRole

View File

@@ -3,7 +3,7 @@ import logging
import uuid import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from enum import StrEnum from enum import StrEnum
from typing import Annotated, Any, TypeAlias, Union from typing import Annotated, Any
from celery import shared_task from celery import shared_task
from flask import current_app, json from flask import current_app, json
@@ -68,7 +68,7 @@ def _get_user_type_descriminator(value: Any):
return None return None
User: TypeAlias = Annotated[ type User = Annotated[
(Annotated[_Account, Tag(_UserType.ACCOUNT)] | Annotated[_EndUser, Tag(_UserType.END_USER)]), (Annotated[_Account, Tag(_UserType.ACCOUNT)] | Annotated[_EndUser, Tag(_UserType.END_USER)]),
Discriminator(_get_user_type_descriminator), Discriminator(_get_user_type_descriminator),
] ]
@@ -93,7 +93,7 @@ class AppExecutionParams(BaseModel):
cls, cls,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,

View File

@@ -12,7 +12,7 @@ import os
from collections.abc import Generator from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Protocol, TypeVar from typing import Protocol
import psycopg2 import psycopg2
import pytest import pytest
@@ -48,11 +48,8 @@ class _CloserProtocol(Protocol):
pass pass
_Closer = TypeVar("_Closer", bound=_CloserProtocol)
@contextmanager @contextmanager
def _auto_close(closer: _Closer) -> Generator[_Closer, None, None]: def _auto_close[T: _CloserProtocol](closer: T) -> Generator[T, None, None]:
yield closer yield closer
closer.close() closer.close()

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from enum import StrEnum from enum import StrEnum
from typing import Any, NamedTuple, TypeVar from typing import Any, NamedTuple
import pytest import pytest
import sqlalchemy as sa import sqlalchemy as sa
@@ -58,10 +58,7 @@ class _ColumnTest(_Base):
long_value: Mapped[_EnumWithLongValue] = mapped_column(EnumText(enum_class=_EnumWithLongValue), nullable=False) long_value: Mapped[_EnumWithLongValue] = mapped_column(EnumText(enum_class=_EnumWithLongValue), nullable=False)
_T = TypeVar("_T") def _first[T](it: Iterable[T]) -> T:
def _first(it: Iterable[_T]) -> _T:
ls = list(it) ls = list(it)
if not ls: if not ls:
raise ValueError("List is empty") raise ValueError("List is empty")

View File

@@ -6,7 +6,7 @@ and data_source_detail_dict for all data_source_type values, including "local_fi
""" """
import json import json
from typing import Generic, Literal, NotRequired, TypedDict, TypeVar, Union from typing import Literal, NotRequired, TypedDict
from models.dataset import Document from models.dataset import Document
@@ -31,12 +31,10 @@ class WebsiteCrawlInfo(TypedDict):
job_id: str job_id: str
RawInfo = Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo] type RawInfo = LocalFileInfo | UploadFileInfo | NotionImportInfo | WebsiteCrawlInfo
T_type = TypeVar("T_type", bound=str)
T_info = TypeVar("T_info", bound=Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo])
class Case(TypedDict, Generic[T_type, T_info]): class Case[T_type: str, T_info: RawInfo](TypedDict):
data_source_type: T_type data_source_type: T_type
data_source_info: str data_source_info: str
expected_raw: T_info expected_raw: T_info
@@ -47,7 +45,7 @@ UploadFileCase = Case[Literal["upload_file"], UploadFileInfo]
NotionImportCase = Case[Literal["notion_import"], NotionImportInfo] NotionImportCase = Case[Literal["notion_import"], NotionImportInfo]
WebsiteCrawlCase = Case[Literal["website_crawl"], WebsiteCrawlInfo] WebsiteCrawlCase = Case[Literal["website_crawl"], WebsiteCrawlInfo]
AnyCase = Union[LocalFileCase, UploadFileCase, NotionImportCase, WebsiteCrawlCase] type AnyCase = LocalFileCase | UploadFileCase | NotionImportCase | WebsiteCrawlCase
case_1: LocalFileCase = { case_1: LocalFileCase = {

View File

@@ -77,7 +77,6 @@ class TestAuthType:
def test_auth_type_immutability(self): def test_auth_type_immutability(self):
"""Test that enum values cannot be modified""" """Test that enum values cannot be modified"""
# In Python 3.11+, enum members are read-only
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
AuthType.FIRECRAWL = "modified" AuthType.FIRECRAWL = "modified"

View File

@@ -1,6 +1,5 @@
import json import json
import operator import operator
from typing import TypeVar
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import httpx import httpx
@@ -16,10 +15,8 @@ from core.tools.entities.tool_entities import (
ToolInvokeMessage, ToolInvokeMessage,
) )
_T = TypeVar("_T")
def _get_message_by_type[T](msgs: list[ToolInvokeMessage], msg_type: type[T]) -> ToolInvokeMessage | None:
def _get_message_by_type(msgs: list[ToolInvokeMessage], msg_type: type[_T]) -> ToolInvokeMessage | None:
return next((i for i in msgs if isinstance(i.message, msg_type)), None) return next((i for i in msgs if isinstance(i.message, msg_type)), None)

892
api/uv.lock generated

File diff suppressed because it is too large Load Diff