mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 10:12:43 +08:00
refactor(api): tighten phase 1 shared type contracts (#33453)
This commit is contained in:
@@ -78,7 +78,7 @@ class UserProfile(TypedDict):
|
|||||||
nickname: NotRequired[str]
|
nickname: NotRequired[str]
|
||||||
```
|
```
|
||||||
|
|
||||||
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
|
- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Literal, Protocol
|
from typing import Literal, Protocol, cast
|
||||||
from urllib.parse import quote_plus, urlunparse
|
from urllib.parse import quote_plus, urlunparse
|
||||||
|
|
||||||
from pydantic import AliasChoices, Field
|
from pydantic import AliasChoices, Field
|
||||||
@@ -12,16 +12,13 @@ class RedisConfigDefaults(Protocol):
|
|||||||
REDIS_PASSWORD: str | None
|
REDIS_PASSWORD: str | None
|
||||||
REDIS_DB: int
|
REDIS_DB: int
|
||||||
REDIS_USE_SSL: bool
|
REDIS_USE_SSL: bool
|
||||||
REDIS_USE_SENTINEL: bool | None
|
|
||||||
REDIS_USE_CLUSTERS: bool
|
|
||||||
|
|
||||||
|
|
||||||
class RedisConfigDefaultsMixin:
|
def _redis_defaults(config: object) -> RedisConfigDefaults:
|
||||||
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
|
return cast(RedisConfigDefaults, config)
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
class RedisPubSubConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Configuration settings for event transport between API and workers.
|
Configuration settings for event transport between API and workers.
|
||||||
|
|
||||||
@@ -74,7 +71,7 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _build_default_pubsub_url(self) -> str:
|
def _build_default_pubsub_url(self) -> str:
|
||||||
defaults = self._redis_defaults()
|
defaults = _redis_defaults(self)
|
||||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||||
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
|
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
|
||||||
|
|
||||||
@@ -91,11 +88,9 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
|||||||
if userinfo:
|
if userinfo:
|
||||||
userinfo = f"{userinfo}@"
|
userinfo = f"{userinfo}@"
|
||||||
|
|
||||||
host = defaults.REDIS_HOST
|
|
||||||
port = defaults.REDIS_PORT
|
|
||||||
db = defaults.REDIS_DB
|
db = defaults.REDIS_DB
|
||||||
|
|
||||||
netloc = f"{userinfo}{host}:{port}"
|
netloc = f"{userinfo}{defaults.REDIS_HOST}:{defaults.REDIS_PORT}"
|
||||||
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
|
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
|
|||||||
from dify_graph.file.models import File
|
from dify_graph.file.models import File
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
pass
|
from dify_graph.variables.segments import Segment
|
||||||
|
|
||||||
|
|
||||||
class ArrayValidation(StrEnum):
|
class ArrayValidation(StrEnum):
|
||||||
@@ -219,7 +219,7 @@ class SegmentType(StrEnum):
|
|||||||
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_zero_value(t: SegmentType):
|
def get_zero_value(t: SegmentType) -> Segment:
|
||||||
# Lazy import to avoid circular dependency
|
# Lazy import to avoid circular dependency
|
||||||
from factories import variable_factory
|
from factories import variable_factory
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Protocol, cast
|
||||||
|
|
||||||
from fastopenapi.routers import FlaskRouter
|
from fastopenapi.routers import FlaskRouter
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
|
|
||||||
@@ -9,6 +11,10 @@ from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS
|
|||||||
DOCS_PREFIX = "/fastopenapi"
|
DOCS_PREFIX = "/fastopenapi"
|
||||||
|
|
||||||
|
|
||||||
|
class SupportsIncludeRouter(Protocol):
|
||||||
|
def include_router(self, router: object, *, prefix: str = "") -> None: ...
|
||||||
|
|
||||||
|
|
||||||
def init_app(app: DifyApp) -> None:
|
def init_app(app: DifyApp) -> None:
|
||||||
docs_enabled = dify_config.SWAGGER_UI_ENABLED
|
docs_enabled = dify_config.SWAGGER_UI_ENABLED
|
||||||
docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
|
docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
|
||||||
@@ -36,7 +42,7 @@ def init_app(app: DifyApp) -> None:
|
|||||||
_ = remote_files
|
_ = remote_files
|
||||||
_ = setup
|
_ = setup
|
||||||
|
|
||||||
router.include_router(console_router, prefix="/console/api")
|
cast(SupportsIncludeRouter, router).include_router(console_router, prefix="/console/api")
|
||||||
CORS(
|
CORS(
|
||||||
app,
|
app,
|
||||||
resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class TypeMismatchError(Exception):
|
|||||||
|
|
||||||
|
|
||||||
# Define the constant
|
# Define the constant
|
||||||
SEGMENT_TO_VARIABLE_MAP = {
|
SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = {
|
||||||
ArrayAnySegment: ArrayAnyVariable,
|
ArrayAnySegment: ArrayAnyVariable,
|
||||||
ArrayBooleanSegment: ArrayBooleanVariable,
|
ArrayBooleanSegment: ArrayBooleanVariable,
|
||||||
ArrayFileSegment: ArrayFileVariable,
|
ArrayFileSegment: ArrayFileVariable,
|
||||||
@@ -296,13 +296,11 @@ def segment_to_variable(
|
|||||||
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
|
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
|
||||||
|
|
||||||
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
|
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
|
||||||
return cast(
|
return variable_class(
|
||||||
VariableBase,
|
|
||||||
variable_class(
|
|
||||||
id=id,
|
id=id,
|
||||||
name=name,
|
name=name,
|
||||||
description=description,
|
description=description,
|
||||||
|
value_type=segment.value_type,
|
||||||
value=segment.value,
|
value=segment.value,
|
||||||
selector=list(selector),
|
selector=list(selector),
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -32,6 +32,11 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_with_request_context(response: object) -> Any:
|
||||||
|
"""Bridge Flask's loosely-typed streaming helper without leaking casts into callers."""
|
||||||
|
return cast(Any, stream_with_context)(response)
|
||||||
|
|
||||||
|
|
||||||
def escape_like_pattern(pattern: str) -> str:
|
def escape_like_pattern(pattern: str) -> str:
|
||||||
"""
|
"""
|
||||||
Escape special characters in a string for safe use in SQL LIKE patterns.
|
Escape special characters in a string for safe use in SQL LIKE patterns.
|
||||||
@@ -286,22 +291,32 @@ def generate_text_hash(text: str) -> str:
|
|||||||
return sha256(hash_text.encode()).hexdigest()
|
return sha256(hash_text.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
def compact_generate_response(
|
||||||
if isinstance(response, dict):
|
response: Mapping[str, Any] | Generator[str, None, None] | RateLimitGenerator,
|
||||||
|
) -> Response:
|
||||||
|
if isinstance(response, Mapping):
|
||||||
return Response(
|
return Response(
|
||||||
response=json.dumps(jsonable_encoder(response)),
|
response=json.dumps(jsonable_encoder(response)),
|
||||||
status=200,
|
status=200,
|
||||||
content_type="application/json; charset=utf-8",
|
content_type="application/json; charset=utf-8",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
stream_response = response
|
||||||
|
|
||||||
def generate() -> Generator:
|
def generate() -> Generator[str, None, None]:
|
||||||
yield from response
|
yield from stream_response
|
||||||
|
|
||||||
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
|
return Response(
|
||||||
|
_stream_with_request_context(generate()),
|
||||||
|
status=200,
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
def length_prefixed_response(
|
||||||
|
magic_number: int,
|
||||||
|
response: Mapping[str, Any] | BaseModel | Generator[str | bytes, None, None] | RateLimitGenerator,
|
||||||
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
This function is used to return a response with a length prefix.
|
This function is used to return a response with a length prefix.
|
||||||
Magic number is a one byte number that indicates the type of the response.
|
Magic number is a one byte number that indicates the type of the response.
|
||||||
@@ -332,7 +347,7 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat
|
|||||||
# | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data
|
# | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data
|
||||||
return struct.pack("<BBHI", magic_number, 0, header_length, data_length) + b"\x00" * 6 + response
|
return struct.pack("<BBHI", magic_number, 0, header_length, data_length) + b"\x00" * 6 + response
|
||||||
|
|
||||||
if isinstance(response, dict):
|
if isinstance(response, Mapping):
|
||||||
return Response(
|
return Response(
|
||||||
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
|
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
|
||||||
status=200,
|
status=200,
|
||||||
@@ -345,14 +360,20 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat
|
|||||||
mimetype="application/json",
|
mimetype="application/json",
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate() -> Generator:
|
stream_response = response
|
||||||
for chunk in response:
|
|
||||||
|
def generate() -> Generator[bytes, None, None]:
|
||||||
|
for chunk in stream_response:
|
||||||
if isinstance(chunk, str):
|
if isinstance(chunk, str):
|
||||||
yield pack_response_with_length_prefix(chunk.encode("utf-8"))
|
yield pack_response_with_length_prefix(chunk.encode("utf-8"))
|
||||||
else:
|
else:
|
||||||
yield pack_response_with_length_prefix(chunk)
|
yield pack_response_with_length_prefix(chunk)
|
||||||
|
|
||||||
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
|
return Response(
|
||||||
|
_stream_with_request_context(generate()),
|
||||||
|
status=200,
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TokenManager:
|
class TokenManager:
|
||||||
|
|||||||
@@ -77,12 +77,14 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]
|
|||||||
@wraps(func)
|
@wraps(func)
|
||||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
|
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
|
||||||
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
||||||
pass
|
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||||
elif current_user is not None and not current_user.is_authenticated:
|
|
||||||
|
user = _get_user()
|
||||||
|
if user is None or not user.is_authenticated:
|
||||||
return current_app.login_manager.unauthorized() # type: ignore
|
return current_app.login_manager.unauthorized() # type: ignore
|
||||||
# we put csrf validation here for less conflicts
|
# we put csrf validation here for less conflicts
|
||||||
# TODO: maybe find a better place for it.
|
# TODO: maybe find a better place for it.
|
||||||
check_csrf_token(request, current_user.id)
|
check_csrf_token(request, user.id)
|
||||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||||
|
|
||||||
return decorated_view
|
return decorated_view
|
||||||
|
|||||||
@@ -7,9 +7,10 @@ https://github.com/django/django/blob/main/django/utils/module_loading.py
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def cached_import(module_path: str, class_name: str):
|
def cached_import(module_path: str, class_name: str) -> Any:
|
||||||
"""
|
"""
|
||||||
Import a module and return the named attribute/class from it, with caching.
|
Import a module and return the named attribute/class from it, with caching.
|
||||||
|
|
||||||
@@ -20,16 +21,14 @@ def cached_import(module_path: str, class_name: str):
|
|||||||
Returns:
|
Returns:
|
||||||
The imported attribute/class
|
The imported attribute/class
|
||||||
"""
|
"""
|
||||||
if not (
|
module = sys.modules.get(module_path)
|
||||||
(module := sys.modules.get(module_path))
|
spec = getattr(module, "__spec__", None) if module is not None else None
|
||||||
and (spec := getattr(module, "__spec__", None))
|
if module is None or getattr(spec, "_initializing", False):
|
||||||
and getattr(spec, "_initializing", False) is False
|
|
||||||
):
|
|
||||||
module = import_module(module_path)
|
module = import_module(module_path)
|
||||||
return getattr(module, class_name)
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|
||||||
def import_string(dotted_path: str):
|
def import_string(dotted_path: str) -> Any:
|
||||||
"""
|
"""
|
||||||
Import a dotted module path and return the attribute/class designated by
|
Import a dotted module path and return the attribute/class designated by
|
||||||
the last name in the path. Raise ImportError if the import failed.
|
the last name in the path. Raise ImportError if the import failed.
|
||||||
|
|||||||
@@ -1,7 +1,48 @@
|
|||||||
|
import sys
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import NotRequired
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 12):
|
||||||
|
from typing import TypedDict
|
||||||
|
else:
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
JsonObject = dict[str, object]
|
||||||
|
JsonObjectList = list[JsonObject]
|
||||||
|
|
||||||
|
JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
|
||||||
|
JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
|
||||||
|
|
||||||
|
|
||||||
|
class AccessTokenResponse(TypedDict, total=False):
|
||||||
|
access_token: str
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubEmailRecord(TypedDict, total=False):
|
||||||
|
email: str
|
||||||
|
primary: bool
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubRawUserInfo(TypedDict):
|
||||||
|
id: int | str
|
||||||
|
login: str
|
||||||
|
name: NotRequired[str]
|
||||||
|
email: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleRawUserInfo(TypedDict):
|
||||||
|
sub: str
|
||||||
|
email: str
|
||||||
|
|
||||||
|
|
||||||
|
ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse)
|
||||||
|
GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo)
|
||||||
|
GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord])
|
||||||
|
GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -11,26 +52,38 @@ class OAuthUserInfo:
|
|||||||
email: str
|
email: str
|
||||||
|
|
||||||
|
|
||||||
|
def _json_object(response: httpx.Response) -> JsonObject:
|
||||||
|
return JSON_OBJECT_ADAPTER.validate_python(response.json())
|
||||||
|
|
||||||
|
|
||||||
|
def _json_list(response: httpx.Response) -> JsonObjectList:
|
||||||
|
return JSON_OBJECT_LIST_ADAPTER.validate_python(response.json())
|
||||||
|
|
||||||
|
|
||||||
class OAuth:
|
class OAuth:
|
||||||
|
client_id: str
|
||||||
|
client_secret: str
|
||||||
|
redirect_uri: str
|
||||||
|
|
||||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||||
self.client_id = client_id
|
self.client_id = client_id
|
||||||
self.client_secret = client_secret
|
self.client_secret = client_secret
|
||||||
self.redirect_uri = redirect_uri
|
self.redirect_uri = redirect_uri
|
||||||
|
|
||||||
def get_authorization_url(self):
|
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_access_token(self, code: str):
|
def get_access_token(self, code: str) -> str:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_raw_user_info(self, token: str):
|
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_user_info(self, token: str) -> OAuthUserInfo:
|
def get_user_info(self, token: str) -> OAuthUserInfo:
|
||||||
raw_info = self.get_raw_user_info(token)
|
raw_info = self.get_raw_user_info(token)
|
||||||
return self._transform_user_info(raw_info)
|
return self._transform_user_info(raw_info)
|
||||||
|
|
||||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
@@ -40,7 +93,7 @@ class GitHubOAuth(OAuth):
|
|||||||
_USER_INFO_URL = "https://api.github.com/user"
|
_USER_INFO_URL = "https://api.github.com/user"
|
||||||
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
|
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
|
||||||
|
|
||||||
def get_authorization_url(self, invite_token: str | None = None):
|
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||||
params = {
|
params = {
|
||||||
"client_id": self.client_id,
|
"client_id": self.client_id,
|
||||||
"redirect_uri": self.redirect_uri,
|
"redirect_uri": self.redirect_uri,
|
||||||
@@ -50,7 +103,7 @@ class GitHubOAuth(OAuth):
|
|||||||
params["state"] = invite_token
|
params["state"] = invite_token
|
||||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||||
|
|
||||||
def get_access_token(self, code: str):
|
def get_access_token(self, code: str) -> str:
|
||||||
data = {
|
data = {
|
||||||
"client_id": self.client_id,
|
"client_id": self.client_id,
|
||||||
"client_secret": self.client_secret,
|
"client_secret": self.client_secret,
|
||||||
@@ -60,7 +113,7 @@ class GitHubOAuth(OAuth):
|
|||||||
headers = {"Accept": "application/json"}
|
headers = {"Accept": "application/json"}
|
||||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||||
|
|
||||||
response_json = response.json()
|
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
|
||||||
access_token = response_json.get("access_token")
|
access_token = response_json.get("access_token")
|
||||||
|
|
||||||
if not access_token:
|
if not access_token:
|
||||||
@@ -68,23 +121,24 @@ class GitHubOAuth(OAuth):
|
|||||||
|
|
||||||
return access_token
|
return access_token
|
||||||
|
|
||||||
def get_raw_user_info(self, token: str):
|
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||||
headers = {"Authorization": f"token {token}"}
|
headers = {"Authorization": f"token {token}"}
|
||||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
user_info = response.json()
|
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
|
||||||
|
|
||||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||||
email_info = email_response.json()
|
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
|
||||||
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
|
primary_email = next((email for email in email_info if email.get("primary") is True), None)
|
||||||
|
|
||||||
return {**user_info, "email": primary_email.get("email", "")}
|
return {**user_info, "email": primary_email.get("email", "") if primary_email else ""}
|
||||||
|
|
||||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||||
email = raw_info.get("email")
|
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||||
|
email = payload.get("email")
|
||||||
if not email:
|
if not email:
|
||||||
email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com"
|
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
|
||||||
return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email)
|
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
|
||||||
|
|
||||||
|
|
||||||
class GoogleOAuth(OAuth):
|
class GoogleOAuth(OAuth):
|
||||||
@@ -92,7 +146,7 @@ class GoogleOAuth(OAuth):
|
|||||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||||
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||||
|
|
||||||
def get_authorization_url(self, invite_token: str | None = None):
|
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||||
params = {
|
params = {
|
||||||
"client_id": self.client_id,
|
"client_id": self.client_id,
|
||||||
"response_type": "code",
|
"response_type": "code",
|
||||||
@@ -103,7 +157,7 @@ class GoogleOAuth(OAuth):
|
|||||||
params["state"] = invite_token
|
params["state"] = invite_token
|
||||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||||
|
|
||||||
def get_access_token(self, code: str):
|
def get_access_token(self, code: str) -> str:
|
||||||
data = {
|
data = {
|
||||||
"client_id": self.client_id,
|
"client_id": self.client_id,
|
||||||
"client_secret": self.client_secret,
|
"client_secret": self.client_secret,
|
||||||
@@ -114,7 +168,7 @@ class GoogleOAuth(OAuth):
|
|||||||
headers = {"Accept": "application/json"}
|
headers = {"Accept": "application/json"}
|
||||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||||
|
|
||||||
response_json = response.json()
|
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
|
||||||
access_token = response_json.get("access_token")
|
access_token = response_json.get("access_token")
|
||||||
|
|
||||||
if not access_token:
|
if not access_token:
|
||||||
@@ -122,11 +176,12 @@ class GoogleOAuth(OAuth):
|
|||||||
|
|
||||||
return access_token
|
return access_token
|
||||||
|
|
||||||
def get_raw_user_info(self, token: str):
|
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return _json_object(response)
|
||||||
|
|
||||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||||
return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"])
|
payload = GOOGLE_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||||
|
return OAuthUserInfo(id=str(payload["sub"]), name="", email=payload["email"])
|
||||||
|
|||||||
@@ -1,25 +1,57 @@
|
|||||||
|
import sys
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
|
from pydantic import TypeAdapter
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models.source import DataSourceOauthBinding
|
from models.source import DataSourceOauthBinding
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 12):
|
||||||
|
from typing import TypedDict
|
||||||
|
else:
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class NotionPageSummary(TypedDict):
|
||||||
|
page_id: str
|
||||||
|
page_name: str
|
||||||
|
page_icon: dict[str, str] | None
|
||||||
|
parent_id: str
|
||||||
|
type: Literal["page", "database"]
|
||||||
|
|
||||||
|
|
||||||
|
class NotionSourceInfo(TypedDict):
|
||||||
|
workspace_name: str | None
|
||||||
|
workspace_icon: str | None
|
||||||
|
workspace_id: str | None
|
||||||
|
pages: list[NotionPageSummary]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
SOURCE_INFO_STORAGE_ADAPTER = TypeAdapter(dict[str, object])
|
||||||
|
NOTION_SOURCE_INFO_ADAPTER = TypeAdapter(NotionSourceInfo)
|
||||||
|
NOTION_PAGE_SUMMARY_ADAPTER = TypeAdapter(NotionPageSummary)
|
||||||
|
|
||||||
|
|
||||||
class OAuthDataSource:
|
class OAuthDataSource:
|
||||||
|
client_id: str
|
||||||
|
client_secret: str
|
||||||
|
redirect_uri: str
|
||||||
|
|
||||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||||
self.client_id = client_id
|
self.client_id = client_id
|
||||||
self.client_secret = client_secret
|
self.client_secret = client_secret
|
||||||
self.redirect_uri = redirect_uri
|
self.redirect_uri = redirect_uri
|
||||||
|
|
||||||
def get_authorization_url(self):
|
def get_authorization_url(self) -> str:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_access_token(self, code: str):
|
def get_access_token(self, code: str) -> None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
@@ -30,7 +62,7 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
_NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
|
_NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
|
||||||
_NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
|
_NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
|
||||||
|
|
||||||
def get_authorization_url(self):
|
def get_authorization_url(self) -> str:
|
||||||
params = {
|
params = {
|
||||||
"client_id": self.client_id,
|
"client_id": self.client_id,
|
||||||
"response_type": "code",
|
"response_type": "code",
|
||||||
@@ -39,7 +71,7 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
}
|
}
|
||||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||||
|
|
||||||
def get_access_token(self, code: str):
|
def get_access_token(self, code: str) -> None:
|
||||||
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
|
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
|
||||||
headers = {"Accept": "application/json"}
|
headers = {"Accept": "application/json"}
|
||||||
auth = (self.client_id, self.client_secret)
|
auth = (self.client_id, self.client_secret)
|
||||||
@@ -54,13 +86,12 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
workspace_id = response_json.get("workspace_id")
|
workspace_id = response_json.get("workspace_id")
|
||||||
# get all authorized pages
|
# get all authorized pages
|
||||||
pages = self.get_authorized_pages(access_token)
|
pages = self.get_authorized_pages(access_token)
|
||||||
source_info = {
|
source_info = self._build_source_info(
|
||||||
"workspace_name": workspace_name,
|
workspace_name=workspace_name,
|
||||||
"workspace_icon": workspace_icon,
|
workspace_icon=workspace_icon,
|
||||||
"workspace_id": workspace_id,
|
workspace_id=workspace_id,
|
||||||
"pages": pages,
|
pages=pages,
|
||||||
"total": len(pages),
|
)
|
||||||
}
|
|
||||||
# save data source binding
|
# save data source binding
|
||||||
data_source_binding = db.session.scalar(
|
data_source_binding = db.session.scalar(
|
||||||
select(DataSourceOauthBinding).where(
|
select(DataSourceOauthBinding).where(
|
||||||
@@ -70,7 +101,7 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
if data_source_binding:
|
if data_source_binding:
|
||||||
data_source_binding.source_info = source_info
|
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||||
data_source_binding.disabled = False
|
data_source_binding.disabled = False
|
||||||
data_source_binding.updated_at = naive_utc_now()
|
data_source_binding.updated_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@@ -78,25 +109,24 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
new_data_source_binding = DataSourceOauthBinding(
|
new_data_source_binding = DataSourceOauthBinding(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
source_info=source_info,
|
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||||
provider="notion",
|
provider="notion",
|
||||||
)
|
)
|
||||||
db.session.add(new_data_source_binding)
|
db.session.add(new_data_source_binding)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def save_internal_access_token(self, access_token: str):
|
def save_internal_access_token(self, access_token: str) -> None:
|
||||||
workspace_name = self.notion_workspace_name(access_token)
|
workspace_name = self.notion_workspace_name(access_token)
|
||||||
workspace_icon = None
|
workspace_icon = None
|
||||||
workspace_id = current_user.current_tenant_id
|
workspace_id = current_user.current_tenant_id
|
||||||
# get all authorized pages
|
# get all authorized pages
|
||||||
pages = self.get_authorized_pages(access_token)
|
pages = self.get_authorized_pages(access_token)
|
||||||
source_info = {
|
source_info = self._build_source_info(
|
||||||
"workspace_name": workspace_name,
|
workspace_name=workspace_name,
|
||||||
"workspace_icon": workspace_icon,
|
workspace_icon=workspace_icon,
|
||||||
"workspace_id": workspace_id,
|
workspace_id=workspace_id,
|
||||||
"pages": pages,
|
pages=pages,
|
||||||
"total": len(pages),
|
)
|
||||||
}
|
|
||||||
# save data source binding
|
# save data source binding
|
||||||
data_source_binding = db.session.scalar(
|
data_source_binding = db.session.scalar(
|
||||||
select(DataSourceOauthBinding).where(
|
select(DataSourceOauthBinding).where(
|
||||||
@@ -106,7 +136,7 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
if data_source_binding:
|
if data_source_binding:
|
||||||
data_source_binding.source_info = source_info
|
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||||
data_source_binding.disabled = False
|
data_source_binding.disabled = False
|
||||||
data_source_binding.updated_at = naive_utc_now()
|
data_source_binding.updated_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@@ -114,13 +144,13 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
new_data_source_binding = DataSourceOauthBinding(
|
new_data_source_binding = DataSourceOauthBinding(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
source_info=source_info,
|
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||||
provider="notion",
|
provider="notion",
|
||||||
)
|
)
|
||||||
db.session.add(new_data_source_binding)
|
db.session.add(new_data_source_binding)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def sync_data_source(self, binding_id: str):
|
def sync_data_source(self, binding_id: str) -> None:
|
||||||
# save data source binding
|
# save data source binding
|
||||||
data_source_binding = db.session.scalar(
|
data_source_binding = db.session.scalar(
|
||||||
select(DataSourceOauthBinding).where(
|
select(DataSourceOauthBinding).where(
|
||||||
@@ -134,23 +164,22 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
if data_source_binding:
|
if data_source_binding:
|
||||||
# get all authorized pages
|
# get all authorized pages
|
||||||
pages = self.get_authorized_pages(data_source_binding.access_token)
|
pages = self.get_authorized_pages(data_source_binding.access_token)
|
||||||
source_info = data_source_binding.source_info
|
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
|
||||||
new_source_info = {
|
new_source_info = self._build_source_info(
|
||||||
"workspace_name": source_info["workspace_name"],
|
workspace_name=source_info["workspace_name"],
|
||||||
"workspace_icon": source_info["workspace_icon"],
|
workspace_icon=source_info["workspace_icon"],
|
||||||
"workspace_id": source_info["workspace_id"],
|
workspace_id=source_info["workspace_id"],
|
||||||
"pages": pages,
|
pages=pages,
|
||||||
"total": len(pages),
|
)
|
||||||
}
|
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
|
||||||
data_source_binding.source_info = new_source_info
|
|
||||||
data_source_binding.disabled = False
|
data_source_binding.disabled = False
|
||||||
data_source_binding.updated_at = naive_utc_now()
|
data_source_binding.updated_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Data source binding not found")
|
raise ValueError("Data source binding not found")
|
||||||
|
|
||||||
def get_authorized_pages(self, access_token: str):
|
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
|
||||||
pages = []
|
pages: list[NotionPageSummary] = []
|
||||||
page_results = self.notion_page_search(access_token)
|
page_results = self.notion_page_search(access_token)
|
||||||
database_results = self.notion_database_search(access_token)
|
database_results = self.notion_database_search(access_token)
|
||||||
# get page detail
|
# get page detail
|
||||||
@@ -187,7 +216,7 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
"parent_id": parent_id,
|
"parent_id": parent_id,
|
||||||
"type": "page",
|
"type": "page",
|
||||||
}
|
}
|
||||||
pages.append(page)
|
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
|
||||||
# get database detail
|
# get database detail
|
||||||
for database_result in database_results:
|
for database_result in database_results:
|
||||||
page_id = database_result["id"]
|
page_id = database_result["id"]
|
||||||
@@ -220,11 +249,11 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
"parent_id": parent_id,
|
"parent_id": parent_id,
|
||||||
"type": "database",
|
"type": "database",
|
||||||
}
|
}
|
||||||
pages.append(page)
|
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
|
||||||
return pages
|
return pages
|
||||||
|
|
||||||
def notion_page_search(self, access_token: str):
|
def notion_page_search(self, access_token: str) -> list[dict[str, Any]]:
|
||||||
results = []
|
results: list[dict[str, Any]] = []
|
||||||
next_cursor = None
|
next_cursor = None
|
||||||
has_more = True
|
has_more = True
|
||||||
|
|
||||||
@@ -249,7 +278,7 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def notion_block_parent_page_id(self, access_token: str, block_id: str):
|
def notion_block_parent_page_id(self, access_token: str, block_id: str) -> str:
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {access_token}",
|
"Authorization": f"Bearer {access_token}",
|
||||||
"Notion-Version": "2022-06-28",
|
"Notion-Version": "2022-06-28",
|
||||||
@@ -265,7 +294,7 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
return self.notion_block_parent_page_id(access_token, parent[parent_type])
|
return self.notion_block_parent_page_id(access_token, parent[parent_type])
|
||||||
return parent[parent_type]
|
return parent[parent_type]
|
||||||
|
|
||||||
def notion_workspace_name(self, access_token: str):
|
def notion_workspace_name(self, access_token: str) -> str:
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {access_token}",
|
"Authorization": f"Bearer {access_token}",
|
||||||
"Notion-Version": "2022-06-28",
|
"Notion-Version": "2022-06-28",
|
||||||
@@ -279,8 +308,8 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
return user_info["workspace_name"]
|
return user_info["workspace_name"]
|
||||||
return "workspace"
|
return "workspace"
|
||||||
|
|
||||||
def notion_database_search(self, access_token: str):
|
def notion_database_search(self, access_token: str) -> list[dict[str, Any]]:
|
||||||
results = []
|
results: list[dict[str, Any]] = []
|
||||||
next_cursor = None
|
next_cursor = None
|
||||||
has_more = True
|
has_more = True
|
||||||
|
|
||||||
@@ -303,3 +332,19 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
next_cursor = response_json.get("next_cursor", None)
|
next_cursor = response_json.get("next_cursor", None)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_source_info(
|
||||||
|
*,
|
||||||
|
workspace_name: str | None,
|
||||||
|
workspace_icon: str | None,
|
||||||
|
workspace_id: str | None,
|
||||||
|
pages: list[NotionPageSummary],
|
||||||
|
) -> NotionSourceInfo:
|
||||||
|
return {
|
||||||
|
"workspace_name": workspace_name,
|
||||||
|
"workspace_icon": workspace_icon,
|
||||||
|
"workspace_id": workspace_id,
|
||||||
|
"pages": pages,
|
||||||
|
"total": len(pages),
|
||||||
|
}
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTr
|
|||||||
from .model import Account
|
from .model import Account
|
||||||
from .types import EnumText, LongText, StringUUID
|
from .types import EnumText, LongText, StringUUID
|
||||||
|
|
||||||
|
TriggerJsonObject = dict[str, object]
|
||||||
|
TriggerCredentials = dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
class WorkflowTriggerLogDict(TypedDict):
|
class WorkflowTriggerLogDict(TypedDict):
|
||||||
id: str
|
id: str
|
||||||
@@ -89,10 +92,14 @@ class TriggerSubscription(TypeBase):
|
|||||||
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
|
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
|
||||||
)
|
)
|
||||||
endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
|
endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
|
||||||
parameters: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON")
|
parameters: Mapped[TriggerJsonObject] = mapped_column(
|
||||||
properties: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON")
|
sa.JSON, nullable=False, comment="Subscription parameters JSON"
|
||||||
|
)
|
||||||
|
properties: Mapped[TriggerJsonObject] = mapped_column(
|
||||||
|
sa.JSON, nullable=False, comment="Subscription properties JSON"
|
||||||
|
)
|
||||||
|
|
||||||
credentials: Mapped[dict[str, Any]] = mapped_column(
|
credentials: Mapped[TriggerCredentials] = mapped_column(
|
||||||
sa.JSON, nullable=False, comment="Subscription credentials JSON"
|
sa.JSON, nullable=False, comment="Subscription credentials JSON"
|
||||||
)
|
)
|
||||||
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
|
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
|
||||||
@@ -200,8 +207,8 @@ class TriggerOAuthTenantClient(TypeBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def oauth_params(self) -> Mapping[str, Any]:
|
def oauth_params(self) -> Mapping[str, object]:
|
||||||
return cast(Mapping[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
|
return cast(TriggerJsonObject, json.loads(self.encrypted_oauth_params or "{}"))
|
||||||
|
|
||||||
|
|
||||||
class WorkflowTriggerLog(TypeBase):
|
class WorkflowTriggerLog(TypeBase):
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from sqlalchemy import (
|
|||||||
orm,
|
orm,
|
||||||
select,
|
select,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||||
@@ -33,7 +33,7 @@ from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
|
|||||||
from dify_graph.file.constants import maybe_file_object
|
from dify_graph.file.constants import maybe_file_object
|
||||||
from dify_graph.file.models import File
|
from dify_graph.file.models import File
|
||||||
from dify_graph.variables import utils as variable_utils
|
from dify_graph.variables import utils as variable_utils
|
||||||
from dify_graph.variables.variables import FloatVariable, IntegerVariable, StringVariable
|
from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable
|
||||||
from extensions.ext_storage import Storage
|
from extensions.ext_storage import Storage
|
||||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
@@ -59,6 +59,9 @@ from .types import EnumText, LongText, StringUUID
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SerializedWorkflowValue = dict[str, Any]
|
||||||
|
SerializedWorkflowVariables = dict[str, SerializedWorkflowValue]
|
||||||
|
|
||||||
|
|
||||||
class WorkflowContentDict(TypedDict):
|
class WorkflowContentDict(TypedDict):
|
||||||
graph: Mapping[str, Any]
|
graph: Mapping[str, Any]
|
||||||
@@ -405,7 +408,7 @@ class Workflow(Base): # bug
|
|||||||
|
|
||||||
def rag_pipeline_user_input_form(self) -> list:
|
def rag_pipeline_user_input_form(self) -> list:
|
||||||
# get user_input_form from start node
|
# get user_input_form from start node
|
||||||
variables: list[Any] = self.rag_pipeline_variables
|
variables: list[SerializedWorkflowValue] = self.rag_pipeline_variables
|
||||||
|
|
||||||
return variables
|
return variables
|
||||||
|
|
||||||
@@ -448,17 +451,13 @@ class Workflow(Base): # bug
|
|||||||
def environment_variables(
|
def environment_variables(
|
||||||
self,
|
self,
|
||||||
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||||
# TODO: find some way to init `self._environment_variables` when instance created.
|
|
||||||
if self._environment_variables is None:
|
|
||||||
self._environment_variables = "{}"
|
|
||||||
|
|
||||||
# Use workflow.tenant_id to avoid relying on request user in background threads
|
# Use workflow.tenant_id to avoid relying on request user in background threads
|
||||||
tenant_id = self.tenant_id
|
tenant_id = self.tenant_id
|
||||||
|
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}")
|
environment_variables_dict = cast(SerializedWorkflowVariables, json.loads(self._environment_variables or "{}"))
|
||||||
results = [
|
results = [
|
||||||
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
|
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
|
||||||
]
|
]
|
||||||
@@ -536,11 +535,7 @@ class Workflow(Base): # bug
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def conversation_variables(self) -> Sequence[VariableBase]:
|
def conversation_variables(self) -> Sequence[VariableBase]:
|
||||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._conversation_variables or "{}"))
|
||||||
if self._conversation_variables is None:
|
|
||||||
self._conversation_variables = "{}"
|
|
||||||
|
|
||||||
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
|
|
||||||
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
|
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -552,19 +547,20 @@ class Workflow(Base): # bug
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rag_pipeline_variables(self) -> list[dict]:
|
def rag_pipeline_variables(self) -> list[SerializedWorkflowValue]:
|
||||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._rag_pipeline_variables or "{}"))
|
||||||
if self._rag_pipeline_variables is None:
|
return [RAGPipelineVariable.model_validate(item).model_dump(mode="json") for item in variables_dict.values()]
|
||||||
self._rag_pipeline_variables = "{}"
|
|
||||||
|
|
||||||
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
|
|
||||||
results = list(variables_dict.values())
|
|
||||||
return results
|
|
||||||
|
|
||||||
@rag_pipeline_variables.setter
|
@rag_pipeline_variables.setter
|
||||||
def rag_pipeline_variables(self, values: list[dict]) -> None:
|
def rag_pipeline_variables(self, values: Sequence[Mapping[str, Any] | RAGPipelineVariable]) -> None:
|
||||||
self._rag_pipeline_variables = json.dumps(
|
self._rag_pipeline_variables = json.dumps(
|
||||||
{item["variable"]: item for item in values},
|
{
|
||||||
|
rag_pipeline_variable.variable: rag_pipeline_variable.model_dump(mode="json")
|
||||||
|
for rag_pipeline_variable in (
|
||||||
|
item if isinstance(item, RAGPipelineVariable) else RAGPipelineVariable.model_validate(item)
|
||||||
|
for item in values
|
||||||
|
)
|
||||||
|
},
|
||||||
ensure_ascii=False,
|
ensure_ascii=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -802,10 +798,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||||||
|
|
||||||
__tablename__ = "workflow_node_executions"
|
__tablename__ = "workflow_node_executions"
|
||||||
|
|
||||||
@declared_attr.directive
|
__table_args__ = (
|
||||||
@classmethod
|
|
||||||
def __table_args__(cls) -> Any:
|
|
||||||
return (
|
|
||||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||||
Index(
|
Index(
|
||||||
"workflow_node_execution_workflow_run_id_idx",
|
"workflow_node_execution_workflow_run_id_idx",
|
||||||
@@ -828,16 +821,11 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||||||
"node_execution_id",
|
"node_execution_id",
|
||||||
),
|
),
|
||||||
Index(
|
Index(
|
||||||
# The first argument is the index name,
|
|
||||||
# which we leave as `None`` to allow auto-generation by the ORM.
|
|
||||||
None,
|
None,
|
||||||
cls.tenant_id,
|
"tenant_id",
|
||||||
cls.workflow_id,
|
"workflow_id",
|
||||||
cls.node_id,
|
"node_id",
|
||||||
# MyPy may flag the following line because it doesn't recognize that
|
sa.desc("created_at"),
|
||||||
# the `declared_attr` decorator passes the receiving class as the first
|
|
||||||
# argument to this method, allowing us to reference class attributes.
|
|
||||||
cls.created_at.desc(),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
configs/middleware/cache/redis_pubsub_config.py
|
|
||||||
controllers/console/app/annotation.py
|
controllers/console/app/annotation.py
|
||||||
controllers/console/app/app.py
|
controllers/console/app/app.py
|
||||||
controllers/console/app/app_import.py
|
controllers/console/app/app_import.py
|
||||||
@@ -138,8 +137,6 @@ dify_graph/nodes/trigger_webhook/node.py
|
|||||||
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
|
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
|
||||||
dify_graph/nodes/variable_assigner/v1/node.py
|
dify_graph/nodes/variable_assigner/v1/node.py
|
||||||
dify_graph/nodes/variable_assigner/v2/node.py
|
dify_graph/nodes/variable_assigner/v2/node.py
|
||||||
dify_graph/variables/types.py
|
|
||||||
extensions/ext_fastopenapi.py
|
|
||||||
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
|
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
|
||||||
extensions/otel/instrumentation.py
|
extensions/otel/instrumentation.py
|
||||||
extensions/otel/runtime.py
|
extensions/otel/runtime.py
|
||||||
@@ -156,19 +153,7 @@ extensions/storage/oracle_oci_storage.py
|
|||||||
extensions/storage/supabase_storage.py
|
extensions/storage/supabase_storage.py
|
||||||
extensions/storage/tencent_cos_storage.py
|
extensions/storage/tencent_cos_storage.py
|
||||||
extensions/storage/volcengine_tos_storage.py
|
extensions/storage/volcengine_tos_storage.py
|
||||||
factories/variable_factory.py
|
|
||||||
libs/external_api.py
|
|
||||||
libs/gmpy2_pkcs10aep_cipher.py
|
libs/gmpy2_pkcs10aep_cipher.py
|
||||||
libs/helper.py
|
|
||||||
libs/login.py
|
|
||||||
libs/module_loading.py
|
|
||||||
libs/oauth.py
|
|
||||||
libs/oauth_data_source.py
|
|
||||||
models/trigger.py
|
|
||||||
models/workflow.py
|
|
||||||
repositories/sqlalchemy_api_workflow_node_execution_repository.py
|
|
||||||
repositories/sqlalchemy_api_workflow_run_repository.py
|
|
||||||
repositories/sqlalchemy_execution_extra_content_repository.py
|
|
||||||
schedule/queue_monitor_task.py
|
schedule/queue_monitor_task.py
|
||||||
services/account_service.py
|
services/account_service.py
|
||||||
services/audio_service.py
|
services/audio_service.py
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
|||||||
import json
|
import json
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import cast
|
from typing import Protocol, cast
|
||||||
|
|
||||||
from sqlalchemy import asc, delete, desc, func, select
|
from sqlalchemy import asc, delete, desc, func, select
|
||||||
from sqlalchemy.engine import CursorResult
|
from sqlalchemy.engine import CursorResult
|
||||||
@@ -22,6 +22,20 @@ from repositories.api_workflow_node_execution_repository import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _WorkflowNodeExecutionSnapshotRow(Protocol):
|
||||||
|
id: str
|
||||||
|
node_execution_id: str | None
|
||||||
|
node_id: str
|
||||||
|
node_type: str
|
||||||
|
title: str
|
||||||
|
index: int
|
||||||
|
status: WorkflowNodeExecutionStatus
|
||||||
|
elapsed_time: float | None
|
||||||
|
created_at: datetime
|
||||||
|
finished_at: datetime | None
|
||||||
|
execution_metadata: str | None
|
||||||
|
|
||||||
|
|
||||||
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
|
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
|
||||||
"""
|
"""
|
||||||
SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
|
SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
|
||||||
@@ -40,6 +54,8 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
|||||||
- Thread-safe database operations using session-per-request pattern
|
- Thread-safe database operations using session-per-request pattern
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_session_maker: sessionmaker[Session]
|
||||||
|
|
||||||
def __init__(self, session_maker: sessionmaker[Session]):
|
def __init__(self, session_maker: sessionmaker[Session]):
|
||||||
"""
|
"""
|
||||||
Initialize the repository with a sessionmaker.
|
Initialize the repository with a sessionmaker.
|
||||||
@@ -156,12 +172,12 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
|||||||
)
|
)
|
||||||
|
|
||||||
with self._session_maker() as session:
|
with self._session_maker() as session:
|
||||||
rows = session.execute(stmt).all()
|
rows = cast(Sequence[_WorkflowNodeExecutionSnapshotRow], session.execute(stmt).all())
|
||||||
|
|
||||||
return [self._row_to_snapshot(row) for row in rows]
|
return [self._row_to_snapshot(row) for row in rows]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot:
|
def _row_to_snapshot(row: _WorkflowNodeExecutionSnapshotRow) -> WorkflowNodeExecutionSnapshot:
|
||||||
metadata: dict[str, object] = {}
|
metadata: dict[str, object] = {}
|
||||||
execution_metadata = getattr(row, "execution_metadata", None)
|
execution_metadata = getattr(row, "execution_metadata", None)
|
||||||
if execution_metadata:
|
if execution_metadata:
|
||||||
|
|||||||
Reference in New Issue
Block a user