mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 10:12:43 +08:00
chore(api): align Python support with 3.12 (#34419)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
@@ -10,7 +10,7 @@ import threading
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable, Generator
|
||||||
from contextlib import AbstractContextManager, contextmanager
|
from contextlib import AbstractContextManager, contextmanager
|
||||||
from typing import Any, Protocol, TypeVar, final, runtime_checkable
|
from typing import Any, Protocol, final, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -188,8 +188,6 @@ class ExecutionContextBuilder:
|
|||||||
_capturer: Callable[[], IExecutionContext] | None = None
|
_capturer: Callable[[], IExecutionContext] | None = None
|
||||||
_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
|
_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
|
||||||
|
|
||||||
|
|
||||||
class ContextProviderNotFoundError(KeyError):
|
class ContextProviderNotFoundError(KeyError):
|
||||||
"""Raised when a tenant-scoped context provider is missing."""
|
"""Raised when a tenant-scoped context provider is missing."""
|
||||||
|
|||||||
@@ -1,7 +1,4 @@
|
|||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Generic, TypeVar
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class HiddenValue:
|
class HiddenValue:
|
||||||
@@ -11,7 +8,7 @@ class HiddenValue:
|
|||||||
_default = HiddenValue()
|
_default = HiddenValue()
|
||||||
|
|
||||||
|
|
||||||
class RecyclableContextVar(Generic[T]):
|
class RecyclableContextVar[T]:
|
||||||
"""
|
"""
|
||||||
RecyclableContextVar is a wrapper around ContextVar
|
RecyclableContextVar is a wrapper around ContextVar
|
||||||
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now
|
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, TypeAlias
|
from typing import Any
|
||||||
|
|
||||||
from graphon.file import helpers as file_helpers
|
from graphon.file import helpers as file_helpers
|
||||||
from pydantic import BaseModel, ConfigDict, computed_field
|
from pydantic import BaseModel, ConfigDict, computed_field
|
||||||
|
|
||||||
from models.model import IconType
|
from models.model import IconType
|
||||||
|
|
||||||
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
type JSONValue = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||||
JSONObject: TypeAlias = dict[str, Any]
|
type JSONObject = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class SystemParameters(BaseModel):
|
class SystemParameters(BaseModel):
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import csv
|
|||||||
import io
|
import io
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import ParamSpec, TypeVar
|
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@@ -20,9 +19,6 @@ from libs.token import extract_access_token
|
|||||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
@@ -72,9 +68,9 @@ console_ns.schema_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def admin_required(view: Callable[P, R]):
|
def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
if not dify_config.ADMIN_API_KEY:
|
if not dify_config.ADMIN_API_KEY:
|
||||||
raise Unauthorized("API key is invalid.")
|
raise Unauthorized("API key is invalid.")
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Literal, TypeAlias
|
from typing import Any, Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@@ -152,7 +152,7 @@ class AppTracePayload(BaseModel):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
JSONValue: TypeAlias = Any
|
type JSONValue = Any
|
||||||
|
|
||||||
|
|
||||||
class ResponseModel(BaseModel):
|
class ResponseModel(BaseModel):
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, NoReturn, ParamSpec, TypeVar
|
from typing import Any
|
||||||
|
|
||||||
from flask import Response, request
|
from flask import Response, request
|
||||||
from flask_restx import Resource, fields, marshal, marshal_with
|
from flask_restx import Resource, fields, marshal, marshal_with
|
||||||
@@ -192,11 +192,8 @@ workflow_draft_variable_list_model = console_ns.model(
|
|||||||
"WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy
|
"WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy
|
||||||
)
|
)
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
def _api_prerequisite(f: Callable[P, R]):
|
|
||||||
"""Common prerequisites for all draft workflow variable APIs.
|
"""Common prerequisites for all draft workflow variable APIs.
|
||||||
|
|
||||||
It ensures the following conditions are satisfied:
|
It ensures the following conditions are satisfied:
|
||||||
@@ -213,7 +210,7 @@ def _api_prerequisite(f: Callable[P, R]):
|
|||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def wrapper(*args: P.args, **kwargs: P.kwargs):
|
def wrapper(*args: Any, **kwargs: Any):
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
@@ -270,7 +267,7 @@ class WorkflowVariableCollectionApi(Resource):
|
|||||||
return Response("", 204)
|
return Response("", 204)
|
||||||
|
|
||||||
|
|
||||||
def validate_node_id(node_id: str) -> NoReturn | None:
|
def validate_node_id(node_id: str) -> None:
|
||||||
if node_id in [
|
if node_id in [
|
||||||
CONVERSATION_VARIABLE_NODE_ID,
|
CONVERSATION_VARIABLE_NODE_ID,
|
||||||
SYSTEM_VARIABLE_NODE_ID,
|
SYSTEM_VARIABLE_NODE_ID,
|
||||||
@@ -285,7 +282,6 @@ def validate_node_id(node_id: str) -> NoReturn | None:
|
|||||||
raise InvalidArgumentError(
|
raise InvalidArgumentError(
|
||||||
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
||||||
)
|
)
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import ParamSpec, TypeVar, Union
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
@@ -9,11 +9,6 @@ from extensions.ext_database import db
|
|||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
from models import App, AppMode
|
from models import App, AppMode
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
P1 = ParamSpec("P1")
|
|
||||||
R1 = TypeVar("R1")
|
|
||||||
|
|
||||||
|
|
||||||
def _load_app_model(app_id: str) -> App | None:
|
def _load_app_model(app_id: str) -> App | None:
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
@@ -28,10 +23,14 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
|
|||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
||||||
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
def get_app_model(
|
||||||
def decorator(view_func: Callable[P1, R1]):
|
view: Callable[..., Any] | None = None,
|
||||||
|
*,
|
||||||
|
mode: AppMode | list[AppMode] | None = None,
|
||||||
|
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||||
|
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
|
def decorated_view(*args: Any, **kwargs: Any):
|
||||||
if not kwargs.get("app_id"):
|
if not kwargs.get("app_id"):
|
||||||
raise ValueError("missing app_id in path parameters")
|
raise ValueError("missing app_id in path parameters")
|
||||||
|
|
||||||
@@ -69,10 +68,14 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
|
|||||||
return decorator(view)
|
return decorator(view)
|
||||||
|
|
||||||
|
|
||||||
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
def get_app_model_with_trial(
|
||||||
def decorator(view_func: Callable[P, R]):
|
view: Callable[..., Any] | None = None,
|
||||||
|
*,
|
||||||
|
mode: AppMode | list[AppMode] | None = None,
|
||||||
|
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||||
|
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
def decorated_view(*args: Any, **kwargs: Any):
|
||||||
if not kwargs.get("app_id"):
|
if not kwargs.get("app_id"):
|
||||||
raise ValueError("missing app_id in path parameters")
|
raise ValueError("missing app_id in path parameters")
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Concatenate, ParamSpec, TypeVar
|
from typing import Concatenate
|
||||||
|
|
||||||
from flask import jsonify, request
|
from flask import jsonify, request
|
||||||
|
from flask.typing import ResponseReturnValue
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -16,10 +17,6 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
|
|||||||
|
|
||||||
from .. import console_ns
|
from .. import console_ns
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthClientPayload(BaseModel):
|
class OAuthClientPayload(BaseModel):
|
||||||
client_id: str
|
client_id: str
|
||||||
@@ -39,9 +36,11 @@ class OAuthTokenRequest(BaseModel):
|
|||||||
refresh_token: str | None = None
|
refresh_token: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
def oauth_server_client_id_required[T, **P, R](
|
||||||
|
view: Callable[Concatenate[T, OAuthProviderApp, P], R],
|
||||||
|
) -> Callable[Concatenate[T, P], R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
json_data = request.get_json()
|
json_data = request.get_json()
|
||||||
if json_data is None:
|
if json_data is None:
|
||||||
raise BadRequest("client_id is required")
|
raise BadRequest("client_id is required")
|
||||||
@@ -58,9 +57,13 @@ def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderA
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
|
def oauth_server_access_token_required[T, **P, R](
|
||||||
|
view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R],
|
||||||
|
) -> Callable[Concatenate[T, OAuthProviderApp, P], R | ResponseReturnValue]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
|
def decorated(
|
||||||
|
self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs
|
||||||
|
) -> R | ResponseReturnValue:
|
||||||
if not isinstance(oauth_provider_app, OAuthProviderApp):
|
if not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||||
raise BadRequest("Invalid oauth_provider_app")
|
raise BadRequest("Invalid oauth_provider_app")
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import ParamSpec, TypeVar
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
@@ -9,13 +8,10 @@ from extensions.ext_database import db
|
|||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
def get_rag_pipeline[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||||
def get_rag_pipeline(view_func: Callable[P, R]):
|
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
if not kwargs.get("pipeline_id"):
|
if not kwargs.get("pipeline_id"):
|
||||||
raise ValueError("missing pipeline_id in path parameters")
|
raise ValueError("missing pipeline_id in path parameters")
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Concatenate, ParamSpec, TypeVar
|
from typing import Concatenate
|
||||||
|
|
||||||
from flask import abort
|
from flask import abort
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@@ -15,12 +15,8 @@ from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
|
|||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
def installed_app_required[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
|
||||||
def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
|
|
||||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||||
@@ -49,7 +45,7 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
|
def user_allowed_to_access_app[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
|
||||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||||
@@ -73,7 +69,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
def trial_app_required[**P, R](view: Callable[Concatenate[App, P], R] | None = None):
|
||||||
def decorator(view: Callable[Concatenate[App, P], R]):
|
def decorator(view: Callable[Concatenate[App, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||||
@@ -106,7 +102,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def trial_feature_enable(view: Callable[P, R]):
|
def trial_feature_enable[**P, R](view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_system_features()
|
features = FeatureService.get_system_features()
|
||||||
@@ -117,7 +113,7 @@ def trial_feature_enable(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def explore_banner_enabled(view: Callable[P, R]):
|
def explore_banner_enabled[**P, R](view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_system_features()
|
features = FeatureService.get_system_features()
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import ParamSpec, TypeVar
|
|
||||||
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
@@ -9,17 +8,14 @@ from extensions.ext_database import db
|
|||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
from models.account import TenantPluginPermission
|
from models.account import TenantPluginPermission
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
|
||||||
def plugin_permission_required(
|
def plugin_permission_required(
|
||||||
install_required: bool = False,
|
install_required: bool = False,
|
||||||
debug_required: bool = False,
|
debug_required: bool = False,
|
||||||
):
|
):
|
||||||
def interceptor(view: Callable[P, R]):
|
def interceptor[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
user = current_user
|
user = current_user
|
||||||
tenant_id = current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import os
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import ParamSpec, TypeVar
|
|
||||||
|
|
||||||
from flask import abort, request
|
from flask import abort, request
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@@ -25,9 +24,6 @@ from services.operation_service import OperationService
|
|||||||
|
|
||||||
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
# Field names for decryption
|
# Field names for decryption
|
||||||
FIELD_NAME_PASSWORD = "password"
|
FIELD_NAME_PASSWORD = "password"
|
||||||
FIELD_NAME_CODE = "code"
|
FIELD_NAME_CODE = "code"
|
||||||
@@ -37,7 +33,7 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
|
|||||||
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
|
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
|
||||||
|
|
||||||
|
|
||||||
def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
|
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
# check account initialization
|
# check account initialization
|
||||||
@@ -50,7 +46,7 @@ def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def only_edition_cloud(view: Callable[P, R]):
|
def only_edition_cloud[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
if dify_config.EDITION != "CLOUD":
|
if dify_config.EDITION != "CLOUD":
|
||||||
@@ -61,7 +57,7 @@ def only_edition_cloud(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def only_edition_enterprise(view: Callable[P, R]):
|
def only_edition_enterprise[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
if not dify_config.ENTERPRISE_ENABLED:
|
if not dify_config.ENTERPRISE_ENABLED:
|
||||||
@@ -72,7 +68,7 @@ def only_edition_enterprise(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def only_edition_self_hosted(view: Callable[P, R]):
|
def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
if dify_config.EDITION != "SELF_HOSTED":
|
if dify_config.EDITION != "SELF_HOSTED":
|
||||||
@@ -83,7 +79,7 @@ def only_edition_self_hosted(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
@@ -95,7 +91,7 @@ def cloud_edition_billing_enabled(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_resource_check(resource: str):
|
def cloud_edition_billing_resource_check[**P, R](resource: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
@@ -137,7 +133,9 @@ def cloud_edition_billing_resource_check(resource: str):
|
|||||||
return interceptor
|
return interceptor
|
||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_knowledge_limit_check(resource: str):
|
def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||||
|
resource: str,
|
||||||
|
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
@@ -160,7 +158,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
|||||||
return interceptor
|
return interceptor
|
||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_rate_limit_check(resource: str):
|
def cloud_edition_billing_rate_limit_check[**P, R](resource: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
@@ -196,7 +194,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||||||
return interceptor
|
return interceptor
|
||||||
|
|
||||||
|
|
||||||
def cloud_utm_record(view: Callable[P, R]):
|
def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
@@ -215,7 +213,7 @@ def cloud_utm_record(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def setup_required(view: Callable[P, R]) -> Callable[P, R]:
|
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
# check setup
|
# check setup
|
||||||
@@ -229,7 +227,7 @@ def setup_required(view: Callable[P, R]) -> Callable[P, R]:
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def enterprise_license_required(view: Callable[P, R]):
|
def enterprise_license_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
settings = FeatureService.get_system_features()
|
settings = FeatureService.get_system_features()
|
||||||
@@ -241,7 +239,7 @@ def enterprise_license_required(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def email_password_login_enabled(view: Callable[P, R]):
|
def email_password_login_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_system_features()
|
features = FeatureService.get_system_features()
|
||||||
@@ -254,7 +252,7 @@ def email_password_login_enabled(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def email_register_enabled(view: Callable[P, R]):
|
def email_register_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_system_features()
|
features = FeatureService.get_system_features()
|
||||||
@@ -267,7 +265,7 @@ def email_register_enabled(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def enable_change_email(view: Callable[P, R]):
|
def enable_change_email[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_system_features()
|
features = FeatureService.get_system_features()
|
||||||
@@ -280,7 +278,7 @@ def enable_change_email(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def is_allow_transfer_owner(view: Callable[P, R]):
|
def is_allow_transfer_owner[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
from libs.workspace_permission import check_workspace_owner_transfer_permission
|
from libs.workspace_permission import check_workspace_owner_transfer_permission
|
||||||
@@ -293,7 +291,7 @@ def is_allow_transfer_owner(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
|
def knowledge_pipeline_publish_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
@@ -305,7 +303,7 @@ def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def edit_permission_required(f: Callable[P, R]):
|
def edit_permission_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def decorated_function(*args: P.args, **kwargs: P.kwargs):
|
def decorated_function(*args: P.args, **kwargs: P.kwargs):
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
@@ -323,7 +321,7 @@ def edit_permission_required(f: Callable[P, R]):
|
|||||||
return decorated_function
|
return decorated_function
|
||||||
|
|
||||||
|
|
||||||
def is_admin_or_owner_required(f: Callable[P, R]):
|
def is_admin_or_owner_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def decorated_function(*args: P.args, **kwargs: P.kwargs):
|
def decorated_function(*args: P.args, **kwargs: P.kwargs):
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
@@ -339,7 +337,7 @@ def is_admin_or_owner_required(f: Callable[P, R]):
|
|||||||
return decorated_function
|
return decorated_function
|
||||||
|
|
||||||
|
|
||||||
def annotation_import_rate_limit(view: Callable[P, R]):
|
def annotation_import_rate_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
"""
|
"""
|
||||||
Rate limiting decorator for annotation import operations.
|
Rate limiting decorator for annotation import operations.
|
||||||
|
|
||||||
@@ -388,7 +386,7 @@ def annotation_import_rate_limit(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def annotation_import_concurrency_limit(view: Callable[P, R]):
|
def annotation_import_concurrency_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
"""
|
"""
|
||||||
Concurrency control decorator for annotation import operations.
|
Concurrency control decorator for annotation import operations.
|
||||||
|
|
||||||
@@ -455,7 +453,7 @@ def _decrypt_field(field_name: str, error_class: type[Exception], error_message:
|
|||||||
payload[field_name] = decoded_value
|
payload[field_name] = decoded_value
|
||||||
|
|
||||||
|
|
||||||
def decrypt_password_field(view: Callable[P, R]):
|
def decrypt_password_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
"""
|
"""
|
||||||
Decorator to decrypt password field in request payload.
|
Decorator to decrypt password field in request payload.
|
||||||
|
|
||||||
@@ -477,7 +475,7 @@ def decrypt_password_field(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def decrypt_code_field(view: Callable[P, R]):
|
def decrypt_code_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
"""
|
"""
|
||||||
Decorator to decrypt verification code field in request payload.
|
Decorator to decrypt verification code field in request payload.
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import ParamSpec, TypeVar
|
|
||||||
|
|
||||||
from flask import current_app, request
|
from flask import current_app, request
|
||||||
from flask_login import user_logged_in
|
from flask_login import user_logged_in
|
||||||
@@ -13,9 +12,6 @@ from libs.login import current_user
|
|||||||
from models.account import Tenant
|
from models.account import Tenant
|
||||||
from models.model import DefaultEndUserSessionID, EndUser
|
from models.model import DefaultEndUserSessionID, EndUser
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
|
||||||
class TenantUserPayload(BaseModel):
|
class TenantUserPayload(BaseModel):
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
@@ -65,9 +61,9 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
|||||||
return user_model
|
return user_model
|
||||||
|
|
||||||
|
|
||||||
def get_user_tenant(view_func: Callable[P, R]):
|
def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
|
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
|
||||||
|
|
||||||
user_id = payload.user_id
|
user_id = payload.user_id
|
||||||
@@ -97,10 +93,14 @@ def get_user_tenant(view_func: Callable[P, R]):
|
|||||||
return decorated_view
|
return decorated_view
|
||||||
|
|
||||||
|
|
||||||
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
|
def plugin_data[**P, R](
|
||||||
def decorator(view_func: Callable[P, R]):
|
view: Callable[P, R] | None = None,
|
||||||
|
*,
|
||||||
|
payload_type: type[BaseModel],
|
||||||
|
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||||
|
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
try:
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -3,10 +3,7 @@ from collections.abc import Callable
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
from hmac import new as hmac_new
|
from hmac import new as hmac_new
|
||||||
from typing import ParamSpec, TypeVar
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
from flask import abort, request
|
from flask import abort, request
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@@ -14,9 +11,9 @@ from extensions.ext_database import db
|
|||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
|
||||||
|
|
||||||
def billing_inner_api_only(view: Callable[P, R]):
|
def billing_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
if not dify_config.INNER_API:
|
if not dify_config.INNER_API:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
@@ -30,9 +27,9 @@ def billing_inner_api_only(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def enterprise_inner_api_only(view: Callable[P, R]):
|
def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
if not dify_config.INNER_API:
|
if not dify_config.INNER_API:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
@@ -46,9 +43,9 @@ def enterprise_inner_api_only(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def enterprise_inner_api_user_auth(view: Callable[P, R]):
|
def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
if not dify_config.INNER_API:
|
if not dify_config.INNER_API:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
@@ -82,9 +79,9 @@ def enterprise_inner_api_user_auth(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def plugin_inner_api_only(view: Callable[P, R]):
|
def plugin_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
if not dify_config.PLUGIN_DAEMON_KEY:
|
if not dify_config.PLUGIN_DAEMON_KEY:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import time
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Concatenate, ParamSpec, TypeVar, cast, overload
|
from typing import Any, cast, overload
|
||||||
|
|
||||||
from flask import current_app, request
|
from flask import current_app, request
|
||||||
from flask_login import user_logged_in
|
from flask_login import user_logged_in
|
||||||
@@ -23,10 +23,6 @@ from services.api_token_service import ApiTokenCache, fetch_token_with_single_fl
|
|||||||
from services.end_user_service import EndUserService
|
from services.end_user_service import EndUserService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -46,16 +42,16 @@ class FetchUserArg(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def validate_app_token(view: Callable[P, R]) -> Callable[P, R]: ...
|
def validate_app_token[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def validate_app_token(
|
def validate_app_token[**P, R](
|
||||||
view: None = None, *, fetch_user_arg: FetchUserArg | None = None
|
view: None = None, *, fetch_user_arg: FetchUserArg | None = None
|
||||||
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||||
|
|
||||||
|
|
||||||
def validate_app_token(
|
def validate_app_token[**P, R](
|
||||||
view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None
|
view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None
|
||||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||||
@@ -136,7 +132,10 @@ def validate_app_token(
|
|||||||
return decorator(view)
|
return decorator(view)
|
||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
def cloud_edition_billing_resource_check[**P, R](
|
||||||
|
resource: str,
|
||||||
|
api_token_type: str,
|
||||||
|
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view: Callable[P, R]):
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
api_token = validate_and_get_api_token(api_token_type)
|
api_token = validate_and_get_api_token(api_token_type)
|
||||||
@@ -166,7 +165,10 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
|||||||
return interceptor
|
return interceptor
|
||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
|
def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||||
|
resource: str,
|
||||||
|
api_token_type: str,
|
||||||
|
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
@@ -188,7 +190,10 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
|
|||||||
return interceptor
|
return interceptor
|
||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
def cloud_edition_billing_rate_limit_check[**P, R](
|
||||||
|
resource: str,
|
||||||
|
api_token_type: str,
|
||||||
|
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
@@ -225,20 +230,12 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
|||||||
return interceptor
|
return interceptor
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def validate_dataset_token(view: Callable[Concatenate[T, P], R]) -> Callable[P, R]: ...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def validate_dataset_token(view: None = None) -> Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]: ...
|
|
||||||
|
|
||||||
|
|
||||||
def validate_dataset_token(
|
def validate_dataset_token(
|
||||||
view: Callable[Concatenate[T, P], R] | None = None,
|
view: Callable[..., Any] | None = None,
|
||||||
) -> Callable[P, R] | Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]:
|
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||||
def decorator(view_func: Callable[Concatenate[T, P], R]) -> Callable[P, R]:
|
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
def decorated(*args: Any, **kwargs: Any) -> Any:
|
||||||
api_token = validate_and_get_api_token("dataset")
|
api_token = validate_and_get_api_token("dataset")
|
||||||
|
|
||||||
# get url path dataset_id from positional args or kwargs
|
# get url path dataset_id from positional args or kwargs
|
||||||
@@ -308,7 +305,10 @@ def validate_dataset_token(
|
|||||||
raise Unauthorized("Tenant owner account does not exist.")
|
raise Unauthorized("Tenant owner account does not exist.")
|
||||||
else:
|
else:
|
||||||
raise Unauthorized("Tenant does not exist.")
|
raise Unauthorized("Tenant does not exist.")
|
||||||
return view_func(api_token.tenant_id, *args, **kwargs) # type: ignore[arg-type]
|
if args and isinstance(args[0], Resource):
|
||||||
|
return view_func(args[0], api_token.tenant_id, *args[1:], **kwargs)
|
||||||
|
|
||||||
|
return view_func(api_token.tenant_id, *args, **kwargs)
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Concatenate, ParamSpec, TypeVar
|
from typing import Concatenate
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@@ -20,14 +20,13 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
|
|||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
from services.webapp_auth_service import WebAppAuthService
|
from services.webapp_auth_service import WebAppAuthService
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
def validate_jwt_token[**P, R](
|
||||||
def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None = None):
|
view: Callable[Concatenate[App, EndUser, P], R] | None = None,
|
||||||
def decorator(view: Callable[Concatenate[App, EndUser, P], R]):
|
) -> Callable[P, R] | Callable[[Callable[Concatenate[App, EndUser, P], R]], Callable[P, R]]:
|
||||||
|
def decorator(view: Callable[Concatenate[App, EndUser, P], R]) -> Callable[P, R]:
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
app_model, end_user = decode_jwt_token()
|
app_model, end_user = decode_jwt_token()
|
||||||
return view(app_model, end_user, *args, **kwargs)
|
return view(app_model, end_user, *args, **kwargs)
|
||||||
|
|
||||||
@@ -38,7 +37,7 @@ def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None =
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None):
|
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) -> tuple[App, EndUser]:
|
||||||
system_features = FeatureService.get_system_features()
|
system_features = FeatureService.get_system_features()
|
||||||
if not app_code:
|
if not app_code:
|
||||||
app_code = str(request.headers.get(HEADER_NAME_APP_CODE))
|
app_code = str(request.headers.get(HEADER_NAME_APP_CODE))
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
@@ -68,7 +68,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
self,
|
self,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
@@ -81,7 +81,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
self,
|
self,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
@@ -94,7 +94,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
self,
|
self,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
@@ -106,7 +106,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
self,
|
self,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
@@ -239,7 +239,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
*,
|
*,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
message: Message,
|
message: Message,
|
||||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
@@ -271,9 +271,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
user: Account | EndUser,
|
user: Account | EndUser,
|
||||||
args: Mapping,
|
args: Mapping[str, Any],
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
||||||
@@ -359,7 +359,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
user: Account | EndUser,
|
user: Account | EndUser,
|
||||||
args: LoopNodeRunPayload,
|
args: LoopNodeRunPayload,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
||||||
@@ -439,7 +439,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
workflow_execution_repository: WorkflowExecutionRepository,
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
@@ -451,7 +451,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
pause_state_config: PauseStateLayerConfig | None = None,
|
||||||
graph_runtime_state: GraphRuntimeState | None = None,
|
graph_runtime_state: GraphRuntimeState | None = None,
|
||||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
||||||
@@ -653,10 +653,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
conversation: ConversationSnapshot,
|
conversation: ConversationSnapshot,
|
||||||
message: MessageSnapshot,
|
message: MessageSnapshot,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
|
||||||
"""
|
"""
|
||||||
Handle response.
|
Handle response.
|
||||||
:param application_generate_entity: application generate entity
|
:param application_generate_entity: application generate entity
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Literal, Union, overload
|
from typing import Any, Literal, overload
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
@@ -37,7 +37,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[False],
|
streaming: Literal[False],
|
||||||
@@ -48,7 +48,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[True],
|
streaming: Literal[True],
|
||||||
@@ -59,21 +59,21 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool,
|
streaming: bool,
|
||||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ...
|
) -> Mapping | Generator[Mapping | str, None, None]: ...
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
|
) -> Mapping | Generator[Mapping | str, None, None]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Literal, Union, overload
|
from typing import Any, Literal, overload
|
||||||
|
|
||||||
from flask import Flask, copy_current_request_context, current_app
|
from flask import Flask, copy_current_request_context, current_app
|
||||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
@@ -36,7 +36,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[True],
|
streaming: Literal[True],
|
||||||
@@ -46,7 +46,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[False],
|
streaming: Literal[False],
|
||||||
@@ -56,20 +56,20 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool,
|
streaming: bool,
|
||||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]: ...
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Literal, Union, overload
|
from typing import Any, Literal, overload
|
||||||
|
|
||||||
from flask import Flask, copy_current_request_context, current_app
|
from flask import Flask, copy_current_request_context, current_app
|
||||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
@@ -36,7 +36,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[True],
|
streaming: Literal[True],
|
||||||
@@ -46,7 +46,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[False],
|
streaming: Literal[False],
|
||||||
@@ -56,20 +56,20 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: ...
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
||||||
@@ -244,10 +244,10 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
self,
|
self,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
|
) -> Mapping | Generator[Mapping | str, None, None]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Literal, Union, cast, overload
|
from typing import Any, Literal, cast, overload
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
@@ -62,7 +62,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
*,
|
*,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[True],
|
streaming: Literal[True],
|
||||||
@@ -77,7 +77,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
*,
|
*,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[False],
|
streaming: Literal[False],
|
||||||
@@ -92,28 +92,28 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
*,
|
*,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool,
|
streaming: bool,
|
||||||
call_depth: int,
|
call_depth: int,
|
||||||
workflow_thread_pool_id: str | None,
|
workflow_thread_pool_id: str | None,
|
||||||
is_retry: bool = False,
|
is_retry: bool = False,
|
||||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
) -> Mapping[str, Any] | Generator[Mapping | str, None, None]: ...
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
call_depth: int = 0,
|
call_depth: int = 0,
|
||||||
workflow_thread_pool_id: str | None = None,
|
workflow_thread_pool_id: str | None = None,
|
||||||
is_retry: bool = False,
|
is_retry: bool = False,
|
||||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None:
|
||||||
# Add null check for dataset
|
# Add null check for dataset
|
||||||
|
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
@@ -278,7 +278,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
context: contextvars.Context,
|
context: contextvars.Context,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
workflow_id: str,
|
workflow_id: str,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
application_generate_entity: RagPipelineGenerateEntity,
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
workflow_execution_repository: WorkflowExecutionRepository,
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
@@ -286,7 +286,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||||
workflow_thread_pool_id: str | None = None,
|
workflow_thread_pool_id: str | None = None,
|
||||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
||||||
@@ -624,10 +624,10 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
application_generate_entity: RagPipelineGenerateEntity,
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
|
||||||
"""
|
"""
|
||||||
Handle response.
|
Handle response.
|
||||||
:param application_generate_entity: application generate entity
|
:param application_generate_entity: application generate entity
|
||||||
@@ -668,7 +668,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
datasource_info: Mapping[str, Any],
|
datasource_info: Mapping[str, Any],
|
||||||
created_from: str,
|
created_from: str,
|
||||||
position: int,
|
position: int,
|
||||||
account: Union[Account, EndUser],
|
account: Account | EndUser,
|
||||||
batch: str,
|
batch: str,
|
||||||
document_form: str,
|
document_form: str,
|
||||||
):
|
):
|
||||||
@@ -715,7 +715,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
start_node_id: str,
|
start_node_id: str,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
) -> list[Mapping[str, Any]]:
|
) -> list[Mapping[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Format datasource info list.
|
Format datasource info list.
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from graphon.graph_engine.layers import GraphEngineLayer
|
from graphon.graph_engine.layers import GraphEngineLayer
|
||||||
@@ -64,7 +64,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
*,
|
*,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[True],
|
streaming: Literal[True],
|
||||||
@@ -82,7 +82,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
*,
|
*,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[False],
|
streaming: Literal[False],
|
||||||
@@ -100,7 +100,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
*,
|
*,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool,
|
streaming: bool,
|
||||||
@@ -110,14 +110,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
root_node_id: str | None = None,
|
root_node_id: str | None = None,
|
||||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
pause_state_config: PauseStateLayerConfig | None = None,
|
||||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]: ...
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
@@ -127,7 +127,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
root_node_id: str | None = None,
|
root_node_id: str | None = None,
|
||||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
pause_state_config: PauseStateLayerConfig | None = None,
|
||||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]:
|
||||||
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
|
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
|
||||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||||
|
|
||||||
@@ -237,7 +237,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
*,
|
*,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
application_generate_entity: WorkflowAppGenerateEntity,
|
application_generate_entity: WorkflowAppGenerateEntity,
|
||||||
graph_runtime_state: GraphRuntimeState,
|
graph_runtime_state: GraphRuntimeState,
|
||||||
workflow_execution_repository: WorkflowExecutionRepository,
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
@@ -245,7 +245,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
pause_state_config: PauseStateLayerConfig | None = None,
|
||||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||||
"""
|
"""
|
||||||
Resume a paused workflow execution using the persisted runtime state.
|
Resume a paused workflow execution using the persisted runtime state.
|
||||||
"""
|
"""
|
||||||
@@ -269,7 +269,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
*,
|
*,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
application_generate_entity: WorkflowAppGenerateEntity,
|
application_generate_entity: WorkflowAppGenerateEntity,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
workflow_execution_repository: WorkflowExecutionRepository,
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
@@ -280,7 +280,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||||
graph_runtime_state: GraphRuntimeState | None = None,
|
graph_runtime_state: GraphRuntimeState | None = None,
|
||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
pause_state_config: PauseStateLayerConfig | None = None,
|
||||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
||||||
@@ -609,10 +609,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
application_generate_entity: WorkflowAppGenerateEntity,
|
application_generate_entity: WorkflowAppGenerateEntity,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
|
||||||
"""
|
"""
|
||||||
Handle response.
|
Handle response.
|
||||||
:param application_generate_entity: application generate entity
|
:param application_generate_entity: application generate entity
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Annotated, Literal, Self, TypeAlias
|
from typing import Annotated, Literal, Self
|
||||||
|
|
||||||
from graphon.graph_engine.layers import GraphEngineLayer
|
from graphon.graph_engine.layers import GraphEngineLayer
|
||||||
from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent
|
from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent
|
||||||
@@ -27,7 +27,7 @@ class _AdvancedChatAppGenerateEntityWrapper(BaseModel):
|
|||||||
entity: AdvancedChatAppGenerateEntity
|
entity: AdvancedChatAppGenerateEntity
|
||||||
|
|
||||||
|
|
||||||
_GenerateEntityUnion: TypeAlias = Annotated[
|
type _GenerateEntityUnion = Annotated[
|
||||||
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
|
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Union, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from graphon.file import FileTransferMethod
|
from graphon.file import FileTransferMethod
|
||||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
@@ -72,14 +72,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_task_state: EasyUITaskState
|
_task_state: EasyUITaskState
|
||||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
_application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity
|
||||||
_precomputed_event_type: StreamEvent | None = None
|
_precomputed_event_type: StreamEvent | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
application_generate_entity: Union[
|
application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity,
|
||||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
|
|
||||||
],
|
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
message: Message,
|
message: Message,
|
||||||
@@ -117,11 +115,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
|
|
||||||
def process(
|
def process(
|
||||||
self,
|
self,
|
||||||
) -> Union[
|
) -> (
|
||||||
ChatbotAppBlockingResponse,
|
ChatbotAppBlockingResponse
|
||||||
CompletionAppBlockingResponse,
|
| CompletionAppBlockingResponse
|
||||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
|
| Generator[ChatbotAppStreamResponse | CompletionAppStreamResponse, None, None]
|
||||||
]:
|
):
|
||||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||||
# start generate conversation name thread
|
# start generate conversation name thread
|
||||||
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||||
@@ -136,7 +134,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
|
|
||||||
def _to_blocking_response(
|
def _to_blocking_response(
|
||||||
self, generator: Generator[StreamResponse, None, None]
|
self, generator: Generator[StreamResponse, None, None]
|
||||||
) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]:
|
) -> ChatbotAppBlockingResponse | CompletionAppBlockingResponse:
|
||||||
"""
|
"""
|
||||||
Process blocking response.
|
Process blocking response.
|
||||||
:return:
|
:return:
|
||||||
@@ -148,7 +146,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
|
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
|
||||||
if self._task_state.metadata:
|
if self._task_state.metadata:
|
||||||
extras["metadata"] = self._task_state.metadata.model_dump()
|
extras["metadata"] = self._task_state.metadata.model_dump()
|
||||||
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
response: ChatbotAppBlockingResponse | CompletionAppBlockingResponse
|
||||||
if self._conversation_mode == AppMode.COMPLETION:
|
if self._conversation_mode == AppMode.COMPLETION:
|
||||||
response = CompletionAppBlockingResponse(
|
response = CompletionAppBlockingResponse(
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@@ -183,7 +181,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
|
|
||||||
def _to_stream_response(
|
def _to_stream_response(
|
||||||
self, generator: Generator[StreamResponse, None, None]
|
self, generator: Generator[StreamResponse, None, None]
|
||||||
) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
|
) -> Generator[ChatbotAppStreamResponse | CompletionAppStreamResponse, None, None]:
|
||||||
"""
|
"""
|
||||||
To stream response.
|
To stream response.
|
||||||
:return:
|
:return:
|
||||||
|
|||||||
@@ -5,14 +5,13 @@ This layer centralizes model-quota deduction outside node implementations.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, cast, final
|
from typing import TYPE_CHECKING, cast, final, override
|
||||||
|
|
||||||
from graphon.enums import BuiltinNodeTypes
|
from graphon.enums import BuiltinNodeTypes
|
||||||
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
|
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
|
||||||
from graphon.graph_engine.layers import GraphEngineLayer
|
from graphon.graph_engine.layers import GraphEngineLayer
|
||||||
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
|
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
|
||||||
from graphon.nodes.base.node import Node
|
from graphon.nodes.base.node import Node
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||||
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ associates with the node span.
|
|||||||
import logging
|
import logging
|
||||||
from contextvars import Token
|
from contextvars import Token
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import cast, final
|
from typing import cast, final, override
|
||||||
|
|
||||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||||
from graphon.graph_engine.layers import GraphEngineLayer
|
from graphon.graph_engine.layers import GraphEngineLayer
|
||||||
@@ -18,7 +18,6 @@ from graphon.graph_events import GraphNodeEventBase
|
|||||||
from graphon.nodes.base.node import Node
|
from graphon.nodes.base.node import Node
|
||||||
from opentelemetry import context as context_api
|
from opentelemetry import context as context_api
|
||||||
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
|
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.otel.parser import (
|
from extensions.otel.parser import (
|
||||||
|
|||||||
@@ -44,7 +44,8 @@ class HumanInputContent(BaseModel):
|
|||||||
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)
|
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)
|
||||||
|
|
||||||
|
|
||||||
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
|
# Keep a runtime alias here: callers and tests expect identity with HumanInputContent.
|
||||||
|
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent # noqa: UP040
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ExecutionExtraContentDomainModel",
|
"ExecutionExtraContentDomainModel",
|
||||||
|
|||||||
@@ -2,12 +2,13 @@ import importlib.util
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import AnyStr
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
|
def import_module_from_source[T: (str, bytes)](
|
||||||
|
*, module_name: str, py_file_path: T, use_lazy_loader: bool = False
|
||||||
|
) -> ModuleType:
|
||||||
"""
|
"""
|
||||||
Importing a module from the source file directly
|
Importing a module from the source file directly
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import os
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import TypeVar
|
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.tools.utils.yaml_utils import load_yaml_file_cached
|
from core.tools.utils.yaml_utils import load_yaml_file_cached
|
||||||
@@ -65,10 +64,7 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str])
|
|||||||
return position_map
|
return position_map
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
def is_filtered[T](
|
||||||
|
|
||||||
|
|
||||||
def is_filtered(
|
|
||||||
include_set: set[str],
|
include_set: set[str],
|
||||||
exclude_set: set[str],
|
exclude_set: set[str],
|
||||||
data: T,
|
data: T,
|
||||||
@@ -97,11 +93,11 @@ def is_filtered(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def sort_by_position_map(
|
def sort_by_position_map[T](
|
||||||
position_map: dict[str, int],
|
position_map: dict[str, int],
|
||||||
data: list[T],
|
data: list[T],
|
||||||
name_func: Callable[[T], str],
|
name_func: Callable[[T], str],
|
||||||
):
|
) -> list[T]:
|
||||||
"""
|
"""
|
||||||
Sort the objects by the position map.
|
Sort the objects by the position map.
|
||||||
If the name of the object is not in the position map, it will be put at the end.
|
If the name of the object is not in the position map, it will be put at the end.
|
||||||
@@ -116,11 +112,11 @@ def sort_by_position_map(
|
|||||||
return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf")))
|
return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf")))
|
||||||
|
|
||||||
|
|
||||||
def sort_to_dict_by_position_map(
|
def sort_to_dict_by_position_map[T](
|
||||||
position_map: dict[str, int],
|
position_map: dict[str, int],
|
||||||
data: list[T],
|
data: list[T],
|
||||||
name_func: Callable[[T], str],
|
name_func: Callable[[T], str],
|
||||||
):
|
) -> OrderedDict[str, T]:
|
||||||
"""
|
"""
|
||||||
Sort the objects into a ordered dict by the position map.
|
Sort the objects into a ordered dict by the position map.
|
||||||
If the name of the object is not in the position map, it will be put at the end.
|
If the name of the object is not in the position map, it will be put at the end.
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Proxy requests to avoid SSRF
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, TypeAlias
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import TypeAdapter, ValidationError
|
from pydantic import TypeAdapter, ValidationError
|
||||||
@@ -20,8 +20,8 @@ SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
|||||||
BACKOFF_FACTOR = 0.5
|
BACKOFF_FACTOR = 0.5
|
||||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||||
|
|
||||||
Headers: TypeAlias = dict[str, str]
|
type Headers = dict[str, str]
|
||||||
_HEADERS_ADAPTER = TypeAdapter(Headers)
|
_HEADERS_ADAPTER: TypeAdapter[Headers] = TypeAdapter(Headers)
|
||||||
|
|
||||||
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
|
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
|
||||||
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"
|
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import queue
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, TypeAlias, final
|
from typing import Any, final
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -33,9 +33,9 @@ class _StatusError:
|
|||||||
|
|
||||||
|
|
||||||
# Type aliases for better readability
|
# Type aliases for better readability
|
||||||
ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
type ReadQueue = queue.Queue[SessionMessage | Exception | None]
|
||||||
WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
type WriteQueue = queue.Queue[SessionMessage | Exception | None]
|
||||||
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
|
type StatusQueue = queue.Queue[_StatusReady | _StatusError]
|
||||||
|
|
||||||
|
|
||||||
class SSETransport:
|
class SSETransport:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any, Generic, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -9,13 +9,12 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut
|
|||||||
|
|
||||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
||||||
|
|
||||||
|
|
||||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||||
LifespanContextT = TypeVar("LifespanContextT")
|
LifespanContextT = TypeVar("LifespanContextT")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RequestContext(Generic[SessionT, LifespanContextT]):
|
class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]:
|
||||||
request_id: RequestId
|
request_id: RequestId
|
||||||
meta: RequestParams.Meta | None
|
meta: RequestParams.Meta | None
|
||||||
session: SessionT
|
session: SessionT
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from collections.abc import Callable
|
|||||||
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
|
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Any, Generic, Self, TypeVar
|
from typing import Any, Self, cast
|
||||||
|
|
||||||
from httpx import HTTPStatusError
|
from httpx import HTTPStatusError
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -34,16 +34,10 @@ from core.mcp.types import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
|
|
||||||
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
|
||||||
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
|
||||||
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
|
||||||
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
|
||||||
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
|
|
||||||
DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
|
DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
|
||||||
|
|
||||||
|
|
||||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResultT: ClientResult | ServerResult]:
|
||||||
"""Handles responding to MCP requests and manages request lifecycle.
|
"""Handles responding to MCP requests and manages request lifecycle.
|
||||||
|
|
||||||
This class MUST be used as a context manager to ensure proper cleanup and
|
This class MUST be used as a context manager to ensure proper cleanup and
|
||||||
@@ -60,7 +54,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
request: ReceiveRequestT
|
request: ReceiveRequestT
|
||||||
_session: Any
|
_session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]"
|
||||||
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
|
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -68,7 +62,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
request_id: RequestId,
|
request_id: RequestId,
|
||||||
request_meta: RequestParams.Meta | None,
|
request_meta: RequestParams.Meta | None,
|
||||||
request: ReceiveRequestT,
|
request: ReceiveRequestT,
|
||||||
session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""",
|
session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]",
|
||||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||||
):
|
):
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
@@ -111,7 +105,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
|
|
||||||
self.completed = True
|
self.completed = True
|
||||||
|
|
||||||
self._session._send_response(request_id=self.request_id, response=response)
|
self._session.send_response(request_id=self.request_id, response=response)
|
||||||
|
|
||||||
def cancel(self):
|
def cancel(self):
|
||||||
"""Cancel this request and mark it as completed."""
|
"""Cancel this request and mark it as completed."""
|
||||||
@@ -120,21 +114,19 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
|
|
||||||
self.completed = True # Mark as completed so it's removed from in_flight
|
self.completed = True # Mark as completed so it's removed from in_flight
|
||||||
# Send an error response to indicate cancellation
|
# Send an error response to indicate cancellation
|
||||||
self._session._send_response(
|
self._session.send_response(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
response=ErrorData(code=0, message="Request cancelled", data=None),
|
response=ErrorData(code=0, message="Request cancelled", data=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseSession(
|
class BaseSession[
|
||||||
Generic[
|
SendRequestT: ClientRequest | ServerRequest,
|
||||||
SendRequestT,
|
SendNotificationT: ClientNotification | ServerNotification,
|
||||||
SendNotificationT,
|
SendResultT: ClientResult | ServerResult,
|
||||||
SendResultT,
|
ReceiveRequestT: ClientRequest | ServerRequest,
|
||||||
ReceiveRequestT,
|
ReceiveNotificationT: ClientNotification | ServerNotification,
|
||||||
ReceiveNotificationT,
|
]:
|
||||||
],
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Implements an MCP "session" on top of read/write streams, including features
|
Implements an MCP "session" on top of read/write streams, including features
|
||||||
like request/response linking, notifications, and progress.
|
like request/response linking, notifications, and progress.
|
||||||
@@ -204,13 +196,13 @@ class BaseSession(
|
|||||||
# The receiver thread should have already exited due to the None message in the queue
|
# The receiver thread should have already exited due to the None message in the queue
|
||||||
self._executor.shutdown(wait=False)
|
self._executor.shutdown(wait=False)
|
||||||
|
|
||||||
def send_request(
|
def send_request[T: BaseModel](
|
||||||
self,
|
self,
|
||||||
request: SendRequestT,
|
request: SendRequestT,
|
||||||
result_type: type[ReceiveResultT],
|
result_type: type[T],
|
||||||
request_read_timeout_seconds: timedelta | None = None,
|
request_read_timeout_seconds: timedelta | None = None,
|
||||||
metadata: MessageMetadata | None = None,
|
metadata: MessageMetadata | None = None,
|
||||||
) -> ReceiveResultT:
|
) -> T:
|
||||||
"""
|
"""
|
||||||
Sends a request and wait for a response. Raises an McpError if the
|
Sends a request and wait for a response. Raises an McpError if the
|
||||||
response contains an error. If a request read timeout is provided, it
|
response contains an error. If a request read timeout is provided, it
|
||||||
@@ -299,7 +291,7 @@ class BaseSession(
|
|||||||
)
|
)
|
||||||
self._write_stream.put(session_message)
|
self._write_stream.put(session_message)
|
||||||
|
|
||||||
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
|
def send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
|
||||||
if isinstance(response, ErrorData):
|
if isinstance(response, ErrorData):
|
||||||
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
||||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
||||||
@@ -346,6 +338,7 @@ class BaseSession(
|
|||||||
validated_request = self._receive_request_type.model_validate(
|
validated_request = self._receive_request_type.model_validate(
|
||||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
)
|
)
|
||||||
|
validated_request = cast(ReceiveRequestT, validated_request)
|
||||||
|
|
||||||
responder = RequestResponder[ReceiveRequestT, SendResultT](
|
responder = RequestResponder[ReceiveRequestT, SendResultT](
|
||||||
request_id=message.message.root.id,
|
request_id=message.message.root.id,
|
||||||
@@ -366,6 +359,7 @@ class BaseSession(
|
|||||||
notification = self._receive_notification_type.model_validate(
|
notification = self._receive_notification_type.model_validate(
|
||||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
)
|
)
|
||||||
|
notification = cast(ReceiveNotificationT, notification)
|
||||||
# Handle cancellation notifications
|
# Handle cancellation notifications
|
||||||
if isinstance(notification.root, CancelledNotification):
|
if isinstance(notification.root, CancelledNotification):
|
||||||
cancelled_id = notification.root.params.requestId
|
cancelled_id = notification.root.params.requestId
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
|
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
|
||||||
from pydantic.networks import AnyUrl, UrlConstraints
|
from pydantic.networks import AnyUrl, UrlConstraints
|
||||||
@@ -31,7 +31,7 @@ ProgressToken = str | int
|
|||||||
Cursor = str
|
Cursor = str
|
||||||
Role = Literal["user", "assistant"]
|
Role = Literal["user", "assistant"]
|
||||||
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
|
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
|
||||||
AnyFunction: TypeAlias = Callable[..., Any]
|
type AnyFunction = Callable[..., Any]
|
||||||
|
|
||||||
|
|
||||||
class RequestParams(BaseModel):
|
class RequestParams(BaseModel):
|
||||||
@@ -68,12 +68,7 @@ class NotificationParams(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None)
|
class Request[RequestParamsT: RequestParams | dict[str, Any] | None, MethodT: str](BaseModel):
|
||||||
NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams | dict[str, Any] | None)
|
|
||||||
MethodT = TypeVar("MethodT", bound=str)
|
|
||||||
|
|
||||||
|
|
||||||
class Request(BaseModel, Generic[RequestParamsT, MethodT]):
|
|
||||||
"""Base class for JSON-RPC requests."""
|
"""Base class for JSON-RPC requests."""
|
||||||
|
|
||||||
method: MethodT
|
method: MethodT
|
||||||
@@ -81,14 +76,14 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]):
|
class PaginatedRequest[T: str](Request[PaginatedRequestParams | None, T]):
|
||||||
"""Base class for paginated requests,
|
"""Base class for paginated requests,
|
||||||
matching the schema's PaginatedRequest interface."""
|
matching the schema's PaginatedRequest interface."""
|
||||||
|
|
||||||
params: PaginatedRequestParams | None = None
|
params: PaginatedRequestParams | None = None
|
||||||
|
|
||||||
|
|
||||||
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
|
class Notification[NotificationParamsT: NotificationParams | dict[str, Any] | None, MethodT: str](BaseModel):
|
||||||
"""Base class for JSON-RPC notifications."""
|
"""Base class for JSON-RPC notifications."""
|
||||||
|
|
||||||
method: MethodT
|
method: MethodT
|
||||||
@@ -736,7 +731,7 @@ class ResourceLink(Resource):
|
|||||||
ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource
|
ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource
|
||||||
"""A content block that can be used in prompts and tool results."""
|
"""A content block that can be used in prompts and tool results."""
|
||||||
|
|
||||||
Content: TypeAlias = ContentBlock
|
type Content = ContentBlock
|
||||||
# """DEPRECATED: Content is deprecated, you should use ContentBlock directly."""
|
# """DEPRECATED: Content is deprecated, you should use ContentBlock directly."""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import queue
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, TypedDict
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from cachetools import LRUCache
|
from cachetools import LRUCache
|
||||||
@@ -14,7 +14,6 @@ from flask import current_app
|
|||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
||||||
from core.ops.entities.config_entity import (
|
from core.ops.entities.config_entity import (
|
||||||
@@ -464,7 +463,7 @@ class OpsTraceManager:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_ops_trace_instance(
|
def get_ops_trace_instance(
|
||||||
cls,
|
cls,
|
||||||
app_id: Union[UUID, str] | None = None,
|
app_id: UUID | str | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get ops trace through model config
|
Get ops trace through model config
|
||||||
@@ -717,7 +716,7 @@ class TraceTask:
|
|||||||
self,
|
self,
|
||||||
trace_type: Any,
|
trace_type: Any,
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
workflow_execution: Optional["WorkflowExecution"] = None,
|
workflow_execution: "WorkflowExecution | None" = None,
|
||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
timer: Any | None = None,
|
timer: Any | None = None,
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Generic, TypeVar
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -19,9 +18,6 @@ class BaseBackwardsInvocation:
|
|||||||
yield BaseBackwardsInvocationResponse(data=response).model_dump_json().encode()
|
yield BaseBackwardsInvocationResponse(data=response).model_dump_json().encode()
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=dict | Mapping | str | bool | int | BaseModel)
|
class BaseBackwardsInvocationResponse[T: dict | Mapping | str | bool | int | BaseModel](BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class BaseBackwardsInvocationResponse(BaseModel, Generic[T]):
|
|
||||||
data: T | None = None
|
data: T | None = None
|
||||||
error: str = ""
|
error: str = ""
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import enum
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any, Generic, TypeVar
|
from typing import Any
|
||||||
|
|
||||||
from graphon.model_runtime.entities.model_entities import AIModelEntity
|
from graphon.model_runtime.entities.model_entities import AIModelEntity
|
||||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||||
@@ -19,10 +19,8 @@ from core.tools.entities.common_entities import I18nObject
|
|||||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
|
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
|
||||||
from core.trigger.entities.entities import TriggerProviderEntity
|
from core.trigger.entities.entities import TriggerProviderEntity
|
||||||
|
|
||||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
|
||||||
|
|
||||||
|
class PluginDaemonBasicResponse[T: BaseModel | dict | list | bool | str](BaseModel):
|
||||||
class PluginDaemonBasicResponse(BaseModel, Generic[T]):
|
|
||||||
"""
|
"""
|
||||||
Basic response from plugin daemon.
|
Basic response from plugin daemon.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable, Generator
|
||||||
from typing import Any, TypeVar, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from graphon.model_runtime.errors.invoke import (
|
from graphon.model_runtime.errors.invoke import (
|
||||||
@@ -51,8 +51,6 @@ elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout):
|
|||||||
else:
|
else:
|
||||||
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
|
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
|
||||||
|
|
||||||
T = TypeVar("T", bound=(BaseModel | dict[str, Any] | list[Any] | bool | str))
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_httpx_client: httpx.Client = get_pooled_http_client(
|
_httpx_client: httpx.Client = get_pooled_http_client(
|
||||||
@@ -191,7 +189,7 @@ class BasePluginClient:
|
|||||||
logger.exception("Stream request to Plugin Daemon Service failed")
|
logger.exception("Stream request to Plugin Daemon Service failed")
|
||||||
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
|
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
|
||||||
|
|
||||||
def _stream_request_with_model(
|
def _stream_request_with_model[T: BaseModel | dict[str, Any] | list[Any] | bool | str](
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
path: str,
|
path: str,
|
||||||
@@ -207,7 +205,7 @@ class BasePluginClient:
|
|||||||
for line in self._stream_request(method, path, params, headers, data, files):
|
for line in self._stream_request(method, path, params, headers, data, files):
|
||||||
yield type_(**json.loads(line)) # type: ignore
|
yield type_(**json.loads(line)) # type: ignore
|
||||||
|
|
||||||
def _request_with_model(
|
def _request_with_model[T: BaseModel | dict[str, Any] | list[Any] | bool | str](
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
path: str,
|
path: str,
|
||||||
@@ -223,7 +221,7 @@ class BasePluginClient:
|
|||||||
response = self._request(method, path, headers, data, params, files)
|
response = self._request(method, path, headers, data, params, files)
|
||||||
return type_(**response.json()) # type: ignore[return-value]
|
return type_(**response.json()) # type: ignore[return-value]
|
||||||
|
|
||||||
def _request_with_plugin_daemon_response(
|
def _request_with_plugin_daemon_response[T: BaseModel | dict[str, Any] | list[Any] | bool | str](
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
path: str,
|
path: str,
|
||||||
@@ -278,7 +276,7 @@ class BasePluginClient:
|
|||||||
|
|
||||||
return rep.data
|
return rep.data
|
||||||
|
|
||||||
def _request_with_plugin_daemon_response_stream(
|
def _request_with_plugin_daemon_response_stream[T: BaseModel | dict[str, Any] | list[Any] | bool | str](
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
path: str,
|
path: str,
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TypeVar, Union
|
|
||||||
|
|
||||||
from core.agent.entities import AgentInvokeMessage
|
from core.agent.entities import AgentInvokeMessage
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
|
||||||
MessageType = TypeVar("MessageType", bound=Union[ToolInvokeMessage, AgentInvokeMessage])
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FileChunk:
|
class FileChunk:
|
||||||
@@ -22,11 +19,11 @@ class FileChunk:
|
|||||||
self.data = bytearray(self.total_length)
|
self.data = bytearray(self.total_length)
|
||||||
|
|
||||||
|
|
||||||
def merge_blob_chunks(
|
def merge_blob_chunks[T: ToolInvokeMessage | AgentInvokeMessage](
|
||||||
response: Generator[MessageType, None, None],
|
response: Generator[T, None, None],
|
||||||
max_file_size: int = 30 * 1024 * 1024,
|
max_file_size: int = 30 * 1024 * 1024,
|
||||||
max_chunk_size: int = 8192,
|
max_chunk_size: int = 8192,
|
||||||
) -> Generator[MessageType, None, None]:
|
) -> Generator[T, None, None]:
|
||||||
"""
|
"""
|
||||||
Merge streaming blob chunks into complete blob messages.
|
Merge streaming blob chunks into complete blob messages.
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
from graphon.model_runtime.entities.model_entities import ModelType
|
from graphon.model_runtime.entities.model_entities import ModelType
|
||||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.rag.data_post_processor.reorder import ReorderRunner
|
from core.rag.data_post_processor.reorder import ReorderRunner
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any, NotRequired
|
from typing import Any, NotRequired, TypedDict
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from graphon.model_runtime.entities.model_entities import ModelType
|
from graphon.model_runtime.entities.model_entities import ModelType
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session, load_only
|
from sqlalchemy.orm import Session, load_only
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.db.session_factory import session_factory
|
from core.db.session_factory import session_factory
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Concatenate, ParamSpec, TypeVar
|
from typing import Any, Concatenate
|
||||||
|
|
||||||
from mo_vector.client import MoVectorClient # type: ignore
|
from mo_vector.client import MoVectorClient # type: ignore
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
@@ -20,15 +20,12 @@ from models.dataset import Dataset
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
T = TypeVar("T", bound="MatrixoneVector")
|
def ensure_client[T: MatrixoneVector, **P, R](
|
||||||
|
func: Callable[Concatenate[T, P], R],
|
||||||
|
) -> Callable[Concatenate[T, P], R]:
|
||||||
def ensure_client(func: Callable[Concatenate[T, P], R]):
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs):
|
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
if self.client is None:
|
if self.client is None:
|
||||||
self.client = self._get_client(None, False)
|
self.client = self._get_client(None, False)
|
||||||
return func(self, *args, **kwargs)
|
return func(self, *args, **kwargs)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Iterable, Sequence
|
from collections.abc import Generator, Iterable, Sequence
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import TYPE_CHECKING, Any, Union
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import qdrant_client
|
import qdrant_client
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
@@ -36,8 +36,8 @@ if TYPE_CHECKING:
|
|||||||
from qdrant_client.conversions import common_types
|
from qdrant_client.conversions import common_types
|
||||||
from qdrant_client.http import models as rest
|
from qdrant_client.http import models as rest
|
||||||
|
|
||||||
DictFilter = dict[str, Union[str, int, bool, dict, list]]
|
type DictFilter = dict[str, str | int | bool | dict | list]
|
||||||
MetadataFilter = Union[DictFilter, common_types.Filter]
|
type MetadataFilter = DictFilter | common_types.Filter
|
||||||
|
|
||||||
|
|
||||||
class PathQdrantParams(BaseModel):
|
class PathQdrantParams(BaseModel):
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Iterable, Sequence
|
from collections.abc import Generator, Iterable, Sequence
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import TYPE_CHECKING, Any, Union
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import qdrant_client
|
import qdrant_client
|
||||||
@@ -40,8 +40,8 @@ if TYPE_CHECKING:
|
|||||||
from qdrant_client.conversions import common_types
|
from qdrant_client.conversions import common_types
|
||||||
from qdrant_client.http import models as rest
|
from qdrant_client.http import models as rest
|
||||||
|
|
||||||
DictFilter = dict[str, Union[str, int, bool, dict, list]]
|
type DictFilter = dict[str, str | int | bool | dict | list]
|
||||||
MetadataFilter = Union[DictFilter, common_types.Filter]
|
type MetadataFilter = DictFilter | common_types.Filter
|
||||||
|
|
||||||
|
|
||||||
class TidbOnQdrantConfig(BaseModel):
|
class TidbOnQdrantConfig(BaseModel):
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from models.dataset import DocumentSegment
|
from models.dataset import DocumentSegment
|
||||||
|
|
||||||
|
|||||||
@@ -12,11 +12,11 @@ import mimetypes
|
|||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from io import BufferedReader, BytesIO
|
from io import BufferedReader, BytesIO
|
||||||
from pathlib import Path, PurePath
|
from pathlib import Path, PurePath
|
||||||
from typing import Any, Union
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, model_validator
|
from pydantic import BaseModel, ConfigDict, model_validator
|
||||||
|
|
||||||
PathLike = Union[str, PurePath]
|
type PathLike = str | PurePath
|
||||||
|
|
||||||
|
|
||||||
class Blob(BaseModel):
|
class Blob(BaseModel):
|
||||||
@@ -29,7 +29,7 @@ class Blob(BaseModel):
|
|||||||
Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
|
Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: Union[bytes, str, None] = None # Raw data
|
data: bytes | str | None = None # Raw data
|
||||||
mimetype: str | None = None # Not to be confused with a file extension
|
mimetype: str | None = None # Not to be confused with a file extension
|
||||||
encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string
|
encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string
|
||||||
# Location where the original content was found
|
# Location where the original content was found
|
||||||
@@ -75,7 +75,7 @@ class Blob(BaseModel):
|
|||||||
raise ValueError(f"Unable to get bytes for blob {self}")
|
raise ValueError(f"Unable to get bytes for blob {self}")
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]:
|
def as_bytes_io(self) -> Generator[BytesIO | BufferedReader, None, None]:
|
||||||
"""Read data as a byte stream."""
|
"""Read data as a byte stream."""
|
||||||
if isinstance(self.data, bytes):
|
if isinstance(self.data, bytes):
|
||||||
yield BytesIO(self.data)
|
yield BytesIO(self.data)
|
||||||
@@ -117,7 +117,7 @@ class Blob(BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_data(
|
def from_data(
|
||||||
cls,
|
cls,
|
||||||
data: Union[str, bytes],
|
data: str | bytes,
|
||||||
*,
|
*,
|
||||||
encoding: str = "utf-8",
|
encoding: str = "utf-8",
|
||||||
mime_type: str | None = None,
|
mime_type: str | None = None,
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import Any, NotRequired, cast
|
from typing import Any, NotRequired, TypedDict, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, Union
|
from typing import Any, TypedDict
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from httpx import Response
|
from httpx import Response
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from core.rag.extractor.watercrawl.exceptions import (
|
from core.rag.extractor.watercrawl.exceptions import (
|
||||||
WaterCrawlAuthenticationError,
|
WaterCrawlAuthenticationError,
|
||||||
@@ -142,7 +141,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
|
|||||||
|
|
||||||
def create_crawl_request(
|
def create_crawl_request(
|
||||||
self,
|
self,
|
||||||
url: Union[list, str] | None = None,
|
url: list | str | None = None,
|
||||||
spider_options: SpiderOptions | None = None,
|
spider_options: SpiderOptions | None = None,
|
||||||
page_options: PageOptions | None = None,
|
page_options: PageOptions | None = None,
|
||||||
plugin_options: dict[str, Any] | None = None,
|
plugin_options: dict[str, Any] | None = None,
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from core.rag.extractor.watercrawl.client import PageOptions, SpiderOptions, WaterCrawlAPIClient
|
from core.rag.extractor.watercrawl.client import PageOptions, SpiderOptions, WaterCrawlAPIClient
|
||||||
|
|
||||||
|
|||||||
@@ -7,12 +7,11 @@ import os
|
|||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import TYPE_CHECKING, Any, NotRequired, Optional
|
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
|
||||||
from urllib.parse import unquote, urlparse
|
from urllib.parse import unquote, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.entities.knowledge_entities import PreviewDetail
|
from core.entities.knowledge_entities import PreviewDetail
|
||||||
@@ -118,11 +117,12 @@ class BaseIndexProcessor(ABC):
|
|||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
chunk_overlap: int,
|
chunk_overlap: int,
|
||||||
separator: str,
|
separator: str,
|
||||||
embedding_model_instance: Optional["ModelInstance"],
|
embedding_model_instance: "ModelInstance | None",
|
||||||
) -> TextSplitter:
|
) -> TextSplitter:
|
||||||
"""
|
"""
|
||||||
Get the NodeParser object according to the processing rule.
|
Get the NodeParser object according to the processing rule.
|
||||||
"""
|
"""
|
||||||
|
character_splitter: TextSplitter
|
||||||
if processing_rule_mode in ["custom", "hierarchical"]:
|
if processing_rule_mode in ["custom", "hierarchical"]:
|
||||||
# The user-defined segmentation rule
|
# The user-defined segmentation rule
|
||||||
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||||
@@ -148,7 +148,7 @@ class BaseIndexProcessor(ABC):
|
|||||||
embedding_model_instance=embedding_model_instance,
|
embedding_model_instance=embedding_model_instance,
|
||||||
)
|
)
|
||||||
|
|
||||||
return character_splitter # type: ignore
|
return character_splitter
|
||||||
|
|
||||||
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
|
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,19 +4,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from collections.abc import Collection
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
|
from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
|
||||||
|
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.rag.splitter.text_splitter import (
|
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
|
||||||
TS,
|
|
||||||
Collection,
|
|
||||||
Literal,
|
|
||||||
RecursiveCharacterTextSplitter,
|
|
||||||
Set,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||||
@@ -25,13 +19,13 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_encoder(
|
def from_encoder[T: EnhanceRecursiveCharacterTextSplitter](
|
||||||
cls: type[TS],
|
cls: type[T],
|
||||||
embedding_model_instance: ModelInstance | None,
|
embedding_model_instance: ModelInstance | None,
|
||||||
allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037
|
allowed_special: Literal["all"] | set[str] = set(),
|
||||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
|
disallowed_special: Literal["all"] | Collection[str] = "all",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
) -> T:
|
||||||
def _token_encoder(texts: list[str]) -> list[int]:
|
def _token_encoder(texts: list[str]) -> list[int]:
|
||||||
if not texts:
|
if not texts:
|
||||||
return []
|
return []
|
||||||
|
|||||||
@@ -6,19 +6,12 @@ import re
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable, Collection, Iterable, Sequence, Set
|
from collections.abc import Callable, Collection, Iterable, Sequence, Set
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import Any, Literal
|
||||||
Any,
|
|
||||||
Literal,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from core.rag.models.document import BaseDocumentTransformer, Document
|
from core.rag.models.document import BaseDocumentTransformer, Document
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
TS = TypeVar("TS", bound="TextSplitter")
|
|
||||||
|
|
||||||
|
|
||||||
def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]:
|
def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]:
|
||||||
# Now that we have the separator, split the text
|
# Now that we have the separator, split the text
|
||||||
@@ -194,8 +187,8 @@ class TokenTextSplitter(TextSplitter):
|
|||||||
self,
|
self,
|
||||||
encoding_name: str = "gpt2",
|
encoding_name: str = "gpt2",
|
||||||
model_name: str | None = None,
|
model_name: str | None = None,
|
||||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
allowed_special: Literal["all"] | Set[str] = set(),
|
||||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
disallowed_special: Literal["all"] | Collection[str] = "all",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Create a new TextSplitter."""
|
"""Create a new TextSplitter."""
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ providing improved performance by offloading database operations to background w
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from graphon.entities import WorkflowExecution
|
from graphon.entities import WorkflowExecution
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
@@ -47,7 +46,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
session_factory: sessionmaker | Engine,
|
session_factory: sessionmaker | Engine,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
app_id: str | None,
|
app_id: str | None,
|
||||||
triggered_from: WorkflowRunTriggeredFrom | None,
|
triggered_from: WorkflowRunTriggeredFrom | None,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ providing improved performance by offloading database operations to background w
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from graphon.entities import WorkflowNodeExecution
|
from graphon.entities import WorkflowNodeExecution
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
@@ -54,7 +53,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
session_factory: sessionmaker | Engine,
|
session_factory: sessionmaker | Engine,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
app_id: str | None,
|
app_id: str | None,
|
||||||
triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
|
triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ allowing users to configure different repository backends through string paths.
|
|||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal, Protocol, Union
|
from typing import Literal, Protocol
|
||||||
|
|
||||||
from graphon.entities import WorkflowExecution, WorkflowNodeExecution
|
from graphon.entities import WorkflowExecution, WorkflowNodeExecution
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
@@ -61,8 +61,8 @@ class DifyCoreRepositoryFactory:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def create_workflow_execution_repository(
|
def create_workflow_execution_repository(
|
||||||
cls,
|
cls,
|
||||||
session_factory: Union[sessionmaker, Engine],
|
session_factory: sessionmaker | Engine,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
app_id: str,
|
app_id: str,
|
||||||
triggered_from: WorkflowRunTriggeredFrom,
|
triggered_from: WorkflowRunTriggeredFrom,
|
||||||
) -> WorkflowExecutionRepository:
|
) -> WorkflowExecutionRepository:
|
||||||
@@ -97,8 +97,8 @@ class DifyCoreRepositoryFactory:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def create_workflow_node_execution_repository(
|
def create_workflow_node_execution_repository(
|
||||||
cls,
|
cls,
|
||||||
session_factory: Union[sessionmaker, Engine],
|
session_factory: sessionmaker | Engine,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
app_id: str,
|
app_id: str,
|
||||||
triggered_from: WorkflowNodeExecutionTriggeredFrom,
|
triggered_from: WorkflowNodeExecutionTriggeredFrom,
|
||||||
) -> WorkflowNodeExecutionRepository:
|
) -> WorkflowNodeExecutionRepository:
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ SQLAlchemy implementation of the WorkflowExecutionRepository.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from graphon.entities import WorkflowExecution
|
from graphon.entities import WorkflowExecution
|
||||||
from graphon.enums import WorkflowExecutionStatus, WorkflowType
|
from graphon.enums import WorkflowExecutionStatus, WorkflowType
|
||||||
@@ -40,7 +39,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
session_factory: sessionmaker | Engine,
|
session_factory: sessionmaker | Engine,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
app_id: str | None,
|
app_id: str | None,
|
||||||
triggered_from: WorkflowRunTriggeredFrom | None,
|
triggered_from: WorkflowRunTriggeredFrom | None,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable, Mapping, Sequence
|
from collections.abc import Callable, Mapping, Sequence
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any, TypeVar, Union
|
from typing import Any
|
||||||
|
|
||||||
import psycopg2.errors
|
import psycopg2.errors
|
||||||
from graphon.entities import WorkflowNodeExecution
|
from graphon.entities import WorkflowNodeExecution
|
||||||
@@ -63,7 +63,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
session_factory: sessionmaker | Engine,
|
session_factory: sessionmaker | Engine,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
app_id: str | None,
|
app_id: str | None,
|
||||||
triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
|
triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
|
||||||
):
|
):
|
||||||
@@ -551,10 +551,7 @@ def _deterministic_json_dump(value: Mapping[str, Any]) -> str:
|
|||||||
return json.dumps(value, sort_keys=True)
|
return json.dumps(value, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
def _find_first[T](seq: Sequence[T], pred: Callable[[T], bool]) -> T | None:
|
||||||
|
|
||||||
|
|
||||||
def _find_first(seq: Sequence[_T], pred: Callable[[_T], bool]) -> _T | None:
|
|
||||||
filtered = [i for i in seq if pred(i)]
|
filtered = [i for i in seq if pred(i)]
|
||||||
if filtered:
|
if filtered:
|
||||||
return filtered[0]
|
return filtered[0]
|
||||||
|
|||||||
@@ -3,15 +3,15 @@ import re
|
|||||||
import threading
|
import threading
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union
|
from typing import Any
|
||||||
|
|
||||||
from core.schemas.registry import SchemaRegistry
|
from core.schemas.registry import SchemaRegistry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Type aliases for better clarity
|
# Type aliases for better clarity
|
||||||
SchemaType = Union[dict[str, Any], list[Any], str, int, float, bool, None]
|
type SchemaType = dict[str, Any] | list[Any] | str | int | float | bool | None
|
||||||
SchemaDict = dict[str, Any]
|
type SchemaDict = dict[str, Any]
|
||||||
|
|
||||||
# Pre-compiled pattern for better performance
|
# Pre-compiled pattern for better performance
|
||||||
_DIFY_SCHEMA_PATTERN = re.compile(r"^https://dify\.ai/schemas/(v\d+)/(.+)\.json$")
|
_DIFY_SCHEMA_PATTERN = re.compile(r"^https://dify\.ai/schemas/(v\d+)/(.+)\.json$")
|
||||||
@@ -54,7 +54,7 @@ class QueueItem:
|
|||||||
|
|
||||||
current: Any
|
current: Any
|
||||||
parent: Any | None
|
parent: Any | None
|
||||||
key: Union[str, int] | None
|
key: str | int | None
|
||||||
depth: int
|
depth: int
|
||||||
ref_path: set[str]
|
ref_path: set[str]
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from typing import TypeVar
|
|
||||||
|
|
||||||
from redis import RedisError
|
from redis import RedisError
|
||||||
|
|
||||||
@@ -11,8 +10,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
TRIGGER_DEBUG_EVENT_TTL = 300
|
TRIGGER_DEBUG_EVENT_TTL = 300
|
||||||
|
|
||||||
TTriggerDebugEvent = TypeVar("TTriggerDebugEvent", bound="BaseDebugEvent")
|
|
||||||
|
|
||||||
|
|
||||||
class TriggerDebugEventBus:
|
class TriggerDebugEventBus:
|
||||||
"""
|
"""
|
||||||
@@ -81,15 +78,15 @@ class TriggerDebugEventBus:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def poll(
|
def poll[T: BaseDebugEvent](
|
||||||
cls,
|
cls,
|
||||||
event_type: type[TTriggerDebugEvent],
|
event_type: type[T],
|
||||||
pool_key: str,
|
pool_key: str,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
app_id: str,
|
app_id: str,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
) -> TTriggerDebugEvent | None:
|
) -> T | None:
|
||||||
"""
|
"""
|
||||||
Poll for an event or register to the waiting pool.
|
Poll for an event or register to the waiting pool.
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import importlib
|
|||||||
import pkgutil
|
import pkgutil
|
||||||
from collections.abc import Callable, Iterator, Mapping, MutableMapping
|
from collections.abc import Callable, Iterator, Mapping, MutableMapping
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import TYPE_CHECKING, Any, TypeAlias, cast, final
|
from typing import TYPE_CHECKING, Any, cast, final, override
|
||||||
|
|
||||||
from graphon.entities.base_node_data import BaseNodeData
|
from graphon.entities.base_node_data import BaseNodeData
|
||||||
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||||
@@ -22,7 +22,6 @@ from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeDat
|
|||||||
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
|
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||||
@@ -192,7 +191,7 @@ class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Nod
|
|||||||
NODE_TYPE_CLASSES_MAPPING: MutableMapping[NodeType, Mapping[str, type[Node]]] = _LazyNodeTypeClassesMapping()
|
NODE_TYPE_CLASSES_MAPPING: MutableMapping[NodeType, Mapping[str, type[Node]]] = _LazyNodeTypeClassesMapping()
|
||||||
|
|
||||||
|
|
||||||
LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
|
type LLMCompatibleNodeData = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
|
||||||
|
|
||||||
|
|
||||||
def fetch_memory(
|
def fetch_memory(
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
import ssl
|
import ssl
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
from redis import RedisError
|
from redis import RedisError
|
||||||
@@ -297,12 +297,7 @@ def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
|
|||||||
return RedisBroadcastChannel(_pubsub_redis_client)
|
return RedisBroadcastChannel(_pubsub_redis_client)
|
||||||
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
def redis_fallback[T](default_return: T | None = None): # type: ignore
|
||||||
R = TypeVar("R")
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
def redis_fallback(default_return: T | None = None): # type: ignore
|
|
||||||
"""
|
"""
|
||||||
decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
|
decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
|
||||||
|
|
||||||
@@ -310,9 +305,9 @@ def redis_fallback(default_return: T | None = None): # type: ignore
|
|||||||
default_return: The value to return when a Redis operation fails. Defaults to None.
|
default_return: The value to return when a Redis operation fails. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[P, R]):
|
def decorator[**P, R](func: Callable[P, R]) -> Callable[P, R | T | None]:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapper(*args: P.args, **kwargs: P.kwargs):
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | T | None:
|
||||||
try:
|
try:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
except RedisError as e:
|
except RedisError as e:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from graphon.entities import WorkflowExecution
|
from graphon.entities import WorkflowExecution
|
||||||
from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||||
@@ -27,7 +26,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
session_factory: sessionmaker | Engine,
|
session_factory: sessionmaker | Engine,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
app_id: str | None,
|
app_id: str | None,
|
||||||
triggered_from: WorkflowRunTriggeredFrom | None,
|
triggered_from: WorkflowRunTriggeredFrom | None,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Union
|
from typing import Any
|
||||||
|
|
||||||
from graphon.entities import WorkflowNodeExecution
|
from graphon.entities import WorkflowNodeExecution
|
||||||
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
@@ -109,7 +109,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
session_factory: sessionmaker | Engine,
|
session_factory: sessionmaker | Engine,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
app_id: str | None,
|
app_id: str | None,
|
||||||
triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
|
triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import functools
|
import functools
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import ParamSpec, TypeVar, cast
|
from typing import cast
|
||||||
|
|
||||||
from opentelemetry.trace import get_tracer
|
from opentelemetry.trace import get_tracer
|
||||||
|
|
||||||
@@ -8,9 +8,6 @@ from configs import dify_config
|
|||||||
from extensions.otel.decorators.handler import SpanHandler
|
from extensions.otel.decorators.handler import SpanHandler
|
||||||
from extensions.otel.runtime import is_instrument_flag_enabled
|
from extensions.otel.runtime import is_instrument_flag_enabled
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
_HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()}
|
_HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()}
|
||||||
|
|
||||||
|
|
||||||
@@ -21,7 +18,7 @@ def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler:
|
|||||||
return _HANDLER_INSTANCES[handler_class]
|
return _HANDLER_INSTANCES[handler_class]
|
||||||
|
|
||||||
|
|
||||||
def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||||
"""
|
"""
|
||||||
Decorator that traces a function with an OpenTelemetry span.
|
Decorator that traces a function with an OpenTelemetry span.
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from collections.abc import Callable, Mapping
|
from collections.abc import Callable, Mapping
|
||||||
from typing import Any, TypeVar
|
from typing import Any
|
||||||
|
|
||||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||||
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
|
||||||
class SpanHandler:
|
class SpanHandler:
|
||||||
"""
|
"""
|
||||||
@@ -31,9 +29,9 @@ class SpanHandler:
|
|||||||
"""
|
"""
|
||||||
return f"{wrapped.__module__}.{wrapped.__qualname__}"
|
return f"{wrapped.__module__}.{wrapped.__qualname__}"
|
||||||
|
|
||||||
def _extract_arguments(
|
def _extract_arguments[T](
|
||||||
self,
|
self,
|
||||||
wrapped: Callable[..., R],
|
wrapped: Callable[..., T],
|
||||||
args: tuple[object, ...],
|
args: tuple[object, ...],
|
||||||
kwargs: Mapping[str, object],
|
kwargs: Mapping[str, object],
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
@@ -61,13 +59,13 @@ class SpanHandler:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def wrapper(
|
def wrapper[T](
|
||||||
self,
|
self,
|
||||||
tracer: Any,
|
tracer: Any,
|
||||||
wrapped: Callable[..., R],
|
wrapped: Callable[..., T],
|
||||||
args: tuple[object, ...],
|
args: tuple[object, ...],
|
||||||
kwargs: Mapping[str, object],
|
kwargs: Mapping[str, object],
|
||||||
) -> R:
|
) -> T:
|
||||||
"""
|
"""
|
||||||
Fully control the wrapper behavior.
|
Fully control the wrapper behavior.
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable, Mapping
|
from collections.abc import Callable, Mapping
|
||||||
from typing import Any, TypeVar
|
from typing import Any
|
||||||
|
|
||||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||||
from opentelemetry.util.types import AttributeValue
|
from opentelemetry.util.types import AttributeValue
|
||||||
@@ -12,19 +12,16 @@ from models.model import Account
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
|
||||||
class AppGenerateHandler(SpanHandler):
|
class AppGenerateHandler(SpanHandler):
|
||||||
"""Span handler for ``AppGenerateService.generate``."""
|
"""Span handler for ``AppGenerateService.generate``."""
|
||||||
|
|
||||||
def wrapper(
|
def wrapper[T](
|
||||||
self,
|
self,
|
||||||
tracer: Any,
|
tracer: Any,
|
||||||
wrapped: Callable[..., R],
|
wrapped: Callable[..., T],
|
||||||
args: tuple[object, ...],
|
args: tuple[object, ...],
|
||||||
kwargs: Mapping[str, object],
|
kwargs: Mapping[str, object],
|
||||||
) -> R:
|
) -> T:
|
||||||
try:
|
try:
|
||||||
arguments = self._extract_arguments(wrapped, args, kwargs)
|
arguments = self._extract_arguments(wrapped, args, kwargs)
|
||||||
if not arguments:
|
if not arguments:
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, TypeAlias
|
from typing import Any
|
||||||
|
|
||||||
from graphon.file import File
|
from graphon.file import File
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||||
|
|
||||||
JSONValue: TypeAlias = Any
|
type JSONValue = Any
|
||||||
|
|
||||||
|
|
||||||
class ResponseModel(BaseModel):
|
class ResponseModel(BaseModel):
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TypeAlias
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from graphon.file import File
|
from graphon.file import File
|
||||||
@@ -10,7 +9,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|||||||
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
||||||
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
|
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
|
||||||
|
|
||||||
JSONValueType: TypeAlias = JSONValue
|
type JSONValueType = JSONValue
|
||||||
|
|
||||||
|
|
||||||
class ResponseModel(BaseModel):
|
class ResponseModel(BaseModel):
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
import contextvars
|
import contextvars
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, TypeVar
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from flask import Flask, g
|
from flask import Flask, g
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from models import Account, EndUser
|
from models import Account, EndUser
|
||||||
|
|
||||||
|
|||||||
@@ -42,13 +42,7 @@ def current_account_with_tenant():
|
|||||||
return user, user.current_tenant_id
|
return user, user.current_tenant_id
|
||||||
|
|
||||||
|
|
||||||
from typing import ParamSpec, TypeVar
|
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
|
||||||
def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
|
|
||||||
"""
|
"""
|
||||||
If you decorate a view with this, it will ensure that the current user is
|
If you decorate a view with this, it will ensure that the current user is
|
||||||
logged in and authenticated before calling the actual view. (If they are
|
logged in and authenticated before calling the actual view. (If they are
|
||||||
|
|||||||
@@ -1,26 +1,20 @@
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import NotRequired
|
from typing import NotRequired, TypedDict
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import TypeAdapter, ValidationError
|
from pydantic import TypeAdapter, ValidationError
|
||||||
|
|
||||||
from core.helper.http_client_pooling import get_pooled_http_client
|
from core.helper.http_client_pooling import get_pooled_http_client
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
|
||||||
from typing import TypedDict
|
|
||||||
else:
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
JsonObject = dict[str, object]
|
type JsonObject = dict[str, object]
|
||||||
JsonObjectList = list[JsonObject]
|
type JsonObjectList = list[JsonObject]
|
||||||
|
|
||||||
JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
|
JSON_OBJECT_ADAPTER: TypeAdapter[JsonObject] = TypeAdapter(JsonObject)
|
||||||
JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
|
JSON_OBJECT_LIST_ADAPTER: TypeAdapter[JsonObjectList] = TypeAdapter(JsonObjectList)
|
||||||
|
|
||||||
# Reuse a pooled httpx.Client for OAuth flows (public endpoints, no SSRF proxy).
|
# Reuse a pooled httpx.Client for OAuth flows (public endpoints, no SSRF proxy).
|
||||||
_http_client: httpx.Client = get_pooled_http_client(
|
_http_client: httpx.Client = get_pooled_http_client(
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import sys
|
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, TypedDict
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
@@ -12,11 +11,6 @@ from extensions.ext_database import db
|
|||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models.source import DataSourceOauthBinding
|
from models.source import DataSourceOauthBinding
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
|
||||||
from typing import TypedDict
|
|
||||||
else:
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
|
|
||||||
class NotionPageSummary(TypedDict):
|
class NotionPageSummary(TypedDict):
|
||||||
page_id: str
|
page_id: str
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from datetime import datetime
|
|||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast
|
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
@@ -19,7 +19,6 @@ from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
|||||||
from graphon.file import helpers as file_helpers
|
from graphon.file import helpers as file_helpers
|
||||||
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import enum
|
import enum
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Generic, TypeVar
|
from typing import Any
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import CHAR, TEXT, VARCHAR, LargeBinary, TypeDecorator
|
from sqlalchemy import CHAR, TEXT, VARCHAR, LargeBinary, TypeDecorator
|
||||||
@@ -110,17 +110,14 @@ class AdjustedJSON(TypeDecorator[dict | list | None]):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
_E = TypeVar("_E", bound=enum.StrEnum)
|
class EnumText[T: enum.StrEnum](TypeDecorator[T | None]):
|
||||||
|
|
||||||
|
|
||||||
class EnumText(TypeDecorator[_E | None], Generic[_E]):
|
|
||||||
impl = VARCHAR
|
impl = VARCHAR
|
||||||
cache_ok = True
|
cache_ok = True
|
||||||
|
|
||||||
_length: int
|
_length: int
|
||||||
_enum_class: type[_E]
|
_enum_class: type[T]
|
||||||
|
|
||||||
def __init__(self, enum_class: type[_E], length: int | None = None):
|
def __init__(self, enum_class: type[T], length: int | None = None):
|
||||||
self._enum_class = enum_class
|
self._enum_class = enum_class
|
||||||
max_enum_value_len = max(len(e.value) for e in enum_class)
|
max_enum_value_len = max(len(e.value) for e in enum_class)
|
||||||
if length is not None:
|
if length is not None:
|
||||||
@@ -131,25 +128,25 @@ class EnumText(TypeDecorator[_E | None], Generic[_E]):
|
|||||||
# leave some rooms for future longer enum values.
|
# leave some rooms for future longer enum values.
|
||||||
self._length = max(max_enum_value_len, 20)
|
self._length = max(max_enum_value_len, 20)
|
||||||
|
|
||||||
def process_bind_param(self, value: _E | str | None, dialect: Dialect) -> str | None:
|
def process_bind_param(self, value: T | str | None, dialect: Dialect) -> str | None:
|
||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
if isinstance(value, self._enum_class):
|
if isinstance(value, self._enum_class):
|
||||||
return value.value
|
return value.value
|
||||||
# Since _E is bound to StrEnum which inherits from str, at this point value must be str
|
# Since T is bound to StrEnum which inherits from str, at this point value must be str
|
||||||
self._enum_class(value)
|
self._enum_class(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
|
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
|
||||||
return dialect.type_descriptor(VARCHAR(self._length))
|
return dialect.type_descriptor(VARCHAR(self._length))
|
||||||
|
|
||||||
def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None:
|
def process_result_value(self, value: str | None, dialect: Dialect) -> T | None:
|
||||||
if value is None or value == "":
|
if value is None or value == "":
|
||||||
return None
|
return None
|
||||||
# Type annotation guarantees value is str at this point
|
# Type annotation guarantees value is str at this point
|
||||||
return self._enum_class(value)
|
return self._enum_class(value)
|
||||||
|
|
||||||
def compare_values(self, x: _E | None, y: _E | None) -> bool:
|
def compare_values(self, x: T | None, y: T | None) -> bool:
|
||||||
if x is None or y is None:
|
if x is None or y is None:
|
||||||
return x is y
|
return x is y
|
||||||
return x == y
|
return x == y
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dify-api"
|
name = "dify-api"
|
||||||
version = "1.13.3"
|
version = "1.13.3"
|
||||||
requires-python = ">=3.11,<3.13"
|
requires-python = "~=3.12.0"
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aliyun-log-python-sdk~=0.9.37",
|
"aliyun-log-python-sdk~=0.9.37",
|
||||||
@@ -232,5 +232,5 @@ vdb = [
|
|||||||
project-includes = ["."]
|
project-includes = ["."]
|
||||||
project-excludes = [".venv", "migrations/"]
|
project-excludes = [".venv", "migrations/"]
|
||||||
python-platform = "linux"
|
python-platform = "linux"
|
||||||
python-version = "3.11.0"
|
python-version = "3.12.0"
|
||||||
infer-with-first-use = false
|
infer-with-first-use = false
|
||||||
|
|||||||
@@ -50,6 +50,6 @@
|
|||||||
"reportUntypedFunctionDecorator": "hint",
|
"reportUntypedFunctionDecorator": "hint",
|
||||||
"reportUnnecessaryTypeIgnoreComment": "hint",
|
"reportUnnecessaryTypeIgnoreComment": "hint",
|
||||||
"reportAttributeAccessIssue": "hint",
|
"reportAttributeAccessIssue": "hint",
|
||||||
"pythonVersion": "3.11",
|
"pythonVersion": "3.12",
|
||||||
"pythonPlatform": "All"
|
"pythonPlatform": "All"
|
||||||
}
|
}
|
||||||
@@ -5,12 +5,11 @@ import secrets
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import Any, cast
|
from typing import Any, TypedDict, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, TypeAdapter
|
from pydantic import BaseModel, TypeAdapter
|
||||||
from sqlalchemy import func, select
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
|
|
||||||
class InvitationData(TypedDict):
|
class InvitationData(TypedDict):
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
|
|
||||||
class AuthCredentials(TypedDict):
|
class AuthCredentials(TypedDict):
|
||||||
|
|||||||
@@ -2,13 +2,12 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Literal
|
from typing import Literal, TypedDict
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
||||||
from typing_extensions import TypedDict
|
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
from core.helper.http_client_pooling import get_pooled_http_client
|
from core.helper.http_client_pooling import get_pooled_http_client
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any, Union
|
from typing import Any
|
||||||
|
|
||||||
from graphon.variables.types import SegmentType
|
from graphon.variables.types import SegmentType
|
||||||
from sqlalchemy import asc, desc, func, or_, select
|
from sqlalchemy import asc, desc, func, or_, select
|
||||||
@@ -37,7 +37,7 @@ class ConversationService:
|
|||||||
*,
|
*,
|
||||||
session: Session,
|
session: Session,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser] | None,
|
user: Account | EndUser | None,
|
||||||
last_id: str | None,
|
last_id: str | None,
|
||||||
limit: int,
|
limit: int,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
@@ -119,7 +119,7 @@ class ConversationService:
|
|||||||
cls,
|
cls,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
user: Union[Account, EndUser] | None,
|
user: Account | EndUser | None,
|
||||||
name: str | None,
|
name: str | None,
|
||||||
auto_generate: bool,
|
auto_generate: bool,
|
||||||
):
|
):
|
||||||
@@ -159,7 +159,7 @@ class ConversationService:
|
|||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
|
def get_conversation(cls, app_model: App, conversation_id: str, user: Account | EndUser | None):
|
||||||
conversation = db.session.scalar(
|
conversation = db.session.scalar(
|
||||||
select(Conversation)
|
select(Conversation)
|
||||||
.where(
|
.where(
|
||||||
@@ -179,7 +179,7 @@ class ConversationService:
|
|||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def delete(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
|
def delete(cls, app_model: App, conversation_id: str, user: Account | EndUser | None):
|
||||||
"""
|
"""
|
||||||
Delete a conversation only if it belongs to the given user and app context.
|
Delete a conversation only if it belongs to the given user and app context.
|
||||||
|
|
||||||
@@ -209,7 +209,7 @@ class ConversationService:
|
|||||||
cls,
|
cls,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
user: Union[Account, EndUser] | None,
|
user: Account | EndUser | None,
|
||||||
limit: int,
|
limit: int,
|
||||||
last_id: str | None,
|
last_id: str | None,
|
||||||
variable_name: str | None = None,
|
variable_name: str | None = None,
|
||||||
@@ -278,7 +278,7 @@ class ConversationService:
|
|||||||
app_model: App,
|
app_model: App,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
variable_id: str,
|
variable_id: str,
|
||||||
user: Union[Account, EndUser] | None,
|
user: Account | EndUser | None,
|
||||||
new_value: Any,
|
new_value: Any,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from graphon.model_runtime.entities.model_entities import ModelType
|
from graphon.model_runtime.entities.model_entities import ModelType
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
@@ -57,7 +56,7 @@ class MessageService:
|
|||||||
def pagination_by_first_id(
|
def pagination_by_first_id(
|
||||||
cls,
|
cls,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser] | None,
|
user: Account | EndUser | None,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
first_id: str | None,
|
first_id: str | None,
|
||||||
limit: int,
|
limit: int,
|
||||||
@@ -117,7 +116,7 @@ class MessageService:
|
|||||||
def pagination_by_last_id(
|
def pagination_by_last_id(
|
||||||
cls,
|
cls,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser] | None,
|
user: Account | EndUser | None,
|
||||||
last_id: str | None,
|
last_id: str | None,
|
||||||
limit: int,
|
limit: int,
|
||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
@@ -170,7 +169,7 @@ class MessageService:
|
|||||||
*,
|
*,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
user: Union[Account, EndUser] | None,
|
user: Account | EndUser | None,
|
||||||
rating: FeedbackRating | None,
|
rating: FeedbackRating | None,
|
||||||
content: str | None,
|
content: str | None,
|
||||||
):
|
):
|
||||||
@@ -221,7 +220,7 @@ class MessageService:
|
|||||||
return [record.to_dict() for record in feedbacks]
|
return [record.to_dict() for record in feedbacks]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_message(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
|
def get_message(cls, app_model: App, user: Account | EndUser | None, message_id: str):
|
||||||
message = db.session.scalar(
|
message = db.session.scalar(
|
||||||
select(Message)
|
select(Message)
|
||||||
.where(
|
.where(
|
||||||
@@ -241,7 +240,7 @@ class MessageService:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_suggested_questions_after_answer(
|
def get_suggested_questions_after_answer(
|
||||||
cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str, invoke_from: InvokeFrom
|
cls, app_model: App, user: Account | EndUser | None, message_id: str, invoke_from: InvokeFrom
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
if not user:
|
if not user:
|
||||||
raise ValueError("user cannot be None")
|
raise ValueError("user cannot be None")
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import time
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import click
|
import click
|
||||||
@@ -14,7 +14,6 @@ import tqdm
|
|||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from core.agent.entities import AgentToolEntity
|
from core.agent.entities import AgentToolEntity
|
||||||
from core.helper import marketplace
|
from core.helper import marketplace
|
||||||
|
|||||||
@@ -13,13 +13,12 @@ from collections.abc import Callable
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, cast
|
from typing import Any, TypedDict, cast
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
from sqlalchemy.engine import CursorResult
|
from sqlalchemy.engine import CursorResult
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
|
|
||||||
class _TableInfo(TypedDict, total=False):
|
class _TableInfo(TypedDict, total=False):
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
from typing import Union
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@@ -14,7 +12,7 @@ from services.message_service import MessageService
|
|||||||
class SavedMessageService:
|
class SavedMessageService:
|
||||||
@classmethod
|
@classmethod
|
||||||
def pagination_by_last_id(
|
def pagination_by_last_id(
|
||||||
cls, app_model: App, user: Union[Account, EndUser] | None, last_id: str | None, limit: int
|
cls, app_model: App, user: Account | EndUser | None, last_id: str | None, limit: int
|
||||||
) -> InfiniteScrollPagination:
|
) -> InfiniteScrollPagination:
|
||||||
if not user:
|
if not user:
|
||||||
raise ValueError("User is required")
|
raise ValueError("User is required")
|
||||||
@@ -34,7 +32,7 @@ class SavedMessageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def save(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
|
def save(cls, app_model: App, user: Account | EndUser | None, message_id: str):
|
||||||
if not user:
|
if not user:
|
||||||
return
|
return
|
||||||
saved_message = db.session.scalar(
|
saved_message = db.session.scalar(
|
||||||
@@ -64,7 +62,7 @@ class SavedMessageService:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def delete(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
|
def delete(cls, app_model: App, user: Account | EndUser | None, message_id: str):
|
||||||
if not user:
|
if not user:
|
||||||
return
|
return
|
||||||
saved_message = db.session.scalar(
|
saved_message = db.session.scalar(
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, cast
|
from typing import Any, TypedDict, cast
|
||||||
|
|
||||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from httpx import get
|
from httpx import get
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from core.entities.provider_entities import ProviderConfig
|
from core.entities.provider_entities import ProviderConfig
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Generic, TypeAlias, TypeVar, overload
|
from typing import Any, overload
|
||||||
|
|
||||||
from graphon.file import File
|
from graphon.file import File
|
||||||
from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable
|
from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable
|
||||||
@@ -43,12 +43,9 @@ class _PCKeys:
|
|||||||
CHILD_CONTENTS = "child_contents"
|
CHILD_CONTENTS = "child_contents"
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class _PartResult(Generic[_T]):
|
class _PartResult[T]:
|
||||||
value: _T
|
value: T
|
||||||
value_size: int
|
value_size: int
|
||||||
truncated: bool
|
truncated: bool
|
||||||
|
|
||||||
@@ -61,7 +58,7 @@ class UnknownTypeError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
JSONTypes: TypeAlias = int | float | str | list[object] | dict[str, object] | None | bool
|
type JSONTypes = int | float | str | list[object] | dict[str, object] | None | bool
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
from typing import Union
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -20,7 +18,7 @@ class WebConversationService:
|
|||||||
*,
|
*,
|
||||||
session: Session,
|
session: Session,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser] | None,
|
user: Account | EndUser | None,
|
||||||
last_id: str | None,
|
last_id: str | None,
|
||||||
limit: int,
|
limit: int,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
@@ -61,7 +59,7 @@ class WebConversationService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
|
def pin(cls, app_model: App, conversation_id: str, user: Account | EndUser | None):
|
||||||
if not user:
|
if not user:
|
||||||
return
|
return
|
||||||
pinned_conversation = db.session.scalar(
|
pinned_conversation = db.session.scalar(
|
||||||
@@ -93,7 +91,7 @@ class WebConversationService:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
|
def unpin(cls, app_model: App, conversation_id: str, user: Account | EndUser | None):
|
||||||
if not user:
|
if not user:
|
||||||
return
|
return
|
||||||
pinned_conversation = db.session.scalar(
|
pinned_conversation = db.session.scalar(
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from graphon.file import FileUploadConfig
|
from graphon.file import FileUploadConfig
|
||||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||||
@@ -7,7 +7,6 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder
|
|||||||
from graphon.nodes import BuiltinNodeTypes
|
from graphon.nodes import BuiltinNodeTypes
|
||||||
from graphon.variables.input_entities import VariableEntity
|
from graphon.variables.input_entities import VariableEntity
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from core.app.app_config.entities import (
|
from core.app.app_config.entities import (
|
||||||
DatasetEntity,
|
DatasetEntity,
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from graphon.enums import WorkflowExecutionStatus
|
from graphon.enums import WorkflowExecutionStatus
|
||||||
from sqlalchemy import and_, func, or_, select
|
from sqlalchemy import and_, func, or_, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
|
from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
|
||||||
from models.enums import AppTriggerType, CreatorUserRole
|
from models.enums import AppTriggerType, CreatorUserRole
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Annotated, Any, TypeAlias, Union
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from flask import current_app, json
|
from flask import current_app, json
|
||||||
@@ -68,7 +68,7 @@ def _get_user_type_descriminator(value: Any):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
User: TypeAlias = Annotated[
|
type User = Annotated[
|
||||||
(Annotated[_Account, Tag(_UserType.ACCOUNT)] | Annotated[_EndUser, Tag(_UserType.END_USER)]),
|
(Annotated[_Account, Tag(_UserType.ACCOUNT)] | Annotated[_EndUser, Tag(_UserType.END_USER)]),
|
||||||
Discriminator(_get_user_type_descriminator),
|
Discriminator(_get_user_type_descriminator),
|
||||||
]
|
]
|
||||||
@@ -93,7 +93,7 @@ class AppExecutionParams(BaseModel):
|
|||||||
cls,
|
cls,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import os
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Protocol, TypeVar
|
from typing import Protocol
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
import pytest
|
import pytest
|
||||||
@@ -48,11 +48,8 @@ class _CloserProtocol(Protocol):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
_Closer = TypeVar("_Closer", bound=_CloserProtocol)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _auto_close(closer: _Closer) -> Generator[_Closer, None, None]:
|
def _auto_close[T: _CloserProtocol](closer: T) -> Generator[T, None, None]:
|
||||||
yield closer
|
yield closer
|
||||||
closer.close()
|
closer.close()
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Iterable
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any, NamedTuple, TypeVar
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
@@ -58,10 +58,7 @@ class _ColumnTest(_Base):
|
|||||||
long_value: Mapped[_EnumWithLongValue] = mapped_column(EnumText(enum_class=_EnumWithLongValue), nullable=False)
|
long_value: Mapped[_EnumWithLongValue] = mapped_column(EnumText(enum_class=_EnumWithLongValue), nullable=False)
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
def _first[T](it: Iterable[T]) -> T:
|
||||||
|
|
||||||
|
|
||||||
def _first(it: Iterable[_T]) -> _T:
|
|
||||||
ls = list(it)
|
ls = list(it)
|
||||||
if not ls:
|
if not ls:
|
||||||
raise ValueError("List is empty")
|
raise ValueError("List is empty")
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ and data_source_detail_dict for all data_source_type values, including "local_fi
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Generic, Literal, NotRequired, TypedDict, TypeVar, Union
|
from typing import Literal, NotRequired, TypedDict
|
||||||
|
|
||||||
from models.dataset import Document
|
from models.dataset import Document
|
||||||
|
|
||||||
@@ -31,12 +31,10 @@ class WebsiteCrawlInfo(TypedDict):
|
|||||||
job_id: str
|
job_id: str
|
||||||
|
|
||||||
|
|
||||||
RawInfo = Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo]
|
type RawInfo = LocalFileInfo | UploadFileInfo | NotionImportInfo | WebsiteCrawlInfo
|
||||||
T_type = TypeVar("T_type", bound=str)
|
|
||||||
T_info = TypeVar("T_info", bound=Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo])
|
|
||||||
|
|
||||||
|
|
||||||
class Case(TypedDict, Generic[T_type, T_info]):
|
class Case[T_type: str, T_info: RawInfo](TypedDict):
|
||||||
data_source_type: T_type
|
data_source_type: T_type
|
||||||
data_source_info: str
|
data_source_info: str
|
||||||
expected_raw: T_info
|
expected_raw: T_info
|
||||||
@@ -47,7 +45,7 @@ UploadFileCase = Case[Literal["upload_file"], UploadFileInfo]
|
|||||||
NotionImportCase = Case[Literal["notion_import"], NotionImportInfo]
|
NotionImportCase = Case[Literal["notion_import"], NotionImportInfo]
|
||||||
WebsiteCrawlCase = Case[Literal["website_crawl"], WebsiteCrawlInfo]
|
WebsiteCrawlCase = Case[Literal["website_crawl"], WebsiteCrawlInfo]
|
||||||
|
|
||||||
AnyCase = Union[LocalFileCase, UploadFileCase, NotionImportCase, WebsiteCrawlCase]
|
type AnyCase = LocalFileCase | UploadFileCase | NotionImportCase | WebsiteCrawlCase
|
||||||
|
|
||||||
|
|
||||||
case_1: LocalFileCase = {
|
case_1: LocalFileCase = {
|
||||||
|
|||||||
@@ -77,7 +77,6 @@ class TestAuthType:
|
|||||||
|
|
||||||
def test_auth_type_immutability(self):
|
def test_auth_type_immutability(self):
|
||||||
"""Test that enum values cannot be modified"""
|
"""Test that enum values cannot be modified"""
|
||||||
# In Python 3.11+, enum members are read-only
|
|
||||||
with pytest.raises(AttributeError):
|
with pytest.raises(AttributeError):
|
||||||
AuthType.FIRECRAWL = "modified"
|
AuthType.FIRECRAWL = "modified"
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
import operator
|
import operator
|
||||||
from typing import TypeVar
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -16,10 +15,8 @@ from core.tools.entities.tool_entities import (
|
|||||||
ToolInvokeMessage,
|
ToolInvokeMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
|
|
||||||
|
def _get_message_by_type[T](msgs: list[ToolInvokeMessage], msg_type: type[T]) -> ToolInvokeMessage | None:
|
||||||
def _get_message_by_type(msgs: list[ToolInvokeMessage], msg_type: type[_T]) -> ToolInvokeMessage | None:
|
|
||||||
return next((i for i in msgs if isinstance(i.message, msg_type)), None)
|
return next((i for i in msgs if isinstance(i.message, msg_type)), None)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
892
api/uv.lock
generated
892
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user