refactor(variables): clarify base vs union type naming (#30634)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
-LAN-
2026-01-13 22:39:34 +08:00
committed by GitHub
parent 91da784f84
commit 206706987d
22 changed files with 124 additions and 125 deletions

View File

@@ -1,11 +1,9 @@
from __future__ import annotations
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4
import sqlalchemy as sa
@@ -46,7 +44,7 @@ if TYPE_CHECKING:
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter
from core.variables import SecretVariable, Segment, SegmentType, Variable
from core.variables import SecretVariable, Segment, SegmentType, VariableBase
from factories import variable_factory
from libs import helper
@@ -69,7 +67,7 @@ class WorkflowType(StrEnum):
RAG_PIPELINE = "rag-pipeline"
@classmethod
def value_of(cls, value: str) -> WorkflowType:
def value_of(cls, value: str) -> "WorkflowType":
"""
Get value of given mode.
@@ -82,7 +80,7 @@ class WorkflowType(StrEnum):
raise ValueError(f"invalid workflow type value {value}")
@classmethod
def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType:
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
"""
Get workflow type from app mode.
@@ -178,12 +176,12 @@ class Workflow(Base): # bug
graph: str,
features: str,
created_by: str,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list[dict],
marked_name: str = "",
marked_comment: str = "",
) -> Workflow:
) -> "Workflow":
workflow = Workflow()
workflow.id = str(uuid4())
workflow.tenant_id = tenant_id
@@ -447,7 +445,7 @@ class Workflow(Base): # bug
# decrypt secret variables value
def decrypt_func(
var: Variable,
var: VariableBase,
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
@@ -463,7 +461,7 @@ class Workflow(Base): # bug
return decrypted_results
@environment_variables.setter
def environment_variables(self, value: Sequence[Variable]):
def environment_variables(self, value: Sequence[VariableBase]):
if not value:
self._environment_variables = "{}"
return
@@ -487,7 +485,7 @@ class Workflow(Base): # bug
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
# encrypt secret variables value
def encrypt_func(var: Variable) -> Variable:
def encrypt_func(var: VariableBase) -> VariableBase:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
else:
@@ -517,7 +515,7 @@ class Workflow(Base): # bug
return result
@property
def conversation_variables(self) -> Sequence[Variable]:
def conversation_variables(self) -> Sequence[VariableBase]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
@@ -527,7 +525,7 @@ class Workflow(Base): # bug
return results
@conversation_variables.setter
def conversation_variables(self, value: Sequence[Variable]):
def conversation_variables(self, value: Sequence[VariableBase]):
self._conversation_variables = json.dumps(
{var.name: var.model_dump() for var in value},
ensure_ascii=False,
@@ -622,7 +620,7 @@ class WorkflowRun(Base):
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
pause: Mapped[WorkflowPause | None] = orm.relationship(
pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
"WorkflowPause",
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
uselist=False,
@@ -692,7 +690,7 @@ class WorkflowRun(Base):
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> WorkflowRun:
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
return cls(
id=data.get("id"),
tenant_id=data.get("tenant_id"),
@@ -844,7 +842,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
created_by: Mapped[str] = mapped_column(StringUUID)
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship(
offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship(
"WorkflowNodeExecutionOffload",
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
uselist=True,
@@ -854,13 +852,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
@staticmethod
def preload_offload_data(
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
):
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
@staticmethod
def preload_offload_data_and_files(
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
):
return query.options(
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
@@ -935,7 +933,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
)
return extras
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None:
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
@property
@@ -1049,7 +1047,7 @@ class WorkflowNodeExecutionOffload(Base):
back_populates="offload_data",
)
file: Mapped[UploadFile | None] = orm.relationship(
file: Mapped[Optional["UploadFile"]] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
@@ -1067,7 +1065,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
INSTALLED_APP = "installed-app"
@classmethod
def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom:
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
"""
Get value of given mode.
@@ -1184,7 +1182,7 @@ class ConversationVariable(TypeBase):
)
@classmethod
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable:
def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable":
obj = cls(
id=variable.id,
app_id=app_id,
@@ -1193,7 +1191,7 @@ class ConversationVariable(TypeBase):
)
return obj
def to_variable(self) -> Variable:
def to_variable(self) -> VariableBase:
mapping = json.loads(self.data)
return variable_factory.build_conversation_variable_from_mapping(mapping)
@@ -1337,7 +1335,7 @@ class WorkflowDraftVariable(Base):
)
# Relationship to WorkflowDraftVariableFile
variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship(
variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
@@ -1507,7 +1505,7 @@ class WorkflowDraftVariable(Base):
node_execution_id: str | None,
description: str = "",
file_id: str | None = None,
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
variable.id = str(uuid4())
variable.created_at = naive_utc_now()
@@ -1530,7 +1528,7 @@ class WorkflowDraftVariable(Base):
name: str,
value: Segment,
description: str = "",
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=CONVERSATION_VARIABLE_NODE_ID,
@@ -1551,7 +1549,7 @@ class WorkflowDraftVariable(Base):
value: Segment,
node_execution_id: str,
editable: bool = False,
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=SYSTEM_VARIABLE_NODE_ID,
@@ -1574,7 +1572,7 @@ class WorkflowDraftVariable(Base):
visible: bool = True,
editable: bool = True,
file_id: str | None = None,
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=node_id,
@@ -1670,7 +1668,7 @@ class WorkflowDraftVariableFile(Base):
)
# Relationship to UploadFile
upload_file: Mapped[UploadFile] = orm.relationship(
upload_file: Mapped["UploadFile"] = orm.relationship(
foreign_keys=[upload_file_id],
lazy="raise",
uselist=False,
@@ -1737,7 +1735,7 @@ class WorkflowPause(DefaultFieldsMixin, Base):
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
# Relationship to WorkflowRun
workflow_run: Mapped[WorkflowRun] = orm.relationship(
workflow_run: Mapped["WorkflowRun"] = orm.relationship(
foreign_keys=[workflow_run_id],
# require explicit preloading.
lazy="raise",
@@ -1793,7 +1791,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
)
@classmethod
def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason:
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
if isinstance(pause_reason, HumanInputRequired):
return cls(
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id