Compare commits

..

41 Commits

Author SHA1 Message Date
hj24
b2861e019b fix: merge error 2026-04-02 18:16:31 +08:00
Joel
cad9936c0a Merge branch 'fix/ps-not-send' into deploy/dev 2026-04-02 17:55:04 +08:00
hj24
8c0b596ced Merge branch 'chore-debug-partnerstack' into deploy/dev 2026-04-02 17:54:06 +08:00
hj24
12a0f85b72 feat: clear api 2026-04-02 17:52:55 +08:00
hj24
1fdb653875 feat: debug partnerstack 2026-04-02 17:18:25 +08:00
hj24
4ba8c71962 feat: debug partnerstack 2026-04-02 17:17:40 +08:00
Poojan
5bafb163cc test: add unit tests for services and tasks part-4 (#33223)
Co-authored-by: akashseth-ifp <akash.seth@infocusp.com>
Co-authored-by: rajatagarwal-oss <rajat.agarwal@infocusp.com>
Co-authored-by: Dev Sharma <50591491+cryptus-neoxys@users.noreply.github.com>
Co-authored-by: sahil-infocusp <73810410+sahil-infocusp@users.noreply.github.com>
2026-04-02 08:35:46 +00:00
Stephen Zhou
52b1bc5b09 refactor: split icon collections (#34453) 2026-04-02 07:58:15 +00:00
Stephen Zhou
1873b22e96 refactor: update to tailwind v4 (#34415)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
2026-04-02 07:06:11 +00:00
Akash Kumar
9a8c853a2e test: added unit test for remaining files in core helper folder (#33288)
Co-authored-by: rajatagarwal-oss <rajat.agarwal@infocusp.com>
Co-authored-by: sahil-infocusp <73810410+sahil-infocusp@users.noreply.github.com>
2026-04-02 06:50:58 +00:00
Akash Kumar
e54383d0fe test: added test for api/services/rag_pipeline folder (#33222)
Co-authored-by: sahil-infocusp <73810410+sahil-infocusp@users.noreply.github.com>
2026-04-02 06:40:52 +00:00
Sedo
43c48ba4d7 fix: add tenant/dataset ownership checks to prevent IDOR vulnerabilities (#34436)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-02 05:45:20 +00:00
99
8f9dbf269e chore(api): align Python support with 3.12 (#34419)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-04-02 05:07:32 +00:00
Renzo
cb9ee5903a refactor: select in tag_service (#34441)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-02 05:04:36 +00:00
99
cd406d2794 refactor(api): replace test fixture side-effect imports (#34421) 2026-04-02 04:55:15 +00:00
Joel
1f1c74099f Merge branch 'fix/ps-not-send' into deploy/dev 2026-04-02 12:53:28 +08:00
wangxiaolei
993a301468 fix: fix online_drive is not a valid datasource_type (#34440) 2026-04-02 04:45:02 +00:00
Renzo
399d3f8da5 refactor: model_load_balancing_service and api_tools_manage_service (#34434)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-02 04:38:35 +00:00
yyh
f9d9ad7a38 refactor(web): migrate remaining toast usage (#34433) 2026-04-02 04:16:50 +00:00
Joel
43fedac47b Merge branch 'fix/ps-not-send' into deploy/dev 2026-04-02 11:23:20 +08:00
hj24
a91c1a2af0 Merge branch 'refactor-enhance-billing-info-guard' into deploy/dev 2026-04-02 11:02:00 +08:00
Yansong Zhang
b3870524d4 fix usage get 2026-04-02 09:52:52 +08:00
hj24
919c080452 chore: update comments 2026-04-01 10:35:34 +08:00
hj24
4653ed7ead refactor: enhance billing info response handling 2026-03-31 18:23:32 +08:00
Yansong Zhang
c543188434 fix linter 2026-03-31 15:22:51 +08:00
Yansong Zhang
f319a9e42f fix test case 2026-03-31 15:22:43 +08:00
Yansong Zhang
58241a89a5 fix linter 2026-03-31 14:59:54 +08:00
Yansong Zhang
422bf3506e rebuild quota service 2026-03-31 14:59:45 +08:00
Yansong Zhang
6e745f9e9b fix linter 2026-03-31 09:49:24 +08:00
Yansong Zhang
4e50d55339 fix comment 2026-03-31 09:49:09 +08:00
autofix-ci[bot]
b95cdabe26 [autofix.ci] apply automated fixes 2026-03-30 08:45:37 +00:00
Yansong Zhang
daa47c25bb Merge branch 'feat/new-biliing-quota' of github.com:langgenius/dify into feat/new-biliing-quota 2026-03-30 16:43:13 +08:00
Yansong Zhang
f1bcd6d715 add test case for quota and billing service 2026-03-30 16:41:56 +08:00
hj24
8643ff43f5 Merge branch 'main' into feat/new-biliing-quota 2026-03-30 15:57:49 +08:00
Yansong Zhang
c5f30a47f0 Merge remote-tracking branch 'origin/main' into feat/new-biliing-quota 2026-03-30 15:26:38 +08:00
Yansong Zhang
37d438fa19 Merge remote-tracking branch 'origin/main' into feat/new-biliing-quota 2026-03-27 16:26:09 +08:00
Yansong Zhang
9503803997 Merge remote-tracking branch 'origin/main' into feat/new-biliing-quota 2026-03-23 09:27:39 +08:00
Yansong Zhang
d6476f5434 Merge remote-tracking branch 'origin/main' into feat/new-biliing-quota 2026-03-20 15:17:27 +08:00
Yansong Zhang
80b4633e8f fix style check and test 2026-03-20 14:58:31 +08:00
autofix-ci[bot]
3888969af3 [autofix.ci] apply automated fixes 2026-03-20 05:45:30 +00:00
Yansong Zhang
658ac15589 use new quota system 2026-03-20 13:29:22 +08:00
1756 changed files with 23556 additions and 12447 deletions

View File

@@ -64,7 +64,7 @@ export const useUpdateAccessMode = () => {
// Component only adds UI behavior.
updateAccessMode({ appId, mode }, {
onSuccess: () => Toast.notify({ type: 'success', message: '...' }),
onSuccess: () => toast.success('...'),
})
// Avoid putting invalidation knowledge in the component.
@@ -114,10 +114,7 @@ try {
router.push(`/orders/${order.id}`)
}
catch (error) {
Toast.notify({
type: 'error',
message: error instanceof Error ? error.message : 'Unknown error',
})
toast.error(error instanceof Error ? error.message : 'Unknown error')
}
```

2
.gitignore vendored
View File

@@ -212,7 +212,7 @@ api/.vscode
# pnpm
/.pnpm-store
/node_modules
node_modules
.vite-hooks/_
# plugin migrate

1
.npmrc Normal file
View File

@@ -0,0 +1 @@
save-exact=true

View File

@@ -115,12 +115,6 @@ ignore = [
"controllers/console/human_input_form.py" = ["TID251"]
"controllers/web/human_input_form.py" = ["TID251"]
[lint.pyflakes]
allowed-unused-imports = [
"tests.integration_tests",
"tests.unit_tests",
]
[lint.flake8-tidy-imports]
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]

View File

@@ -10,7 +10,7 @@ import threading
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Protocol, TypeVar, final, runtime_checkable
from typing import Any, Protocol, final, runtime_checkable
from pydantic import BaseModel
@@ -188,8 +188,6 @@ class ExecutionContextBuilder:
_capturer: Callable[[], IExecutionContext] | None = None
_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
T = TypeVar("T", bound=BaseModel)
class ContextProviderNotFoundError(KeyError):
"""Raised when a tenant-scoped context provider is missing."""

View File

@@ -1,7 +1,4 @@
from contextvars import ContextVar
from typing import Generic, TypeVar
T = TypeVar("T")
class HiddenValue:
@@ -11,7 +8,7 @@ class HiddenValue:
_default = HiddenValue()
class RecyclableContextVar(Generic[T]):
class RecyclableContextVar[T]:
"""
RecyclableContextVar is a wrapper around ContextVar
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now

View File

@@ -1,14 +1,14 @@
from __future__ import annotations
from typing import Any, TypeAlias
from typing import Any
from graphon.file import helpers as file_helpers
from pydantic import BaseModel, ConfigDict, computed_field
from models.model import IconType
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
JSONObject: TypeAlias = dict[str, Any]
type JSONValue = str | int | float | bool | None | dict[str, Any] | list[Any]
type JSONObject = dict[str, Any]
class SystemParameters(BaseModel):

View File

@@ -2,7 +2,6 @@ import csv
import io
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource
@@ -20,9 +19,6 @@ from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from services.billing_service import BillingService
P = ParamSpec("P")
R = TypeVar("R")
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -72,9 +68,9 @@ console_ns.schema_model(
)
def admin_required(view: Callable[P, R]):
def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")

View File

@@ -1,7 +1,7 @@
import logging
import uuid
from datetime import datetime
from typing import Any, Literal, TypeAlias
from typing import Any, Literal
from flask import request
from flask_restx import Resource
@@ -152,7 +152,7 @@ class AppTracePayload(BaseModel):
return value
JSONValue: TypeAlias = Any
type JSONValue = Any
class ResponseModel(BaseModel):

View File

@@ -1,7 +1,7 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import Any, NoReturn, ParamSpec, TypeVar
from typing import Any
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
@@ -192,11 +192,8 @@ workflow_draft_variable_list_model = console_ns.model(
"WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy
)
P = ParamSpec("P")
R = TypeVar("R")
def _api_prerequisite(f: Callable[P, R]):
def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@@ -213,7 +210,7 @@ def _api_prerequisite(f: Callable[P, R]):
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs):
def wrapper(*args: Any, **kwargs: Any):
return f(*args, **kwargs)
return wrapper
@@ -270,7 +267,7 @@ class WorkflowVariableCollectionApi(Resource):
return Response("", 204)
def validate_node_id(node_id: str) -> NoReturn | None:
def validate_node_id(node_id: str) -> None:
if node_id in [
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
@@ -285,7 +282,6 @@ def validate_node_id(node_id: str) -> NoReturn | None:
raise InvalidArgumentError(
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
)
return None
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar, Union
from typing import Any
from sqlalchemy import select
@@ -9,11 +9,6 @@ from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models import App, AppMode
P = ParamSpec("P")
R = TypeVar("R")
P1 = ParamSpec("P1")
R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None:
_, current_tenant_id = current_account_with_tenant()
@@ -28,10 +23,14 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
return app_model
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P1, R1]):
def get_app_model(
view: Callable[..., Any] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(view_func)
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
def decorated_view(*args: Any, **kwargs: Any):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")
@@ -69,10 +68,14 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
return decorator(view)
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]):
def get_app_model_with_trial(
view: Callable[..., Any] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
def decorated_view(*args: Any, **kwargs: Any):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")

View File

@@ -1,8 +1,9 @@
from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from typing import Concatenate
from flask import jsonify, request
from flask.typing import ResponseReturnValue
from flask_restx import Resource
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel
@@ -16,10 +17,6 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
from .. import console_ns
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
class OAuthClientPayload(BaseModel):
client_id: str
@@ -39,9 +36,11 @@ class OAuthTokenRequest(BaseModel):
refresh_token: str | None = None
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
def oauth_server_client_id_required[T, **P, R](
view: Callable[Concatenate[T, OAuthProviderApp, P], R],
) -> Callable[Concatenate[T, P], R]:
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
json_data = request.get_json()
if json_data is None:
raise BadRequest("client_id is required")
@@ -58,9 +57,13 @@ def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderA
return decorated
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
def oauth_server_access_token_required[T, **P, R](
view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R],
) -> Callable[Concatenate[T, OAuthProviderApp, P], R | ResponseReturnValue]:
@wraps(view)
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
def decorated(
self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs
) -> R | ResponseReturnValue:
if not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")

View File

@@ -1,4 +1,6 @@
import base64
import json
from datetime import UTC, datetime, timedelta
from typing import Literal
from flask import request
@@ -9,6 +11,7 @@ from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@@ -84,3 +87,39 @@ class PartnerTenants(Resource):
raise BadRequest("Invalid partner information")
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
_DEBUG_KEY = "billing:debug"
_DEBUG_TTL = timedelta(days=7)
class DebugDataPayload(BaseModel):
type: str = Field(..., min_length=1, description="Data type key")
data: str = Field(..., min_length=1, description="Data value to append")
@console_ns.route("/billing/debug/data")
class DebugData(Resource):
def post(self):
body = DebugDataPayload.model_validate(request.get_json(force=True))
item = json.dumps({
"type": body.type,
"data": body.data,
"createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
})
redis_client.lpush(_DEBUG_KEY, item)
redis_client.expire(_DEBUG_KEY, _DEBUG_TTL)
return {"result": "ok"}, 201
def get(self):
recent = request.args.get("recent", 10, type=int)
items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1)
return {
"data": [
json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items
]
}
def delete(self):
redis_client.delete(_DEBUG_KEY)
return {"result": "ok"}

View File

@@ -173,8 +173,11 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id):
_, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id)
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
external_knowledge_api_id, current_tenant_id
)
if external_knowledge_api is None:
raise NotFound("API template not found.")

View File

@@ -1,6 +1,5 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from sqlalchemy import select
@@ -9,13 +8,10 @@ from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.dataset import Pipeline
P = ParamSpec("P")
R = TypeVar("R")
def get_rag_pipeline(view_func: Callable[P, R]):
def get_rag_pipeline[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters")

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from typing import Concatenate
from flask import abort
from flask_restx import Resource
@@ -15,12 +15,8 @@ from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def installed_app_required[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
@@ -49,7 +45,7 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
return decorator
def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def user_allowed_to_access_app[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
@@ -73,7 +69,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
return decorator
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
def trial_app_required[**P, R](view: Callable[Concatenate[App, P], R] | None = None):
def decorator(view: Callable[Concatenate[App, P], R]):
@wraps(view)
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
@@ -106,7 +102,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
return decorator
def trial_feature_enable(view: Callable[P, R]):
def trial_feature_enable[**P, R](view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
@@ -117,7 +113,7 @@ def trial_feature_enable(view: Callable[P, R]):
return decorated
def explore_banner_enabled(view: Callable[P, R]):
def explore_banner_enabled[**P, R](view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()

View File

@@ -1,6 +1,5 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
@@ -9,17 +8,14 @@ from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.account import TenantPluginPermission
P = ParamSpec("P")
R = TypeVar("R")
def plugin_permission_required(
install_required: bool = False,
debug_required: bool = False,
):
def interceptor(view: Callable[P, R]):
def interceptor[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
tenant_id = current_tenant_id

View File

@@ -4,7 +4,6 @@ import os
import time
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request
from sqlalchemy import select
@@ -25,9 +24,6 @@ from services.operation_service import OperationService
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
P = ParamSpec("P")
R = TypeVar("R")
# Field names for decryption
FIELD_NAME_PASSWORD = "password"
FIELD_NAME_CODE = "code"
@@ -37,7 +33,7 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
# check account initialization
@@ -50,7 +46,7 @@ def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
return decorated
def only_edition_cloud(view: Callable[P, R]):
def only_edition_cloud[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "CLOUD":
@@ -61,7 +57,7 @@ def only_edition_cloud(view: Callable[P, R]):
return decorated
def only_edition_enterprise(view: Callable[P, R]):
def only_edition_enterprise[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ENTERPRISE_ENABLED:
@@ -72,7 +68,7 @@ def only_edition_enterprise(view: Callable[P, R]):
return decorated
def only_edition_self_hosted(view: Callable[P, R]):
def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "SELF_HOSTED":
@@ -83,7 +79,7 @@ def only_edition_self_hosted(view: Callable[P, R]):
return decorated
def cloud_edition_billing_enabled(view: Callable[P, R]):
def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
@@ -95,7 +91,7 @@ def cloud_edition_billing_enabled(view: Callable[P, R]):
return decorated
def cloud_edition_billing_resource_check(resource: str):
def cloud_edition_billing_resource_check[**P, R](resource: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
@@ -137,7 +133,9 @@ def cloud_edition_billing_resource_check(resource: str):
return interceptor
def cloud_edition_billing_knowledge_limit_check(resource: str):
def cloud_edition_billing_knowledge_limit_check[**P, R](
resource: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
@@ -160,7 +158,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
return interceptor
def cloud_edition_billing_rate_limit_check(resource: str):
def cloud_edition_billing_rate_limit_check[**P, R](resource: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
@@ -196,7 +194,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
return interceptor
def cloud_utm_record(view: Callable[P, R]):
def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
@@ -215,7 +213,7 @@ def cloud_utm_record(view: Callable[P, R]):
return decorated
def setup_required(view: Callable[P, R]) -> Callable[P, R]:
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
# check setup
@@ -229,7 +227,7 @@ def setup_required(view: Callable[P, R]) -> Callable[P, R]:
return decorated
def enterprise_license_required(view: Callable[P, R]):
def enterprise_license_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
settings = FeatureService.get_system_features()
@@ -241,7 +239,7 @@ def enterprise_license_required(view: Callable[P, R]):
return decorated
def email_password_login_enabled(view: Callable[P, R]):
def email_password_login_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
@@ -254,7 +252,7 @@ def email_password_login_enabled(view: Callable[P, R]):
return decorated
def email_register_enabled(view: Callable[P, R]):
def email_register_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
@@ -267,7 +265,7 @@ def email_register_enabled(view: Callable[P, R]):
return decorated
def enable_change_email(view: Callable[P, R]):
def enable_change_email[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
@@ -280,7 +278,7 @@ def enable_change_email(view: Callable[P, R]):
return decorated
def is_allow_transfer_owner(view: Callable[P, R]):
def is_allow_transfer_owner[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
from libs.workspace_permission import check_workspace_owner_transfer_permission
@@ -293,7 +291,7 @@ def is_allow_transfer_owner(view: Callable[P, R]):
return decorated
def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
def knowledge_pipeline_publish_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
@@ -305,7 +303,7 @@ def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
return decorated
def edit_permission_required(f: Callable[P, R]):
def edit_permission_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden
@@ -323,7 +321,7 @@ def edit_permission_required(f: Callable[P, R]):
return decorated_function
def is_admin_or_owner_required(f: Callable[P, R]):
def is_admin_or_owner_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden
@@ -339,7 +337,7 @@ def is_admin_or_owner_required(f: Callable[P, R]):
return decorated_function
def annotation_import_rate_limit(view: Callable[P, R]):
def annotation_import_rate_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""
Rate limiting decorator for annotation import operations.
@@ -388,7 +386,7 @@ def annotation_import_rate_limit(view: Callable[P, R]):
return decorated
def annotation_import_concurrency_limit(view: Callable[P, R]):
def annotation_import_concurrency_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""
Concurrency control decorator for annotation import operations.
@@ -455,7 +453,7 @@ def _decrypt_field(field_name: str, error_class: type[Exception], error_message:
payload[field_name] = decoded_value
def decrypt_password_field(view: Callable[P, R]):
def decrypt_password_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""
Decorator to decrypt password field in request payload.
@@ -477,7 +475,7 @@ def decrypt_password_field(view: Callable[P, R]):
return decorated
def decrypt_code_field(view: Callable[P, R]):
def decrypt_code_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""
Decorator to decrypt verification code field in request payload.

View File

@@ -1,6 +1,5 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import current_app, request
from flask_login import user_logged_in
@@ -13,9 +12,6 @@ from libs.login import current_user
from models.account import Tenant
from models.model import DefaultEndUserSessionID, EndUser
P = ParamSpec("P")
R = TypeVar("R")
class TenantUserPayload(BaseModel):
tenant_id: str
@@ -65,9 +61,9 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
return user_model
def get_user_tenant(view_func: Callable[P, R]):
def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
user_id = payload.user_id
@@ -97,10 +93,14 @@ def get_user_tenant(view_func: Callable[P, R]):
return decorated_view
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
def decorator(view_func: Callable[P, R]):
def plugin_data[**P, R](
view: Callable[P, R] | None = None,
*,
payload_type: type[BaseModel],
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
try:
data = request.get_json()
except Exception:

View File

@@ -3,10 +3,7 @@ from collections.abc import Callable
from functools import wraps
from hashlib import sha1
from hmac import new as hmac_new
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
from flask import abort, request
from configs import dify_config
@@ -14,9 +11,9 @@ from extensions.ext_database import db
from models.model import EndUser
def billing_inner_api_only(view: Callable[P, R]):
def billing_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.INNER_API:
abort(404)
@@ -30,9 +27,9 @@ def billing_inner_api_only(view: Callable[P, R]):
return decorated
def enterprise_inner_api_only(view: Callable[P, R]):
def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.INNER_API:
abort(404)
@@ -46,9 +43,9 @@ def enterprise_inner_api_only(view: Callable[P, R]):
return decorated
def enterprise_inner_api_user_auth(view: Callable[P, R]):
def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.INNER_API:
return view(*args, **kwargs)
@@ -82,9 +79,9 @@ def enterprise_inner_api_user_auth(view: Callable[P, R]):
return decorated
def plugin_inner_api_only(view: Callable[P, R]):
def plugin_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.PLUGIN_DAEMON_KEY:
abort(404)

View File

@@ -3,7 +3,7 @@ import time
from collections.abc import Callable
from enum import StrEnum, auto
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar, cast, overload
from typing import Any, cast, overload
from flask import current_app, request
from flask_login import user_logged_in
@@ -23,10 +23,6 @@ from services.api_token_service import ApiTokenCache, fetch_token_with_single_fl
from services.end_user_service import EndUserService
from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
logger = logging.getLogger(__name__)
@@ -46,16 +42,16 @@ class FetchUserArg(BaseModel):
@overload
def validate_app_token(view: Callable[P, R]) -> Callable[P, R]: ...
def validate_app_token[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
@overload
def validate_app_token(
def validate_app_token[**P, R](
view: None = None, *, fetch_user_arg: FetchUserArg | None = None
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def validate_app_token(
def validate_app_token[**P, R](
view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@@ -136,7 +132,10 @@ def validate_app_token(
return decorator(view)
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
def cloud_edition_billing_resource_check[**P, R](
resource: str,
api_token_type: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]):
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type)
@@ -166,7 +165,10 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
return interceptor
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
def cloud_edition_billing_knowledge_limit_check[**P, R](
resource: str,
api_token_type: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
@@ -188,7 +190,10 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
return interceptor
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
def cloud_edition_billing_rate_limit_check[**P, R](
resource: str,
api_token_type: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
@@ -225,20 +230,12 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
return interceptor
@overload
def validate_dataset_token(view: Callable[Concatenate[T, P], R]) -> Callable[P, R]: ...
@overload
def validate_dataset_token(view: None = None) -> Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]: ...
def validate_dataset_token(
view: Callable[Concatenate[T, P], R] | None = None,
) -> Callable[P, R] | Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]:
def decorator(view_func: Callable[Concatenate[T, P], R]) -> Callable[P, R]:
view: Callable[..., Any] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(view_func)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated(*args: Any, **kwargs: Any) -> Any:
api_token = validate_and_get_api_token("dataset")
# get url path dataset_id from positional args or kwargs
@@ -308,7 +305,10 @@ def validate_dataset_token(
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
return view_func(api_token.tenant_id, *args, **kwargs) # type: ignore[arg-type]
if args and isinstance(args[0], Resource):
return view_func(args[0], api_token.tenant_id, *args[1:], **kwargs)
return view_func(api_token.tenant_id, *args, **kwargs)
return decorated

View File

@@ -1,7 +1,7 @@
from collections.abc import Callable
from datetime import UTC, datetime
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from typing import Concatenate
from flask import request
from flask_restx import Resource
@@ -20,14 +20,13 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
P = ParamSpec("P")
R = TypeVar("R")
def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None = None):
def decorator(view: Callable[Concatenate[App, EndUser, P], R]):
def validate_jwt_token[**P, R](
view: Callable[Concatenate[App, EndUser, P], R] | None = None,
) -> Callable[P, R] | Callable[[Callable[Concatenate[App, EndUser, P], R]], Callable[P, R]]:
def decorator(view: Callable[Concatenate[App, EndUser, P], R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
app_model, end_user = decode_jwt_token()
return view(app_model, end_user, *args, **kwargs)
@@ -38,7 +37,7 @@ def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None =
return decorator
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None):
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) -> tuple[App, EndUser]:
system_features = FeatureService.get_system_features()
if not app_code:
app_code = str(request.headers.get(HEADER_NAME_APP_CODE))

View File

@@ -5,7 +5,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from typing import TYPE_CHECKING, Any, Literal, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -68,7 +68,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
@@ -81,7 +81,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
@@ -94,7 +94,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
@@ -106,7 +106,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
@@ -239,7 +239,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
conversation: Conversation,
message: Message,
application_generate_entity: AdvancedChatAppGenerateEntity,
@@ -271,9 +271,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
@@ -359,7 +359,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Account | EndUser,
args: LoopNodeRunPayload,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
@@ -439,7 +439,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
*,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_execution_repository: WorkflowExecutionRepository,
@@ -451,7 +451,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
pause_state_config: PauseStateLayerConfig | None = None,
graph_runtime_state: GraphRuntimeState | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
@@ -653,10 +653,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
queue_manager: AppQueueManager,
conversation: ConversationSnapshot,
message: MessageSnapshot,
user: Union[Account, EndUser],
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from typing import Any, Literal, overload
from flask import Flask, current_app
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
@@ -37,7 +37,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
self,
*,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
@@ -48,7 +48,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
self,
*,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
@@ -59,21 +59,21 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
self,
*,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ...
) -> Mapping | Generator[Mapping | str, None, None]: ...
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
) -> Mapping | Generator[Mapping | str, None, None]:
"""
Generate App response.

View File

@@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from typing import Any, Literal, overload
from flask import Flask, copy_current_request_context, current_app
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
@@ -36,7 +36,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
@@ -46,7 +46,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
@@ -56,20 +56,20 @@ class ChatAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]: ...
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]:
"""
Generate App response.

View File

@@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from typing import Any, Literal, overload
from flask import Flask, copy_current_request_context, current_app
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
@@ -36,7 +36,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
@@ -46,7 +46,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
@@ -56,20 +56,20 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = False,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: ...
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
@@ -244,10 +244,10 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
message_id: str,
user: Union[Account, EndUser],
user: Account | EndUser,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
) -> Mapping | Generator[Mapping | str, None, None]:
"""
Generate App response.

View File

@@ -7,7 +7,7 @@ import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, cast, overload
from typing import Any, Literal, cast, overload
from flask import Flask, current_app
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
@@ -62,7 +62,7 @@ class PipelineGenerator(BaseAppGenerator):
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
@@ -77,7 +77,7 @@ class PipelineGenerator(BaseAppGenerator):
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
@@ -92,28 +92,28 @@ class PipelineGenerator(BaseAppGenerator):
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
) -> Mapping[str, Any] | Generator[Mapping | str, None, None]: ...
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: str | None = None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None:
# Add null check for dataset
with Session(db.engine, expire_on_commit=False) as session:
@@ -278,7 +278,7 @@ class PipelineGenerator(BaseAppGenerator):
context: contextvars.Context,
pipeline: Pipeline,
workflow_id: str,
user: Union[Account, EndUser],
user: Account | EndUser,
application_generate_entity: RagPipelineGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
@@ -286,7 +286,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
workflow_thread_pool_id: str | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
@@ -624,10 +624,10 @@ class PipelineGenerator(BaseAppGenerator):
application_generate_entity: RagPipelineGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
"""
Handle response.
:param application_generate_entity: application generate entity
@@ -668,7 +668,7 @@ class PipelineGenerator(BaseAppGenerator):
datasource_info: Mapping[str, Any],
created_from: str,
position: int,
account: Union[Account, EndUser],
account: Account | EndUser,
batch: str,
document_form: str,
):
@@ -715,7 +715,7 @@ class PipelineGenerator(BaseAppGenerator):
pipeline: Pipeline,
workflow: Workflow,
start_node_id: str,
user: Union[Account, EndUser],
user: Account | EndUser,
) -> list[Mapping[str, Any]]:
"""
Format datasource info list.

View File

@@ -5,7 +5,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from typing import TYPE_CHECKING, Any, Literal, overload
from flask import Flask, current_app
from graphon.graph_engine.layers import GraphEngineLayer
@@ -64,7 +64,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
@@ -82,7 +82,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
@@ -100,7 +100,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
@@ -110,14 +110,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]: ...
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
@@ -127,7 +127,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]:
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
files: Sequence[Mapping[str, Any]] = args.get("files") or []
@@ -237,7 +237,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
application_generate_entity: WorkflowAppGenerateEntity,
graph_runtime_state: GraphRuntimeState,
workflow_execution_repository: WorkflowExecutionRepository,
@@ -245,7 +245,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Resume a paused workflow execution using the persisted runtime state.
"""
@@ -269,7 +269,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
user: Account | EndUser,
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
@@ -280,7 +280,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
@@ -609,10 +609,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Annotated, Literal, Self, TypeAlias
from typing import Annotated, Literal, Self
from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent
@@ -27,7 +27,7 @@ class _AdvancedChatAppGenerateEntityWrapper(BaseModel):
entity: AdvancedChatAppGenerateEntity
_GenerateEntityUnion: TypeAlias = Annotated[
type _GenerateEntityUnion = Annotated[
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
Field(discriminator="type"),
]

View File

@@ -2,7 +2,7 @@ import logging
import time
from collections.abc import Generator
from threading import Thread
from typing import Any, Union, cast
from typing import Any, cast
from graphon.file import FileTransferMethod
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@@ -72,14 +72,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
"""
_task_state: EasyUITaskState
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
_application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity
_precomputed_event_type: StreamEvent | None = None
def __init__(
self,
application_generate_entity: Union[
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
],
application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
@@ -117,11 +115,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
def process(
self,
) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
]:
) -> (
ChatbotAppBlockingResponse
| CompletionAppBlockingResponse
| Generator[ChatbotAppStreamResponse | CompletionAppStreamResponse, None, None]
):
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
@@ -136,7 +134,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]:
) -> ChatbotAppBlockingResponse | CompletionAppBlockingResponse:
"""
Process blocking response.
:return:
@@ -148,7 +146,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.model_dump()
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
response: ChatbotAppBlockingResponse | CompletionAppBlockingResponse
if self._conversation_mode == AppMode.COMPLETION:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
@@ -183,7 +181,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
) -> Generator[ChatbotAppStreamResponse | CompletionAppStreamResponse, None, None]:
"""
To stream response.
:return:

View File

@@ -5,14 +5,13 @@ This layer centralizes model-quota deduction outside node implementations.
"""
import logging
from typing import TYPE_CHECKING, cast, final
from typing import TYPE_CHECKING, cast, final, override
from graphon.enums import BuiltinNodeTypes
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
from graphon.nodes.base.node import Node
from typing_extensions import override
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available

View File

@@ -10,7 +10,7 @@ associates with the node span.
import logging
from contextvars import Token
from dataclasses import dataclass
from typing import cast, final
from typing import cast, final, override
from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.graph_engine.layers import GraphEngineLayer
@@ -18,7 +18,6 @@ from graphon.graph_events import GraphNodeEventBase
from graphon.nodes.base.node import Node
from opentelemetry import context as context_api
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
from typing_extensions import override
from configs import dify_config
from extensions.otel.parser import (

View File

@@ -44,7 +44,8 @@ class HumanInputContent(BaseModel):
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
# Keep a runtime alias here: callers and tests expect identity with HumanInputContent.
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent # noqa: UP040
__all__ = [
"ExecutionExtraContentDomainModel",

View File

@@ -2,12 +2,13 @@ import importlib.util
import logging
import sys
from types import ModuleType
from typing import AnyStr
logger = logging.getLogger(__name__)
def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
def import_module_from_source[T: (str, bytes)](
*, module_name: str, py_file_path: T, use_lazy_loader: bool = False
) -> ModuleType:
"""
Importing a module from the source file directly
"""

View File

@@ -2,7 +2,6 @@ import os
from collections import OrderedDict
from collections.abc import Callable
from functools import lru_cache
from typing import TypeVar
from configs import dify_config
from core.tools.utils.yaml_utils import load_yaml_file_cached
@@ -65,10 +64,7 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str])
return position_map
T = TypeVar("T")
def is_filtered(
def is_filtered[T](
include_set: set[str],
exclude_set: set[str],
data: T,
@@ -97,11 +93,11 @@ def is_filtered(
return False
def sort_by_position_map(
def sort_by_position_map[T](
position_map: dict[str, int],
data: list[T],
name_func: Callable[[T], str],
):
) -> list[T]:
"""
Sort the objects by the position map.
If the name of the object is not in the position map, it will be put at the end.
@@ -116,11 +112,11 @@ def sort_by_position_map(
return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf")))
def sort_to_dict_by_position_map(
def sort_to_dict_by_position_map[T](
position_map: dict[str, int],
data: list[T],
name_func: Callable[[T], str],
):
) -> OrderedDict[str, T]:
"""
Sort the objects into a ordered dict by the position map.
If the name of the object is not in the position map, it will be put at the end.

View File

@@ -4,7 +4,7 @@ Proxy requests to avoid SSRF
import logging
import time
from typing import Any, TypeAlias
from typing import Any
import httpx
from pydantic import TypeAdapter, ValidationError
@@ -20,8 +20,8 @@ SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
Headers: TypeAlias = dict[str, str]
_HEADERS_ADAPTER = TypeAdapter(Headers)
type Headers = dict[str, str]
_HEADERS_ADAPTER: TypeAdapter[Headers] = TypeAdapter(Headers)
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"

View File

@@ -3,7 +3,7 @@ import queue
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, TypeAlias, final
from typing import Any, final
from urllib.parse import urljoin, urlparse
import httpx
@@ -33,9 +33,9 @@ class _StatusError:
# Type aliases for better readability
ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
type ReadQueue = queue.Queue[SessionMessage | Exception | None]
type WriteQueue = queue.Queue[SessionMessage | Exception | None]
type StatusQueue = queue.Queue[_StatusReady | _StatusError]
class SSETransport:

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, Generic, TypeVar
from typing import Any, TypeVar
from pydantic import BaseModel
@@ -9,13 +9,12 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
@dataclass
class RequestContext(Generic[SessionT, LifespanContextT]):
class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]:
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT

View File

@@ -4,7 +4,7 @@ from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from datetime import timedelta
from types import TracebackType
from typing import Any, Generic, Self, TypeVar
from typing import Any, Self, cast
from httpx import HTTPStatusError
from pydantic import BaseModel
@@ -34,16 +34,10 @@ from core.mcp.types import (
logger = logging.getLogger(__name__)
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResultT: ClientResult | ServerResult]:
"""Handles responding to MCP requests and manages request lifecycle.
This class MUST be used as a context manager to ensure proper cleanup and
@@ -60,7 +54,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""
request: ReceiveRequestT
_session: Any
_session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]"
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
def __init__(
@@ -68,7 +62,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""",
session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
):
self.request_id = request_id
@@ -111,7 +105,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self.completed = True
self._session._send_response(request_id=self.request_id, response=response)
self._session.send_response(request_id=self.request_id, response=response)
def cancel(self):
"""Cancel this request and mark it as completed."""
@@ -120,21 +114,19 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self.completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
self._session._send_response(
self._session.send_response(
request_id=self.request_id,
response=ErrorData(code=0, message="Request cancelled", data=None),
)
class BaseSession(
Generic[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
],
):
class BaseSession[
SendRequestT: ClientRequest | ServerRequest,
SendNotificationT: ClientNotification | ServerNotification,
SendResultT: ClientResult | ServerResult,
ReceiveRequestT: ClientRequest | ServerRequest,
ReceiveNotificationT: ClientNotification | ServerNotification,
]:
"""
Implements an MCP "session" on top of read/write streams, including features
like request/response linking, notifications, and progress.
@@ -204,13 +196,13 @@ class BaseSession(
# The receiver thread should have already exited due to the None message in the queue
self._executor.shutdown(wait=False)
def send_request(
def send_request[T: BaseModel](
self,
request: SendRequestT,
result_type: type[ReceiveResultT],
result_type: type[T],
request_read_timeout_seconds: timedelta | None = None,
metadata: MessageMetadata | None = None,
) -> ReceiveResultT:
) -> T:
"""
Sends a request and wait for a response. Raises an McpError if the
response contains an error. If a request read timeout is provided, it
@@ -299,7 +291,7 @@ class BaseSession(
)
self._write_stream.put(session_message)
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
def send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
@@ -346,6 +338,7 @@ class BaseSession(
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
validated_request = cast(ReceiveRequestT, validated_request)
responder = RequestResponder[ReceiveRequestT, SendResultT](
request_id=message.message.root.id,
@@ -366,6 +359,7 @@ class BaseSession(
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
notification = cast(ReceiveNotificationT, notification)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable
from dataclasses import dataclass
from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar
from typing import Annotated, Any, Literal
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
from pydantic.networks import AnyUrl, UrlConstraints
@@ -31,7 +31,7 @@ ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
AnyFunction: TypeAlias = Callable[..., Any]
type AnyFunction = Callable[..., Any]
class RequestParams(BaseModel):
@@ -68,12 +68,7 @@ class NotificationParams(BaseModel):
"""
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None)
NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams | dict[str, Any] | None)
MethodT = TypeVar("MethodT", bound=str)
class Request(BaseModel, Generic[RequestParamsT, MethodT]):
class Request[RequestParamsT: RequestParams | dict[str, Any] | None, MethodT: str](BaseModel):
"""Base class for JSON-RPC requests."""
method: MethodT
@@ -81,14 +76,14 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
model_config = ConfigDict(extra="allow")
class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]):
class PaginatedRequest[T: str](Request[PaginatedRequestParams | None, T]):
"""Base class for paginated requests,
matching the schema's PaginatedRequest interface."""
params: PaginatedRequestParams | None = None
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
class Notification[NotificationParamsT: NotificationParams | dict[str, Any] | None, MethodT: str](BaseModel):
"""Base class for JSON-RPC notifications."""
method: MethodT
@@ -736,7 +731,7 @@ class ResourceLink(Resource):
ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource
"""A content block that can be used in prompts and tool results."""
Content: TypeAlias = ContentBlock
type Content = ContentBlock
# """DEPRECATED: Content is deprecated, you should use ContentBlock directly."""

View File

@@ -6,7 +6,7 @@ import queue
import threading
import time
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, TypedDict
from uuid import UUID, uuid4
from cachetools import LRUCache
@@ -14,7 +14,6 @@ from flask import current_app
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from typing_extensions import TypedDict
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
@@ -464,7 +463,7 @@ class OpsTraceManager:
@classmethod
def get_ops_trace_instance(
cls,
app_id: Union[UUID, str] | None = None,
app_id: UUID | str | None = None,
):
"""
Get ops trace through model config
@@ -717,7 +716,7 @@ class TraceTask:
self,
trace_type: Any,
message_id: str | None = None,
workflow_execution: Optional["WorkflowExecution"] = None,
workflow_execution: "WorkflowExecution | None" = None,
conversation_id: str | None = None,
user_id: str | None = None,
timer: Any | None = None,

View File

@@ -1,5 +1,4 @@
from collections.abc import Generator, Mapping
from typing import Generic, TypeVar
from pydantic import BaseModel
@@ -19,9 +18,6 @@ class BaseBackwardsInvocation:
yield BaseBackwardsInvocationResponse(data=response).model_dump_json().encode()
T = TypeVar("T", bound=dict | Mapping | str | bool | int | BaseModel)
class BaseBackwardsInvocationResponse(BaseModel, Generic[T]):
class BaseBackwardsInvocationResponse[T: dict | Mapping | str | bool | int | BaseModel](BaseModel):
data: T | None = None
error: str = ""

View File

@@ -4,7 +4,7 @@ import enum
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from typing import Any, Generic, TypeVar
from typing import Any
from graphon.model_runtime.entities.model_entities import AIModelEntity
from graphon.model_runtime.entities.provider_entities import ProviderEntity
@@ -19,10 +19,8 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
from core.trigger.entities.entities import TriggerProviderEntity
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
class PluginDaemonBasicResponse(BaseModel, Generic[T]):
class PluginDaemonBasicResponse[T: BaseModel | dict | list | bool | str](BaseModel):
"""
Basic response from plugin daemon.
"""

View File

@@ -2,7 +2,7 @@ import inspect
import json
import logging
from collections.abc import Callable, Generator
from typing import Any, TypeVar, cast
from typing import Any, cast
import httpx
from graphon.model_runtime.errors.invoke import (
@@ -51,8 +51,6 @@ elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout):
else:
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
T = TypeVar("T", bound=(BaseModel | dict[str, Any] | list[Any] | bool | str))
logger = logging.getLogger(__name__)
_httpx_client: httpx.Client = get_pooled_http_client(
@@ -191,7 +189,7 @@ class BasePluginClient:
logger.exception("Stream request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
def _stream_request_with_model(
def _stream_request_with_model[T: BaseModel | dict[str, Any] | list[Any] | bool | str](
self,
method: str,
path: str,
@@ -207,7 +205,7 @@ class BasePluginClient:
for line in self._stream_request(method, path, params, headers, data, files):
yield type_(**json.loads(line)) # type: ignore
def _request_with_model(
def _request_with_model[T: BaseModel | dict[str, Any] | list[Any] | bool | str](
self,
method: str,
path: str,
@@ -223,7 +221,7 @@ class BasePluginClient:
response = self._request(method, path, headers, data, params, files)
return type_(**response.json()) # type: ignore[return-value]
def _request_with_plugin_daemon_response(
def _request_with_plugin_daemon_response[T: BaseModel | dict[str, Any] | list[Any] | bool | str](
self,
method: str,
path: str,
@@ -278,7 +276,7 @@ class BasePluginClient:
return rep.data
def _request_with_plugin_daemon_response_stream(
def _request_with_plugin_daemon_response_stream[T: BaseModel | dict[str, Any] | list[Any] | bool | str](
self,
method: str,
path: str,

View File

@@ -1,12 +1,9 @@
from collections.abc import Generator
from dataclasses import dataclass, field
from typing import TypeVar, Union
from core.agent.entities import AgentInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage
MessageType = TypeVar("MessageType", bound=Union[ToolInvokeMessage, AgentInvokeMessage])
@dataclass
class FileChunk:
@@ -22,11 +19,11 @@ class FileChunk:
self.data = bytearray(self.total_length)
def merge_blob_chunks(
response: Generator[MessageType, None, None],
def merge_blob_chunks[T: ToolInvokeMessage | AgentInvokeMessage](
response: Generator[T, None, None],
max_file_size: int = 30 * 1024 * 1024,
max_chunk_size: int = 8192,
) -> Generator[MessageType, None, None]:
) -> Generator[T, None, None]:
"""
Merge streaming blob chunks into complete blob messages.

View File

@@ -1,6 +1,7 @@
from typing import TypedDict
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
from typing_extensions import TypedDict
from core.model_manager import ModelInstance, ModelManager
from core.rag.data_post_processor.reorder import ReorderRunner

View File

@@ -1,10 +1,9 @@
from collections import defaultdict
from typing import Any
from typing import Any, TypedDict
import orjson
from pydantic import BaseModel
from sqlalchemy import select
from typing_extensions import TypedDict
from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler

View File

@@ -1,13 +1,12 @@
import concurrent.futures
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any, NotRequired
from typing import Any, NotRequired, TypedDict
from flask import Flask, current_app
from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from typing_extensions import TypedDict
from configs import dify_config
from core.db.session_factory import session_factory

View File

@@ -3,7 +3,7 @@ import logging
import uuid
from collections.abc import Callable
from functools import wraps
from typing import Any, Concatenate, ParamSpec, TypeVar
from typing import Any, Concatenate
from mo_vector.client import MoVectorClient # type: ignore
from pydantic import BaseModel, model_validator
@@ -20,15 +20,12 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T", bound="MatrixoneVector")
def ensure_client(func: Callable[Concatenate[T, P], R]):
def ensure_client[T: MatrixoneVector, **P, R](
func: Callable[Concatenate[T, P], R],
) -> Callable[Concatenate[T, P], R]:
@wraps(func)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs):
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)

View File

@@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any
import qdrant_client
from flask import current_app
@@ -36,8 +36,8 @@ if TYPE_CHECKING:
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
DictFilter = dict[str, Union[str, int, bool, dict, list]]
MetadataFilter = Union[DictFilter, common_types.Filter]
type DictFilter = dict[str, str | int | bool | dict | list]
type MetadataFilter = DictFilter | common_types.Filter
class PathQdrantParams(BaseModel):

View File

@@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any
import httpx
import qdrant_client
@@ -40,8 +40,8 @@ if TYPE_CHECKING:
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
DictFilter = dict[str, Union[str, int, bool, dict, list]]
MetadataFilter = Union[DictFilter, common_types.Filter]
type DictFilter = dict[str, str | int | bool | dict | list]
type MetadataFilter = DictFilter | common_types.Filter
class TidbOnQdrantConfig(BaseModel):

View File

@@ -1,5 +1,6 @@
from typing import TypedDict
from pydantic import BaseModel
from typing_extensions import TypedDict
from models.dataset import DocumentSegment

View File

@@ -12,11 +12,11 @@ import mimetypes
from collections.abc import Generator, Mapping
from io import BufferedReader, BytesIO
from pathlib import Path, PurePath
from typing import Any, Union
from typing import Any
from pydantic import BaseModel, ConfigDict, model_validator
PathLike = Union[str, PurePath]
type PathLike = str | PurePath
class Blob(BaseModel):
@@ -29,7 +29,7 @@ class Blob(BaseModel):
Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
"""
data: Union[bytes, str, None] = None # Raw data
data: bytes | str | None = None # Raw data
mimetype: str | None = None # Not to be confused with a file extension
encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string
# Location where the original content was found
@@ -75,7 +75,7 @@ class Blob(BaseModel):
raise ValueError(f"Unable to get bytes for blob {self}")
@contextlib.contextmanager
def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]:
def as_bytes_io(self) -> Generator[BytesIO | BufferedReader, None, None]:
"""Read data as a byte stream."""
if isinstance(self.data, bytes):
yield BytesIO(self.data)
@@ -117,7 +117,7 @@ class Blob(BaseModel):
@classmethod
def from_data(
cls,
data: Union[str, bytes],
data: str | bytes,
*,
encoding: str = "utf-8",
mime_type: str | None = None,

View File

@@ -1,9 +1,8 @@
import json
import time
from typing import Any, NotRequired, cast
from typing import Any, NotRequired, TypedDict, cast
import httpx
from typing_extensions import TypedDict
from extensions.ext_storage import storage

View File

@@ -1,11 +1,10 @@
import json
from collections.abc import Generator
from typing import Any, Union
from typing import Any, TypedDict
from urllib.parse import urljoin
import httpx
from httpx import Response
from typing_extensions import TypedDict
from core.rag.extractor.watercrawl.exceptions import (
WaterCrawlAuthenticationError,
@@ -142,7 +141,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
def create_crawl_request(
self,
url: Union[list, str] | None = None,
url: list | str | None = None,
spider_options: SpiderOptions | None = None,
page_options: PageOptions | None = None,
plugin_options: dict[str, Any] | None = None,

View File

@@ -1,8 +1,6 @@
from collections.abc import Generator
from datetime import datetime
from typing import Any
from typing_extensions import TypedDict
from typing import Any, TypedDict
from core.rag.extractor.watercrawl.client import PageOptions, SpiderOptions, WaterCrawlAPIClient

View File

@@ -7,12 +7,11 @@ import os
import re
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, NotRequired, Optional
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
from urllib.parse import unquote, urlparse
import httpx
from sqlalchemy import select
from typing_extensions import TypedDict
from configs import dify_config
from core.entities.knowledge_entities import PreviewDetail
@@ -118,11 +117,12 @@ class BaseIndexProcessor(ABC):
max_tokens: int,
chunk_overlap: int,
separator: str,
embedding_model_instance: Optional["ModelInstance"],
embedding_model_instance: "ModelInstance | None",
) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
"""
character_splitter: TextSplitter
if processing_rule_mode in ["custom", "hierarchical"]:
# The user-defined segmentation rule
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
@@ -148,7 +148,7 @@ class BaseIndexProcessor(ABC):
embedding_model_instance=embedding_model_instance,
)
return character_splitter # type: ignore
return character_splitter
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
"""

View File

@@ -4,19 +4,13 @@ from __future__ import annotations
import codecs
import re
from typing import Any
from collections.abc import Collection
from typing import Any, Literal
from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
from core.model_manager import ModelInstance
from core.rag.splitter.text_splitter import (
TS,
Collection,
Literal,
RecursiveCharacterTextSplitter,
Set,
Union,
)
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
@@ -25,13 +19,13 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
"""
@classmethod
def from_encoder(
cls: type[TS],
def from_encoder[T: EnhanceRecursiveCharacterTextSplitter](
cls: type[T],
embedding_model_instance: ModelInstance | None,
allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
allowed_special: Literal["all"] | set[str] = set(),
disallowed_special: Literal["all"] | Collection[str] = "all",
**kwargs: Any,
):
) -> T:
def _token_encoder(texts: list[str]) -> list[int]:
if not texts:
return []

View File

@@ -6,19 +6,12 @@ import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Collection, Iterable, Sequence, Set
from dataclasses import dataclass
from typing import (
Any,
Literal,
TypeVar,
Union,
)
from typing import Any, Literal
from core.rag.models.document import BaseDocumentTransformer, Document
logger = logging.getLogger(__name__)
TS = TypeVar("TS", bound="TextSplitter")
def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]:
# Now that we have the separator, split the text
@@ -194,8 +187,8 @@ class TokenTextSplitter(TextSplitter):
self,
encoding_name: str = "gpt2",
model_name: str | None = None,
allowed_special: Union[Literal["all"], Set[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
allowed_special: Literal["all"] | Set[str] = set(),
disallowed_special: Literal["all"] | Collection[str] = "all",
**kwargs: Any,
):
"""Create a new TextSplitter."""

View File

@@ -6,7 +6,6 @@ providing improved performance by offloading database operations to background w
"""
import logging
from typing import Union
from graphon.entities import WorkflowExecution
from sqlalchemy.engine import Engine
@@ -47,7 +46,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
user: Account | EndUser,
app_id: str | None,
triggered_from: WorkflowRunTriggeredFrom | None,
):

View File

@@ -7,7 +7,6 @@ providing improved performance by offloading database operations to background w
import logging
from collections.abc import Sequence
from typing import Union
from graphon.entities import WorkflowNodeExecution
from sqlalchemy.engine import Engine
@@ -54,7 +53,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
user: Account | EndUser,
app_id: str | None,
triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
):

View File

@@ -7,7 +7,7 @@ allowing users to configure different repository backends through string paths.
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal, Protocol, Union
from typing import Literal, Protocol
from graphon.entities import WorkflowExecution, WorkflowNodeExecution
from sqlalchemy.engine import Engine
@@ -61,8 +61,8 @@ class DifyCoreRepositoryFactory:
@classmethod
def create_workflow_execution_repository(
cls,
session_factory: Union[sessionmaker, Engine],
user: Union[Account, EndUser],
session_factory: sessionmaker | Engine,
user: Account | EndUser,
app_id: str,
triggered_from: WorkflowRunTriggeredFrom,
) -> WorkflowExecutionRepository:
@@ -97,8 +97,8 @@ class DifyCoreRepositoryFactory:
@classmethod
def create_workflow_node_execution_repository(
cls,
session_factory: Union[sessionmaker, Engine],
user: Union[Account, EndUser],
session_factory: sessionmaker | Engine,
user: Account | EndUser,
app_id: str,
triggered_from: WorkflowNodeExecutionTriggeredFrom,
) -> WorkflowNodeExecutionRepository:

View File

@@ -4,7 +4,6 @@ SQLAlchemy implementation of the WorkflowExecutionRepository.
import json
import logging
from typing import Union
from graphon.entities import WorkflowExecution
from graphon.enums import WorkflowExecutionStatus, WorkflowType
@@ -40,7 +39,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
user: Account | EndUser,
app_id: str | None,
triggered_from: WorkflowRunTriggeredFrom | None,
):

View File

@@ -7,7 +7,7 @@ import json
import logging
from collections.abc import Callable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import Any, TypeVar, Union
from typing import Any
import psycopg2.errors
from graphon.entities import WorkflowNodeExecution
@@ -63,7 +63,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
user: Account | EndUser,
app_id: str | None,
triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
):
@@ -551,10 +551,7 @@ def _deterministic_json_dump(value: Mapping[str, Any]) -> str:
return json.dumps(value, sort_keys=True)
_T = TypeVar("_T")
def _find_first(seq: Sequence[_T], pred: Callable[[_T], bool]) -> _T | None:
def _find_first[T](seq: Sequence[T], pred: Callable[[T], bool]) -> T | None:
filtered = [i for i in seq if pred(i)]
if filtered:
return filtered[0]

View File

@@ -3,15 +3,15 @@ import re
import threading
from collections import deque
from dataclasses import dataclass
from typing import Any, Union
from typing import Any
from core.schemas.registry import SchemaRegistry
logger = logging.getLogger(__name__)
# Type aliases for better clarity
SchemaType = Union[dict[str, Any], list[Any], str, int, float, bool, None]
SchemaDict = dict[str, Any]
type SchemaType = dict[str, Any] | list[Any] | str | int | float | bool | None
type SchemaDict = dict[str, Any]
# Pre-compiled pattern for better performance
_DIFY_SCHEMA_PATTERN = re.compile(r"^https://dify\.ai/schemas/(v\d+)/(.+)\.json$")
@@ -54,7 +54,7 @@ class QueueItem:
current: Any
parent: Any | None
key: Union[str, int] | None
key: str | int | None
depth: int
ref_path: set[str]

View File

@@ -1,6 +1,5 @@
import hashlib
import logging
from typing import TypeVar
from redis import RedisError
@@ -11,8 +10,6 @@ logger = logging.getLogger(__name__)
TRIGGER_DEBUG_EVENT_TTL = 300
TTriggerDebugEvent = TypeVar("TTriggerDebugEvent", bound="BaseDebugEvent")
class TriggerDebugEventBus:
"""
@@ -81,15 +78,15 @@ class TriggerDebugEventBus:
return 0
@classmethod
def poll(
def poll[T: BaseDebugEvent](
cls,
event_type: type[TTriggerDebugEvent],
event_type: type[T],
pool_key: str,
tenant_id: str,
user_id: str,
app_id: str,
node_id: str,
) -> TTriggerDebugEvent | None:
) -> T | None:
"""
Poll for an event or register to the waiting pool.

View File

@@ -2,7 +2,7 @@ import importlib
import pkgutil
from collections.abc import Callable, Iterator, Mapping, MutableMapping
from functools import lru_cache
from typing import TYPE_CHECKING, Any, TypeAlias, cast, final
from typing import TYPE_CHECKING, Any, cast, final, override
from graphon.entities.base_node_data import BaseNodeData
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
@@ -22,7 +22,6 @@ from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeDat
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing_extensions import override
from configs import dify_config
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
@@ -192,7 +191,7 @@ class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Nod
NODE_TYPE_CLASSES_MAPPING: MutableMapping[NodeType, Mapping[str, type[Node]]] = _LazyNodeTypeClassesMapping()
LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
type LLMCompatibleNodeData = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
def fetch_memory(

View File

@@ -1,56 +1,17 @@
import logging
from dataclasses import dataclass
from enum import StrEnum, auto
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota consumption operation.
Attributes:
success: Whether the quota charge succeeded
charge_id: UUID for refund, or None if failed/disabled
"""
success: bool
charge_id: str | None
_quota_type: "QuotaType"
def refund(self) -> None:
"""
Refund this quota charge.
Safe to call even if charge failed or was disabled.
This method guarantees no exceptions will be raised.
"""
if self.charge_id:
self._quota_type.refund(self.charge_id)
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
class QuotaType(StrEnum):
"""
Supported quota types for tenant feature usage.
Add additional types here whenever new billable features become available.
"""
# Trigger execution quota
TRIGGER = auto()
# Workflow execution quota
WORKFLOW = auto()
UNLIMITED = auto()
@property
def billing_key(self) -> str:
"""
Get the billing key for the feature.
"""
match self:
case QuotaType.TRIGGER:
return "trigger_event"
@@ -58,152 +19,3 @@ class QuotaType(StrEnum):
return "api_rate_limit"
case _:
raise ValueError(f"Invalid quota type: {self}")
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Consume quota for the feature.
Args:
tenant_id: The tenant identifier
amount: Amount to consume (default: 1)
Returns:
QuotaCharge with success status and charge_id for refund
Raises:
QuotaExceededError: When quota is insufficient
"""
from configs import dify_config
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to consume must be greater than 0")
try:
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
if response.get("result") != "success":
logger.warning(
"Failed to consume quota for %s, feature %s details: %s",
tenant_id,
self.value,
response.get("detail"),
)
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
charge_id = response.get("history_id")
logger.debug(
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
amount,
self.value,
tenant_id,
charge_id,
)
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
except QuotaExceededError:
raise
except Exception:
# fail-safe: allow request on billing errors
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
return unlimited()
def check(self, tenant_id: str, amount: int = 1) -> bool:
"""
Check if tenant has sufficient quota without consuming.
Args:
tenant_id: The tenant identifier
amount: Amount to check (default: 1)
Returns:
True if quota is sufficient, False otherwise
"""
from configs import dify_config
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = self.get_remaining(tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
# fail-safe: allow request on billing errors
return True
def refund(self, charge_id: str) -> None:
"""
Refund quota using charge_id from consume().
This method guarantees no exceptions will be raised.
All errors are logged but silently handled.
Args:
charge_id: The UUID returned from consume()
"""
try:
from configs import dify_config
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not charge_id:
logger.warning("Cannot refund: charge_id is empty")
return
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
if response.get("result") == "success":
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
else:
logger.warning("Refund failed for charge_id: %s", charge_id)
except Exception:
# Catch ALL exceptions - refund must never fail
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
# Don't raise - refund is best-effort and must be silent
def get_remaining(self, tenant_id: str) -> int:
"""
Get remaining quota for the tenant.
Args:
tenant_id: The tenant identifier
Returns:
Remaining quota amount
"""
from services.billing_service import BillingService
try:
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
if isinstance(usage_info, dict):
return usage_info.get("remaining", 0)
# If it returns a simple number, treat it as remaining
return int(usage_info) if usage_info else 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
return -1
def unlimited() -> QuotaCharge:
"""
Return a quota charge for unlimited quota.
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
"""
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)

View File

@@ -3,7 +3,7 @@ import logging
import ssl
from collections.abc import Callable
from datetime import timedelta
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union
from typing import TYPE_CHECKING, Any, Union
import redis
from redis import RedisError
@@ -297,12 +297,7 @@ def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
return RedisBroadcastChannel(_pubsub_redis_client)
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def redis_fallback(default_return: T | None = None): # type: ignore
def redis_fallback[T](default_return: T | None = None): # type: ignore
"""
decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
@@ -310,9 +305,9 @@ def redis_fallback(default_return: T | None = None): # type: ignore
default_return: The value to return when a Redis operation fails. Defaults to None.
"""
def decorator(func: Callable[P, R]):
def decorator[**P, R](func: Callable[P, R]) -> Callable[P, R | T | None]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | T | None:
try:
return func(*args, **kwargs)
except RedisError as e:

View File

@@ -2,7 +2,6 @@ import json
import logging
import os
import time
from typing import Union
from graphon.entities import WorkflowExecution
from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
@@ -27,7 +26,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
user: Account | EndUser,
app_id: str | None,
triggered_from: WorkflowRunTriggeredFrom | None,
):

View File

@@ -11,7 +11,7 @@ import os
import time
from collections.abc import Sequence
from datetime import datetime
from typing import Any, Union
from typing import Any
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@@ -109,7 +109,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
user: Account | EndUser,
app_id: str | None,
triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
):

View File

@@ -1,6 +1,6 @@
import functools
from collections.abc import Callable
from typing import ParamSpec, TypeVar, cast
from typing import cast
from opentelemetry.trace import get_tracer
@@ -8,9 +8,6 @@ from configs import dify_config
from extensions.otel.decorators.handler import SpanHandler
from extensions.otel.runtime import is_instrument_flag_enabled
P = ParamSpec("P")
R = TypeVar("R")
_HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()}
@@ -21,7 +18,7 @@ def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler:
return _HANDLER_INSTANCES[handler_class]
def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[Callable[P, R]], Callable[P, R]]:
def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Decorator that traces a function with an OpenTelemetry span.

View File

@@ -1,11 +1,9 @@
import inspect
from collections.abc import Callable, Mapping
from typing import Any, TypeVar
from typing import Any
from opentelemetry.trace import SpanKind, Status, StatusCode
R = TypeVar("R")
class SpanHandler:
"""
@@ -31,9 +29,9 @@ class SpanHandler:
"""
return f"{wrapped.__module__}.{wrapped.__qualname__}"
def _extract_arguments(
def _extract_arguments[T](
self,
wrapped: Callable[..., R],
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
) -> dict[str, Any] | None:
@@ -61,13 +59,13 @@ class SpanHandler:
except Exception:
return None
def wrapper(
def wrapper[T](
self,
tracer: Any,
wrapped: Callable[..., R],
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
) -> R:
) -> T:
"""
Fully control the wrapper behavior.

View File

@@ -1,6 +1,6 @@
import logging
from collections.abc import Callable, Mapping
from typing import Any, TypeVar
from typing import Any
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.util.types import AttributeValue
@@ -12,19 +12,16 @@ from models.model import Account
logger = logging.getLogger(__name__)
R = TypeVar("R")
class AppGenerateHandler(SpanHandler):
"""Span handler for ``AppGenerateService.generate``."""
def wrapper(
def wrapper[T](
self,
tracer: Any,
wrapped: Callable[..., R],
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
) -> R:
) -> T:
try:
arguments = self._extract_arguments(wrapped, args, kwargs)
if not arguments:

View File

@@ -1,12 +1,12 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, TypeAlias
from typing import Any
from graphon.file import File
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
JSONValue: TypeAlias = Any
type JSONValue = Any
class ResponseModel(BaseModel):

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from datetime import datetime
from typing import TypeAlias
from uuid import uuid4
from graphon.file import File
@@ -10,7 +9,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
JSONValueType: TypeAlias = JSONValue
type JSONValueType = JSONValue
class ResponseModel(BaseModel):

View File

@@ -1,12 +1,10 @@
import contextvars
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING
from flask import Flask, g
T = TypeVar("T")
if TYPE_CHECKING:
from models import Account, EndUser

View File

@@ -42,13 +42,7 @@ def current_account_with_tenant():
return user, user.current_tenant_id
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
"""
If you decorate a view with this, it will ensure that the current user is
logged in and authenticated before calling the actual view. (If they are

View File

@@ -1,26 +1,20 @@
import logging
import sys
import urllib.parse
from dataclasses import dataclass
from typing import NotRequired
from typing import NotRequired, TypedDict
import httpx
from pydantic import TypeAdapter, ValidationError
from core.helper.http_client_pooling import get_pooled_http_client
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
logger = logging.getLogger(__name__)
JsonObject = dict[str, object]
JsonObjectList = list[JsonObject]
type JsonObject = dict[str, object]
type JsonObjectList = list[JsonObject]
JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
JSON_OBJECT_ADAPTER: TypeAdapter[JsonObject] = TypeAdapter(JsonObject)
JSON_OBJECT_LIST_ADAPTER: TypeAdapter[JsonObjectList] = TypeAdapter(JsonObjectList)
# Reuse a pooled httpx.Client for OAuth flows (public endpoints, no SSRF proxy).
_http_client: httpx.Client = get_pooled_http_client(

View File

@@ -1,6 +1,5 @@
import sys
import urllib.parse
from typing import Any, Literal
from typing import Any, Literal, TypedDict
import httpx
from flask_login import current_user
@@ -12,11 +11,6 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
class NotionPageSummary(TypedDict):
page_id: str

View File

@@ -113,6 +113,7 @@ class DataSourceType(StrEnum):
WEBSITE_CRAWL = "website_crawl"
LOCAL_FILE = "local_file"
ONLINE_DOCUMENT = "online_document"
ONLINE_DRIVE = "online_drive"
class ProcessRuleMode(StrEnum):

View File

@@ -8,7 +8,7 @@ from datetime import datetime
from decimal import Decimal
from enum import StrEnum, auto
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, cast
from uuid import uuid4
import sqlalchemy as sa
@@ -19,7 +19,6 @@ from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from graphon.file import helpers as file_helpers
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import TypedDict
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS

View File

@@ -1,6 +1,6 @@
import enum
import uuid
from typing import Any, Generic, TypeVar
from typing import Any
import sqlalchemy as sa
from sqlalchemy import CHAR, TEXT, VARCHAR, LargeBinary, TypeDecorator
@@ -110,17 +110,14 @@ class AdjustedJSON(TypeDecorator[dict | list | None]):
return value
_E = TypeVar("_E", bound=enum.StrEnum)
class EnumText(TypeDecorator[_E | None], Generic[_E]):
class EnumText[T: enum.StrEnum](TypeDecorator[T | None]):
impl = VARCHAR
cache_ok = True
_length: int
_enum_class: type[_E]
_enum_class: type[T]
def __init__(self, enum_class: type[_E], length: int | None = None):
def __init__(self, enum_class: type[T], length: int | None = None):
self._enum_class = enum_class
max_enum_value_len = max(len(e.value) for e in enum_class)
if length is not None:
@@ -131,25 +128,25 @@ class EnumText(TypeDecorator[_E | None], Generic[_E]):
# leave some rooms for future longer enum values.
self._length = max(max_enum_value_len, 20)
def process_bind_param(self, value: _E | str | None, dialect: Dialect) -> str | None:
def process_bind_param(self, value: T | str | None, dialect: Dialect) -> str | None:
if value is None:
return value
if isinstance(value, self._enum_class):
return value.value
# Since _E is bound to StrEnum which inherits from str, at this point value must be str
# Since T is bound to StrEnum which inherits from str, at this point value must be str
self._enum_class(value)
return value
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
return dialect.type_descriptor(VARCHAR(self._length))
def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None:
def process_result_value(self, value: str | None, dialect: Dialect) -> T | None:
if value is None or value == "":
return None
# Type annotation guarantees value is str at this point
return self._enum_class(value)
def compare_values(self, x: _E | None, y: _E | None) -> bool:
def compare_values(self, x: T | None, y: T | None) -> bool:
if x is None or y is None:
return x is y
return x == y

View File

@@ -1,7 +1,7 @@
[project]
name = "dify-api"
version = "1.13.3"
requires-python = ">=3.11,<3.13"
requires-python = "~=3.12.0"
dependencies = [
"aliyun-log-python-sdk~=0.9.37",
@@ -232,5 +232,5 @@ vdb = [
project-includes = ["."]
project-excludes = [".venv", "migrations/"]
python-platform = "linux"
python-version = "3.11.0"
python-version = "3.12.0"
infer-with-first-use = false

View File

@@ -50,6 +50,6 @@
"reportUntypedFunctionDecorator": "hint",
"reportUnnecessaryTypeIgnoreComment": "hint",
"reportAttributeAccessIssue": "hint",
"pythonVersion": "3.11",
"pythonVersion": "3.12",
"pythonPlatform": "All"
}
}

View File

@@ -5,12 +5,11 @@ import secrets
import uuid
from datetime import UTC, datetime, timedelta
from hashlib import sha256
from typing import Any, cast
from typing import Any, TypedDict, cast
from pydantic import BaseModel, TypeAdapter
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from typing_extensions import TypedDict
class InvitationData(TypedDict):

View File

@@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit
from core.app.features.rate_limiting.rate_limit import rate_limit_context
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.db import session_factory
from enums.quota_type import QuotaType, unlimited
from enums.quota_type import QuotaType
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow, WorkflowRun
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow_service import WorkflowService
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
@@ -106,7 +107,7 @@ class AppGenerateService:
quota_charge = unlimited()
if dify_config.BILLING_ENABLED:
try:
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id)
except QuotaExceededError:
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
@@ -116,6 +117,7 @@ class AppGenerateService:
request_id = RateLimit.gen_request_key()
try:
request_id = rate_limit.enter(request_id)
quota_charge.commit()
if app_model.mode == AppMode.COMPLETION:
return rate_limit.generate(
CompletionAppGenerator.convert_to_event_stream(

View File

@@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService
@@ -131,9 +132,10 @@ class AsyncWorkflowService:
trigger_log = trigger_log_repo.create(trigger_log)
session.commit()
# 7. Check and consume quota
# 7. Reserve quota (commit after successful dispatch)
quota_charge = unlimited()
try:
QuotaType.WORKFLOW.consume(trigger_data.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id)
except QuotaExceededError as e:
# Update trigger log status
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
@@ -153,13 +155,18 @@ class AsyncWorkflowService:
# 9. Dispatch to appropriate queue
task_data_dict = task_data.model_dump(mode="json")
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
try:
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
# 10. Update trigger log with task info
trigger_log.status = WorkflowTriggerStatus.QUEUED

View File

@@ -1,7 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
from typing_extensions import TypedDict
from typing import Any, TypedDict
class AuthCredentials(TypedDict):

View File

@@ -2,13 +2,12 @@ import json
import logging
import os
from collections.abc import Sequence
from typing import Literal
from typing import Literal, NotRequired, TypedDict
import httpx
from pydantic import TypeAdapter
from sqlalchemy import select
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
from typing_extensions import TypedDict
from werkzeug.exceptions import InternalServerError
from core.helper.http_client_pooling import get_pooled_http_client
@@ -33,6 +32,81 @@ class SubscriptionPlan(TypedDict):
expiration_date: int
class QuotaReserveResult(TypedDict):
reservation_id: str
available: int
reserved: int
class QuotaCommitResult(TypedDict):
available: int
reserved: int
refunded: int
class QuotaReleaseResult(TypedDict):
available: int
reserved: int
released: int
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
class _BillingQuota(TypedDict):
size: int
limit: int
class _VectorSpaceQuota(TypedDict):
size: float
limit: int
class _KnowledgeRateLimit(TypedDict):
# NOTE (hj24):
# 1. Return for sandbox users but is null for other plans, it's defined but never used.
# 2. Keep it for compatibility for now, can be deprecated in future versions.
size: NotRequired[int]
# NOTE END
limit: int
class _BillingSubscription(TypedDict):
plan: str
interval: str
education: bool
class BillingInfo(TypedDict):
"""Response of /subscription/info.
NOTE (hj24):
- Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python()
- To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter:
1. validate_python in non-strict mode will coerce it to the expected type
2. In strict mode, it will raise ValidationError
3. To preserve compatibility, always keep non-strict mode here and avoid strict mode
"""
enabled: bool
subscription: _BillingSubscription
members: _BillingQuota
apps: _BillingQuota
vector_space: _VectorSpaceQuota
knowledge_rate_limit: _KnowledgeRateLimit
documents_upload_quota: _BillingQuota
annotation_quota_limit: _BillingQuota
docs_processing: str
can_replace_logo: bool
model_load_balancing_enabled: bool
knowledge_pipeline_publish_enabled: bool
next_credit_reset_date: NotRequired[int]
_billing_info_adapter = TypeAdapter(BillingInfo)
class BillingService:
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
@@ -45,19 +119,69 @@ class BillingService:
_PLAN_CACHE_TTL = 600
@classmethod
def get_info(cls, tenant_id: str):
def get_info(cls, tenant_id: str) -> BillingInfo:
params = {"tenant_id": tenant_id}
billing_info = cls._send_request("GET", "/subscription/info", params=params)
return billing_info
return _billing_info_adapter.validate_python(billing_info)
@classmethod
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
"""Deprecated: Use get_quota_info instead."""
params = {"tenant_id": tenant_id}
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
return usage_info
@classmethod
def get_quota_info(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
return cls._send_request("GET", "/quota/info", params=params)
@classmethod
def quota_reserve(
cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None
) -> QuotaReserveResult:
"""Reserve quota before task execution."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"request_id": request_id,
"amount": amount,
}
if meta:
payload["meta"] = meta
return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload))
@classmethod
def quota_commit(
cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None
) -> QuotaCommitResult:
"""Commit a reservation with actual consumption."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
"actual_amount": actual_amount,
}
if meta:
payload["meta"] = meta
return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload))
@classmethod
def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult:
"""Release a reservation (cancel, return frozen quota)."""
return _quota_release_adapter.validate_python(
cls._send_request(
"POST",
"/quota/release",
json={
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
},
)
)
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
params = {"tenant_id": tenant_id}

View File

@@ -1,7 +1,7 @@
import contextlib
import logging
from collections.abc import Callable, Sequence
from typing import Any, Union
from typing import Any
from graphon.variables.types import SegmentType
from sqlalchemy import asc, desc, func, or_, select
@@ -37,7 +37,7 @@ class ConversationService:
*,
session: Session,
app_model: App,
user: Union[Account, EndUser] | None,
user: Account | EndUser | None,
last_id: str | None,
limit: int,
invoke_from: InvokeFrom,
@@ -119,7 +119,7 @@ class ConversationService:
cls,
app_model: App,
conversation_id: str,
user: Union[Account, EndUser] | None,
user: Account | EndUser | None,
name: str | None,
auto_generate: bool,
):
@@ -159,7 +159,7 @@ class ConversationService:
return conversation
@classmethod
def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
def get_conversation(cls, app_model: App, conversation_id: str, user: Account | EndUser | None):
conversation = db.session.scalar(
select(Conversation)
.where(
@@ -179,7 +179,7 @@ class ConversationService:
return conversation
@classmethod
def delete(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
def delete(cls, app_model: App, conversation_id: str, user: Account | EndUser | None):
"""
Delete a conversation only if it belongs to the given user and app context.
@@ -209,7 +209,7 @@ class ConversationService:
cls,
app_model: App,
conversation_id: str,
user: Union[Account, EndUser] | None,
user: Account | EndUser | None,
limit: int,
last_id: str | None,
variable_name: str | None = None,
@@ -278,7 +278,7 @@ class ConversationService:
app_model: App,
conversation_id: str,
variable_id: str,
user: Union[Account, EndUser] | None,
user: Account | EndUser | None,
new_value: Any,
):
"""

View File

@@ -274,7 +274,9 @@ class DatasetService:
db.session.flush()
if provider == "external" and external_knowledge_api_id:
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id)
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
external_knowledge_api_id, tenant_id
)
if not external_knowledge_api:
raise ValueError("External API template not found.")
if external_knowledge_id is None:

View File

@@ -102,9 +102,9 @@ class ExternalDatasetService:
raise ValueError(f"Forbidden: Authorization failed with api_key: {api_key}")
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
def get_external_knowledge_api(external_knowledge_api_id: str, tenant_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first()
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")

View File

@@ -281,7 +281,7 @@ class FeatureService:
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
billing_info = BillingService.get_info(tenant_id)
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
features_usage_info = BillingService.get_quota_info(tenant_id)
features.billing.enabled = billing_info["enabled"]
features.billing.subscription.plan = billing_info["subscription"]["plan"]
@@ -312,7 +312,10 @@ class FeatureService:
features.apps.limit = billing_info["apps"]["limit"]
if "vector_space" in billing_info:
features.vector_space.size = billing_info["vector_space"]["size"]
# NOTE (hj24): billing API returns vector_space.size as float (e.g. 0.0)
# but LimitationModel.size is int; truncate here for compatibility
features.vector_space.size = int(billing_info["vector_space"]["size"])
# NOTE END
features.vector_space.limit = billing_info["vector_space"]["limit"]
if "documents_upload_quota" in billing_info:
@@ -333,7 +336,11 @@ class FeatureService:
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
if "knowledge_rate_limit" in billing_info:
# NOTE (hj24):
# 1. knowledge_rate_limit size is nullable, currently it's defined but never used, only limit is used.
# 2. So be careful if later we decide to use [size], we cannot assume it is always present.
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
# NOTE END
if "knowledge_pipeline_publish_enabled" in billing_info:
features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"]

View File

@@ -1,5 +1,4 @@
from collections.abc import Sequence
from typing import Union
from graphon.model_runtime.entities.model_entities import ModelType
from pydantic import TypeAdapter
@@ -57,7 +56,7 @@ class MessageService:
def pagination_by_first_id(
cls,
app_model: App,
user: Union[Account, EndUser] | None,
user: Account | EndUser | None,
conversation_id: str,
first_id: str | None,
limit: int,
@@ -117,7 +116,7 @@ class MessageService:
def pagination_by_last_id(
cls,
app_model: App,
user: Union[Account, EndUser] | None,
user: Account | EndUser | None,
last_id: str | None,
limit: int,
conversation_id: str | None = None,
@@ -170,7 +169,7 @@ class MessageService:
*,
app_model: App,
message_id: str,
user: Union[Account, EndUser] | None,
user: Account | EndUser | None,
rating: FeedbackRating | None,
content: str | None,
):
@@ -221,7 +220,7 @@ class MessageService:
return [record.to_dict() for record in feedbacks]
@classmethod
def get_message(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
def get_message(cls, app_model: App, user: Account | EndUser | None, message_id: str):
message = db.session.scalar(
select(Message)
.where(
@@ -241,7 +240,7 @@ class MessageService:
@classmethod
def get_suggested_questions_after_answer(
cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str, invoke_from: InvokeFrom
cls, app_model: App, user: Account | EndUser | None, message_id: str, invoke_from: InvokeFrom
) -> list[str]:
if not user:
raise ValueError("user cannot be None")

View File

@@ -65,7 +65,7 @@ class MetadataService:
raise ValueError("Metadata name already exists in Built-in fields.")
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id, dataset_id=dataset_id).first()
if metadata is None:
raise ValueError("Metadata not found.")
old_name = metadata.name
@@ -101,7 +101,7 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset_id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id, dataset_id=dataset_id).first()
if metadata is None:
raise ValueError("Metadata not found.")
db.session.delete(metadata)

View File

@@ -110,20 +110,21 @@ class ModelLoadBalancingService:
credential_source_type = CredentialSourceType.CUSTOM_MODEL
# Get load balancing configurations
load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum,
LoadBalancingModelConfig.model_name == model,
or_(
LoadBalancingModelConfig.credential_source_type == credential_source_type,
LoadBalancingModelConfig.credential_source_type.is_(None),
),
)
.order_by(LoadBalancingModelConfig.created_at)
.all()
load_balancing_configs = list(
db.session.scalars(
select(LoadBalancingModelConfig)
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum,
LoadBalancingModelConfig.model_name == model,
or_(
LoadBalancingModelConfig.credential_source_type == credential_source_type,
LoadBalancingModelConfig.credential_source_type.is_(None),
),
)
.order_by(LoadBalancingModelConfig.created_at)
).all()
)
if provider_configuration.custom_configuration.provider:
@@ -143,7 +144,7 @@ class ModelLoadBalancingService:
load_balancing_configs.insert(0, inherit_config)
else:
# move the inherit configuration to the first
for i, load_balancing_config in enumerate(load_balancing_configs[:]):
for i, load_balancing_config in enumerate(load_balancing_configs.copy()):
if load_balancing_config.name == "__inherit__":
inherit_config = load_balancing_configs.pop(i)
load_balancing_configs.insert(0, inherit_config)
@@ -235,8 +236,8 @@ class ModelLoadBalancingService:
model_type_enum = ModelType.value_of(model_type)
# Get load balancing configurations
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
load_balancing_model_config = db.session.scalar(
select(LoadBalancingModelConfig)
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
@@ -244,7 +245,7 @@ class ModelLoadBalancingService:
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
.first()
.limit(1)
)
if not load_balancing_model_config:
@@ -351,26 +352,26 @@ class ModelLoadBalancingService:
if credential_id:
if config_from == "predefined-model":
credential_record = (
db.session.query(ProviderCredential)
.filter_by(
id=credential_id,
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
credential_record = db.session.scalar(
select(ProviderCredential)
.where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == tenant_id,
ProviderCredential.provider_name == provider_configuration.provider.provider,
)
.first()
.limit(1)
)
else:
credential_record = (
db.session.query(ProviderModelCredential)
.filter_by(
id=credential_id,
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_name=model,
model_type=model_type_enum,
credential_record = db.session.scalar(
select(ProviderModelCredential)
.where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == tenant_id,
ProviderModelCredential.provider_name == provider_configuration.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type_enum,
)
.first()
.limit(1)
)
if not credential_record:
raise ValueError(f"Provider credential with id {credential_id} not found")
@@ -510,8 +511,8 @@ class ModelLoadBalancingService:
load_balancing_model_config = None
if config_id:
# Get load balancing config
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
load_balancing_model_config = db.session.scalar(
select(LoadBalancingModelConfig)
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
@@ -519,7 +520,7 @@ class ModelLoadBalancingService:
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
.first()
.limit(1)
)
if not load_balancing_model_config:

View File

@@ -5,7 +5,7 @@ import time
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any
from typing import Any, TypedDict
from uuid import uuid4
import click
@@ -14,7 +14,6 @@ import tqdm
from flask import Flask, current_app
from pydantic import TypeAdapter
from sqlalchemy.orm import Session
from typing_extensions import TypedDict
from core.agent.entities import AgentToolEntity
from core.helper import marketplace

View File

@@ -0,0 +1,233 @@
from __future__ import annotations
import logging
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from configs import dify_config
if TYPE_CHECKING:
from enums.quota_type import QuotaType
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota reservation (Reserve phase).
Lifecycle:
charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id)
try:
do_work()
charge.commit() # Confirm consumption
except:
charge.refund() # Release frozen quota
If neither commit() nor refund() is called, the billing system's
cleanup CronJob will auto-release the reservation within ~75 seconds.
"""
success: bool
charge_id: str | None # reservation_id
_quota_type: QuotaType
_tenant_id: str | None = None
_feature_key: str | None = None
_amount: int = 0
_committed: bool = field(default=False, repr=False)
def commit(self, actual_amount: int | None = None) -> None:
"""
Confirm the consumption with actual amount.
Args:
actual_amount: Actual amount consumed. Defaults to the reserved amount.
If less than reserved, the difference is refunded automatically.
"""
if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key:
return
try:
from services.billing_service import BillingService
amount = actual_amount if actual_amount is not None else self._amount
BillingService.quota_commit(
tenant_id=self._tenant_id,
feature_key=self._feature_key,
reservation_id=self.charge_id,
actual_amount=amount,
)
self._committed = True
logger.debug(
"Committed %s quota for tenant %s, reservation_id: %s, amount: %d",
self._quota_type,
self._tenant_id,
self.charge_id,
amount,
)
except Exception:
logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id)
def refund(self) -> None:
"""
Release the reserved quota (cancel the charge).
Safe to call even if:
- charge failed or was disabled (charge_id is None)
- already committed (Release after Commit is a no-op)
- already refunded (idempotent)
This method guarantees no exceptions will be raised.
"""
if not self.charge_id or not self._tenant_id or not self._feature_key:
return
QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key)
def unlimited() -> QuotaCharge:
from enums.quota_type import QuotaType
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
class QuotaService:
"""Orchestrates quota reserve / commit / release lifecycle via BillingService."""
@staticmethod
def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve + immediate Commit (one-shot mode).
The returned QuotaCharge supports .refund() which calls Release.
For two-phase usage (e.g. streaming), use reserve() directly.
"""
charge = QuotaService.reserve(quota_type, tenant_id, amount)
if charge.success and charge.charge_id:
charge.commit()
return charge
@staticmethod
def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve quota before task execution (Reserve phase only).
The caller MUST call charge.commit() after the task succeeds,
or charge.refund() if the task fails.
Raises:
QuotaExceededError: When quota is insufficient
"""
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type)
logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to reserve must be greater than 0")
request_id = str(uuid.uuid4())
feature_key = quota_type.billing_key
try:
reserve_resp = BillingService.quota_reserve(
tenant_id=tenant_id,
feature_key=feature_key,
request_id=request_id,
amount=amount,
)
reservation_id = reserve_resp.get("reservation_id")
if not reservation_id:
logger.warning(
"Reserve returned no reservation_id for %s, feature %s, response: %s",
tenant_id,
quota_type.value,
reserve_resp,
)
raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount)
logger.debug(
"Reserved %d %s quota for tenant %s, reservation_id: %s",
amount,
quota_type.value,
tenant_id,
reservation_id,
)
return QuotaCharge(
success=True,
charge_id=reservation_id,
_quota_type=quota_type,
_tenant_id=tenant_id,
_feature_key=feature_key,
_amount=amount,
)
except QuotaExceededError:
raise
except ValueError:
raise
except Exception:
logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value)
return unlimited()
@staticmethod
def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool:
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = QuotaService.get_remaining(quota_type, tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value)
return True
@staticmethod
def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None:
"""Release a reservation. Guarantees no exceptions."""
try:
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not reservation_id:
return
logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id)
BillingService.quota_release(
tenant_id=tenant_id,
feature_key=feature_key,
reservation_id=reservation_id,
)
except Exception:
logger.exception("Failed to release quota, reservation_id: %s", reservation_id)
@staticmethod
def get_remaining(quota_type: QuotaType, tenant_id: str) -> int:
from services.billing_service import BillingService
try:
usage_info = BillingService.get_quota_info(tenant_id)
if isinstance(usage_info, dict):
feature_info = usage_info.get(quota_type.billing_key, {})
if isinstance(feature_info, dict):
limit = feature_info.get("limit", 0)
usage = feature_info.get("usage", 0)
if limit == -1:
return -1
return max(0, limit - usage)
return 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value)
return -1

View File

@@ -574,7 +574,7 @@ class RagPipelineService:
outputs=workflow_node_execution.outputs,
)
session.commit()
if workflow_node_execution_db_model is not None:
if isinstance(workflow_node_execution_db_model, WorkflowNodeExecutionModel):
enqueue_draft_node_execution_trace(
execution=workflow_node_execution_db_model,
outputs=workflow_node_execution.outputs,

Some files were not shown because too many files have changed in this diff Show More