refactor(api): tighten phase 1 shared type contracts (#33453)

This commit is contained in:
盐粒 Yanli
2026-03-17 17:50:51 +08:00
committed by GitHub
parent a592c53573
commit a717519822
14 changed files with 313 additions and 196 deletions

View File

@@ -19,7 +19,7 @@ from sqlalchemy import (
orm,
select,
)
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
from sqlalchemy.orm import Mapped, mapped_column
from typing_extensions import deprecated
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
@@ -33,7 +33,7 @@ from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
from dify_graph.file.constants import maybe_file_object
from dify_graph.file.models import File
from dify_graph.variables import utils as variable_utils
from dify_graph.variables.variables import FloatVariable, IntegerVariable, StringVariable
from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable
from extensions.ext_storage import Storage
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from libs.datetime_utils import naive_utc_now
@@ -59,6 +59,9 @@ from .types import EnumText, LongText, StringUUID
logger = logging.getLogger(__name__)
SerializedWorkflowValue = dict[str, Any]
SerializedWorkflowVariables = dict[str, SerializedWorkflowValue]
class WorkflowContentDict(TypedDict):
graph: Mapping[str, Any]
@@ -405,7 +408,7 @@ class Workflow(Base): # bug
def rag_pipeline_user_input_form(self) -> list:
# get user_input_form from start node
variables: list[Any] = self.rag_pipeline_variables
variables: list[SerializedWorkflowValue] = self.rag_pipeline_variables
return variables
@@ -448,17 +451,13 @@ class Workflow(Base): # bug
def environment_variables(
self,
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# TODO: find some way to init `self._environment_variables` when instance created.
if self._environment_variables is None:
self._environment_variables = "{}"
# Use workflow.tenant_id to avoid relying on request user in background threads
tenant_id = self.tenant_id
if not tenant_id:
return []
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}")
environment_variables_dict = cast(SerializedWorkflowVariables, json.loads(self._environment_variables or "{}"))
results = [
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
]
@@ -536,11 +535,7 @@ class Workflow(Base): # bug
@property
def conversation_variables(self) -> Sequence[VariableBase]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._conversation_variables or "{}"))
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
return results
@@ -552,19 +547,20 @@ class Workflow(Base): # bug
)
@property
def rag_pipeline_variables(self) -> list[dict]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._rag_pipeline_variables is None:
self._rag_pipeline_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
results = list(variables_dict.values())
return results
def rag_pipeline_variables(self) -> list[SerializedWorkflowValue]:
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._rag_pipeline_variables or "{}"))
return [RAGPipelineVariable.model_validate(item).model_dump(mode="json") for item in variables_dict.values()]
@rag_pipeline_variables.setter
def rag_pipeline_variables(self, values: list[dict]) -> None:
def rag_pipeline_variables(self, values: Sequence[Mapping[str, Any] | RAGPipelineVariable]) -> None:
self._rag_pipeline_variables = json.dumps(
{item["variable"]: item for item in values},
{
rag_pipeline_variable.variable: rag_pipeline_variable.model_dump(mode="json")
for rag_pipeline_variable in (
item if isinstance(item, RAGPipelineVariable) else RAGPipelineVariable.model_validate(item)
for item in values
)
},
ensure_ascii=False,
)
@@ -802,44 +798,36 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
__tablename__ = "workflow_node_executions"
@declared_attr.directive
@classmethod
def __table_args__(cls) -> Any:
return (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
"workflow_node_execution_workflow_run_id_idx",
"workflow_run_id",
),
Index(
"workflow_node_execution_node_run_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_id",
),
Index(
"workflow_node_execution_id_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_execution_id",
),
Index(
# The first argument is the index name,
# which we leave as `None`` to allow auto-generation by the ORM.
None,
cls.tenant_id,
cls.workflow_id,
cls.node_id,
# MyPy may flag the following line because it doesn't recognize that
# the `declared_attr` decorator passes the receiving class as the first
# argument to this method, allowing us to reference class attributes.
cls.created_at.desc(),
),
)
__table_args__ = (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
"workflow_node_execution_workflow_run_id_idx",
"workflow_run_id",
),
Index(
"workflow_node_execution_node_run_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_id",
),
Index(
"workflow_node_execution_id_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_execution_id",
),
Index(
None,
"tenant_id",
"workflow_id",
"node_id",
sa.desc("created_at"),
),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
tenant_id: Mapped[str] = mapped_column(StringUUID)