mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 18:49:29 +08:00
refactor(api): tighten phase 1 shared type contracts (#33453)
This commit is contained in:
@@ -19,7 +19,7 @@ from sqlalchemy import (
|
||||
orm,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||
@@ -33,7 +33,7 @@ from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
|
||||
from dify_graph.file.constants import maybe_file_object
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.variables import utils as variable_utils
|
||||
from dify_graph.variables.variables import FloatVariable, IntegerVariable, StringVariable
|
||||
from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable
|
||||
from extensions.ext_storage import Storage
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -59,6 +59,9 @@ from .types import EnumText, LongText, StringUUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SerializedWorkflowValue = dict[str, Any]
|
||||
SerializedWorkflowVariables = dict[str, SerializedWorkflowValue]
|
||||
|
||||
|
||||
class WorkflowContentDict(TypedDict):
|
||||
graph: Mapping[str, Any]
|
||||
@@ -405,7 +408,7 @@ class Workflow(Base): # bug
|
||||
|
||||
def rag_pipeline_user_input_form(self) -> list:
|
||||
# get user_input_form from start node
|
||||
variables: list[Any] = self.rag_pipeline_variables
|
||||
variables: list[SerializedWorkflowValue] = self.rag_pipeline_variables
|
||||
|
||||
return variables
|
||||
|
||||
@@ -448,17 +451,13 @@ class Workflow(Base): # bug
|
||||
def environment_variables(
|
||||
self,
|
||||
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||
# TODO: find some way to init `self._environment_variables` when instance created.
|
||||
if self._environment_variables is None:
|
||||
self._environment_variables = "{}"
|
||||
|
||||
# Use workflow.tenant_id to avoid relying on request user in background threads
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
if not tenant_id:
|
||||
return []
|
||||
|
||||
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}")
|
||||
environment_variables_dict = cast(SerializedWorkflowVariables, json.loads(self._environment_variables or "{}"))
|
||||
results = [
|
||||
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
|
||||
]
|
||||
@@ -536,11 +535,7 @@ class Workflow(Base): # bug
|
||||
|
||||
@property
|
||||
def conversation_variables(self) -> Sequence[VariableBase]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._conversation_variables is None:
|
||||
self._conversation_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
|
||||
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._conversation_variables or "{}"))
|
||||
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
|
||||
return results
|
||||
|
||||
@@ -552,19 +547,20 @@ class Workflow(Base): # bug
|
||||
)
|
||||
|
||||
@property
|
||||
def rag_pipeline_variables(self) -> list[dict]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._rag_pipeline_variables is None:
|
||||
self._rag_pipeline_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
|
||||
results = list(variables_dict.values())
|
||||
return results
|
||||
def rag_pipeline_variables(self) -> list[SerializedWorkflowValue]:
|
||||
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._rag_pipeline_variables or "{}"))
|
||||
return [RAGPipelineVariable.model_validate(item).model_dump(mode="json") for item in variables_dict.values()]
|
||||
|
||||
@rag_pipeline_variables.setter
|
||||
def rag_pipeline_variables(self, values: list[dict]) -> None:
|
||||
def rag_pipeline_variables(self, values: Sequence[Mapping[str, Any] | RAGPipelineVariable]) -> None:
|
||||
self._rag_pipeline_variables = json.dumps(
|
||||
{item["variable"]: item for item in values},
|
||||
{
|
||||
rag_pipeline_variable.variable: rag_pipeline_variable.model_dump(mode="json")
|
||||
for rag_pipeline_variable in (
|
||||
item if isinstance(item, RAGPipelineVariable) else RAGPipelineVariable.model_validate(item)
|
||||
for item in values
|
||||
)
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -802,44 +798,36 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
|
||||
__tablename__ = "workflow_node_executions"
|
||||
|
||||
@declared_attr.directive
|
||||
@classmethod
|
||||
def __table_args__(cls) -> Any:
|
||||
return (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_node_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_id_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
Index(
|
||||
# The first argument is the index name,
|
||||
# which we leave as `None`` to allow auto-generation by the ORM.
|
||||
None,
|
||||
cls.tenant_id,
|
||||
cls.workflow_id,
|
||||
cls.node_id,
|
||||
# MyPy may flag the following line because it doesn't recognize that
|
||||
# the `declared_attr` decorator passes the receiving class as the first
|
||||
# argument to this method, allowing us to reference class attributes.
|
||||
cls.created_at.desc(),
|
||||
),
|
||||
)
|
||||
__table_args__ = (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_node_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_id_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
Index(
|
||||
None,
|
||||
"tenant_id",
|
||||
"workflow_id",
|
||||
"node_id",
|
||||
sa.desc("created_at"),
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
|
||||
Reference in New Issue
Block a user