refactor(api): add TypedDict definitions to models/model.py (#32925)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
statxc
2026-03-06 01:42:54 +02:00
committed by GitHub
parent 6bd1be9e16
commit 741d48560d
23 changed files with 453 additions and 142 deletions

View File

@@ -7,7 +7,7 @@ from collections.abc import Mapping, Sequence
from datetime import datetime
from decimal import Decimal
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast
from uuid import uuid4
import sqlalchemy as sa
@@ -15,6 +15,7 @@ from flask import request
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import TypedDict
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
@@ -36,6 +37,259 @@ if TYPE_CHECKING:
from .workflow import Workflow
# --- TypedDict definitions for structured dict return types ---
class EnabledConfig(TypedDict):
enabled: bool
class EmbeddingModelInfo(TypedDict):
embedding_provider_name: str
embedding_model_name: str
class AnnotationReplyDisabledConfig(TypedDict):
enabled: Literal[False]
class AnnotationReplyEnabledConfig(TypedDict):
id: str
enabled: Literal[True]
score_threshold: float
embedding_model: EmbeddingModelInfo
AnnotationReplyConfig = AnnotationReplyEnabledConfig | AnnotationReplyDisabledConfig
class SensitiveWordAvoidanceConfig(TypedDict):
enabled: bool
type: str
config: dict[str, Any]
class AgentToolConfig(TypedDict):
provider_type: str
provider_id: str
tool_name: str
tool_parameters: dict[str, Any]
plugin_unique_identifier: NotRequired[str | None]
credential_id: NotRequired[str | None]
class AgentModeConfig(TypedDict):
enabled: bool
strategy: str | None
tools: list[AgentToolConfig | dict[str, Any]]
prompt: str | None
class ImageUploadConfig(TypedDict):
enabled: bool
number_limits: int
detail: str
transfer_methods: list[str]
class FileUploadConfig(TypedDict):
image: ImageUploadConfig
class DeletedToolInfo(TypedDict):
type: str
tool_name: str
provider_id: str
class ExternalDataToolConfig(TypedDict):
enabled: bool
variable: str
type: str
config: dict[str, Any]
class UserInputFormItemConfig(TypedDict):
variable: str
label: str
description: NotRequired[str]
required: NotRequired[bool]
max_length: NotRequired[int]
options: NotRequired[list[str]]
default: NotRequired[str]
type: NotRequired[str]
config: NotRequired[dict[str, Any]]
# Each item is a single-key dict, e.g. {"text-input": UserInputFormItemConfig}
UserInputFormItem = dict[str, UserInputFormItemConfig]
class DatasetConfigs(TypedDict):
retrieval_model: str
datasets: NotRequired[dict[str, Any]]
top_k: NotRequired[int]
score_threshold: NotRequired[float]
score_threshold_enabled: NotRequired[bool]
reranking_model: NotRequired[dict[str, Any] | None]
weights: NotRequired[dict[str, Any] | None]
reranking_enabled: NotRequired[bool]
reranking_mode: NotRequired[str]
metadata_filtering_mode: NotRequired[str]
metadata_model_config: NotRequired[dict[str, Any] | None]
metadata_filtering_conditions: NotRequired[dict[str, Any] | None]
class ChatPromptMessage(TypedDict):
text: str
role: str
class ChatPromptConfig(TypedDict, total=False):
prompt: list[ChatPromptMessage]
class CompletionPromptText(TypedDict):
text: str
class ConversationHistoriesRole(TypedDict):
user_prefix: str
assistant_prefix: str
class CompletionPromptConfig(TypedDict):
prompt: CompletionPromptText
conversation_histories_role: NotRequired[ConversationHistoriesRole]
class ModelConfig(TypedDict):
provider: str
name: str
mode: str
completion_params: NotRequired[dict[str, Any]]
class AppModelConfigDict(TypedDict):
opening_statement: str | None
suggested_questions: list[str]
suggested_questions_after_answer: EnabledConfig
speech_to_text: EnabledConfig
text_to_speech: EnabledConfig
retriever_resource: EnabledConfig
annotation_reply: AnnotationReplyConfig
more_like_this: EnabledConfig
sensitive_word_avoidance: SensitiveWordAvoidanceConfig
external_data_tools: list[ExternalDataToolConfig]
model: ModelConfig
user_input_form: list[UserInputFormItem]
dataset_query_variable: str | None
pre_prompt: str | None
agent_mode: AgentModeConfig
prompt_type: str
chat_prompt_config: ChatPromptConfig
completion_prompt_config: CompletionPromptConfig
dataset_configs: DatasetConfigs
file_upload: FileUploadConfig
# Added dynamically in Conversation.model_config
model_id: NotRequired[str | None]
provider: NotRequired[str | None]
class ConversationDict(TypedDict):
id: str
app_id: str
app_model_config_id: str | None
model_provider: str | None
override_model_configs: str | None
model_id: str | None
mode: str
name: str
summary: str | None
inputs: dict[str, Any]
introduction: str | None
system_instruction: str | None
system_instruction_tokens: int
status: str
invoke_from: str | None
from_source: str
from_end_user_id: str | None
from_account_id: str | None
read_at: datetime | None
read_account_id: str | None
dialogue_count: int
created_at: datetime
updated_at: datetime
class MessageDict(TypedDict):
id: str
app_id: str
conversation_id: str
model_id: str | None
inputs: dict[str, Any]
query: str
total_price: Decimal | None
message: dict[str, Any]
answer: str
status: str
error: str | None
message_metadata: dict[str, Any]
from_source: str
from_end_user_id: str | None
from_account_id: str | None
created_at: str
updated_at: str
agent_based: bool
workflow_run_id: str | None
class MessageFeedbackDict(TypedDict):
id: str
app_id: str
conversation_id: str
message_id: str
rating: str
content: str | None
from_source: str
from_end_user_id: str | None
from_account_id: str | None
created_at: str
updated_at: str
class MessageFileInfo(TypedDict, total=False):
belongs_to: str | None
upload_file_id: str | None
id: str
tenant_id: str
type: str
transfer_method: str
remote_url: str | None
related_id: str | None
filename: str | None
extension: str | None
mime_type: str | None
size: int
dify_model_identity: str
url: str | None
class ExtraContentDict(TypedDict, total=False):
type: str
workflow_run_id: str
class TraceAppConfigDict(TypedDict):
id: str
app_id: str
tracing_provider: str | None
tracing_config: dict[str, Any]
is_active: bool
created_at: str | None
updated_at: str | None
class DifySetup(TypeBase):
__tablename__ = "dify_setups"
__table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
@@ -176,7 +430,7 @@ class App(Base):
return str(self.mode)
@property
def deleted_tools(self) -> list[dict[str, str]]:
def deleted_tools(self) -> list[DeletedToolInfo]:
from core.tools.tool_manager import ToolManager, ToolProviderType
from services.plugin.plugin_service import PluginService
@@ -257,7 +511,7 @@ class App(Base):
provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
}
deleted_tools: list[dict[str, str]] = []
deleted_tools: list[DeletedToolInfo] = []
for tool in tools:
keys = list(tool.keys())
@@ -364,35 +618,38 @@ class AppModelConfig(TypeBase):
return app
@property
def model_dict(self) -> dict[str, Any]:
return json.loads(self.model) if self.model else {}
def model_dict(self) -> ModelConfig:
return cast(ModelConfig, json.loads(self.model) if self.model else {})
@property
def suggested_questions_list(self) -> list[str]:
return json.loads(self.suggested_questions) if self.suggested_questions else []
@property
def suggested_questions_after_answer_dict(self) -> dict[str, Any]:
return (
def suggested_questions_after_answer_dict(self) -> EnabledConfig:
return cast(
EnabledConfig,
json.loads(self.suggested_questions_after_answer)
if self.suggested_questions_after_answer
else {"enabled": False}
else {"enabled": False},
)
@property
def speech_to_text_dict(self) -> dict[str, Any]:
return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}
def speech_to_text_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False})
@property
def text_to_speech_dict(self) -> dict[str, Any]:
return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}
def text_to_speech_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False})
@property
def retriever_resource_dict(self) -> dict[str, Any]:
return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
def retriever_resource_dict(self) -> EnabledConfig:
return cast(
EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
)
@property
def annotation_reply_dict(self) -> dict[str, Any]:
def annotation_reply_dict(self) -> AnnotationReplyConfig:
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
)
@@ -415,56 +672,62 @@ class AppModelConfig(TypeBase):
return {"enabled": False}
@property
def more_like_this_dict(self) -> dict[str, Any]:
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
def more_like_this_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False})
@property
def sensitive_word_avoidance_dict(self) -> dict[str, Any]:
return (
def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig:
return cast(
SensitiveWordAvoidanceConfig,
json.loads(self.sensitive_word_avoidance)
if self.sensitive_word_avoidance
else {"enabled": False, "type": "", "configs": []}
else {"enabled": False, "type": "", "config": {}},
)
@property
def external_data_tools_list(self) -> list[dict[str, Any]]:
def external_data_tools_list(self) -> list[ExternalDataToolConfig]:
return json.loads(self.external_data_tools) if self.external_data_tools else []
@property
def user_input_form_list(self) -> list[dict[str, Any]]:
def user_input_form_list(self) -> list[UserInputFormItem]:
return json.loads(self.user_input_form) if self.user_input_form else []
@property
def agent_mode_dict(self) -> dict[str, Any]:
return (
def agent_mode_dict(self) -> AgentModeConfig:
return cast(
AgentModeConfig,
json.loads(self.agent_mode)
if self.agent_mode
else {"enabled": False, "strategy": None, "tools": [], "prompt": None}
else {"enabled": False, "strategy": None, "tools": [], "prompt": None},
)
@property
def chat_prompt_config_dict(self) -> dict[str, Any]:
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
def chat_prompt_config_dict(self) -> ChatPromptConfig:
return cast(ChatPromptConfig, json.loads(self.chat_prompt_config) if self.chat_prompt_config else {})
@property
def completion_prompt_config_dict(self) -> dict[str, Any]:
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
def completion_prompt_config_dict(self) -> CompletionPromptConfig:
return cast(
CompletionPromptConfig,
json.loads(self.completion_prompt_config) if self.completion_prompt_config else {},
)
@property
def dataset_configs_dict(self) -> dict[str, Any]:
def dataset_configs_dict(self) -> DatasetConfigs:
if self.dataset_configs:
dataset_configs: dict[str, Any] = json.loads(self.dataset_configs)
dataset_configs = json.loads(self.dataset_configs)
if "retrieval_model" not in dataset_configs:
return {"retrieval_model": "single"}
else:
return dataset_configs
return cast(DatasetConfigs, dataset_configs)
return {
"retrieval_model": "multiple",
}
@property
def file_upload_dict(self) -> dict[str, Any]:
return (
def file_upload_dict(self) -> FileUploadConfig:
return cast(
FileUploadConfig,
json.loads(self.file_upload)
if self.file_upload
else {
@@ -474,10 +737,10 @@ class AppModelConfig(TypeBase):
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
}
},
)
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> AppModelConfigDict:
return {
"opening_statement": self.opening_statement,
"suggested_questions": self.suggested_questions_list,
@@ -501,36 +764,42 @@ class AppModelConfig(TypeBase):
"file_upload": self.file_upload_dict,
}
def from_model_config_dict(self, model_config: Mapping[str, Any]):
def from_model_config_dict(self, model_config: AppModelConfigDict):
self.opening_statement = model_config.get("opening_statement")
self.suggested_questions = (
json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None
json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None
)
self.suggested_questions_after_answer = (
json.dumps(model_config["suggested_questions_after_answer"])
json.dumps(model_config.get("suggested_questions_after_answer"))
if model_config.get("suggested_questions_after_answer")
else None
)
self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None
self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None
self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None
self.speech_to_text = (
json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None
)
self.text_to_speech = (
json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None
)
self.more_like_this = (
json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None
)
self.sensitive_word_avoidance = (
json.dumps(model_config["sensitive_word_avoidance"])
json.dumps(model_config.get("sensitive_word_avoidance"))
if model_config.get("sensitive_word_avoidance")
else None
)
self.external_data_tools = (
json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None
json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None
)
self.model = json.dumps(model_config["model"]) if model_config.get("model") else None
self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None
self.user_input_form = (
json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None
json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None
)
self.dataset_query_variable = model_config.get("dataset_query_variable")
self.pre_prompt = model_config["pre_prompt"]
self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None
self.pre_prompt = model_config.get("pre_prompt")
self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None
self.retriever_resource = (
json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None
json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None
)
self.prompt_type = model_config.get("prompt_type", "simple")
self.chat_prompt_config = (
@@ -823,24 +1092,26 @@ class Conversation(Base):
self._inputs = inputs
@property
def model_config(self):
model_config = {}
def model_config(self) -> AppModelConfigDict:
model_config = cast(AppModelConfigDict, {})
app_model_config: AppModelConfig | None = None
if self.mode == AppMode.ADVANCED_CHAT:
if self.override_model_configs:
override_model_configs = json.loads(self.override_model_configs)
model_config = override_model_configs
model_config = cast(AppModelConfigDict, override_model_configs)
else:
if self.override_model_configs:
override_model_configs = json.loads(self.override_model_configs)
if "model" in override_model_configs:
# where is app_id?
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs)
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(
cast(AppModelConfigDict, override_model_configs)
)
model_config = app_model_config.to_dict()
else:
model_config["configs"] = override_model_configs
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
else:
app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
@@ -1015,7 +1286,7 @@ class Conversation(Base):
def in_debug_mode(self) -> bool:
return self.override_model_configs is not None
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> ConversationDict:
return {
"id": self.id,
"app_id": self.app_id,
@@ -1295,7 +1566,7 @@ class Message(Base):
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
@property
def message_files(self) -> list[dict[str, Any]]:
def message_files(self) -> list[MessageFileInfo]:
from factories import file_factory
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
@@ -1350,10 +1621,13 @@ class Message(Base):
)
files.append(file)
result: list[dict[str, Any]] = [
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
for (file, message_file) in zip(files, message_files)
]
result = cast(
list[MessageFileInfo],
[
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
for (file, message_file) in zip(files, message_files)
],
)
db.session.commit()
return result
@@ -1363,7 +1637,7 @@ class Message(Base):
self._extra_contents = list(contents)
@property
def extra_contents(self) -> list[dict[str, Any]]:
def extra_contents(self) -> list[ExtraContentDict]:
return getattr(self, "_extra_contents", [])
@property
@@ -1379,7 +1653,7 @@ class Message(Base):
return None
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> MessageDict:
return {
"id": self.id,
"app_id": self.app_id,
@@ -1403,7 +1677,7 @@ class Message(Base):
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Message:
def from_dict(cls, data: MessageDict) -> Message:
return cls(
id=data["id"],
app_id=data["app_id"],
@@ -1463,7 +1737,7 @@ class MessageFeedback(TypeBase):
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> MessageFeedbackDict:
return {
"id": str(self.id),
"app_id": str(self.app_id),
@@ -1726,8 +2000,8 @@ class AppMCPServer(TypeBase):
return result
@property
def parameters_dict(self) -> dict[str, Any]:
return cast(dict[str, Any], json.loads(self.parameters))
def parameters_dict(self) -> dict[str, str]:
return cast(dict[str, str], json.loads(self.parameters))
class Site(Base):
@@ -2167,7 +2441,7 @@ class TraceAppConfig(TypeBase):
def tracing_config_str(self) -> str:
return json.dumps(self.tracing_config_dict)
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> TraceAppConfigDict:
return {
"id": self.id,
"app_id": self.app_id,