Merge branch 'main' into feat/remove-unused-is-deleted-from-conversation

This commit is contained in:
-LAN-
2026-03-30 15:54:38 +08:00
committed by GitHub
3294 changed files with 231483 additions and 98607 deletions

View File

@@ -3,16 +3,20 @@ from __future__ import annotations
import json
import re
import uuid
from collections.abc import Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence
from datetime import datetime
from decimal import Decimal
from enum import StrEnum, auto
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast
from uuid import uuid4
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin # type: ignore[import-untyped]
from graphon.enums import WorkflowExecutionStatus
from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from graphon.file import helpers as file_helpers
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import TypedDict
@@ -20,16 +24,33 @@ from typing_extensions import TypedDict
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.tools.signature import sign_tool_file
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from dify_graph.file import helpers as file_helpers
from extensions.storage.storage_type import StorageType
from libs.helper import generate_string # type: ignore[import-not-found]
from libs.uuid_utils import uuidv7
from models.utils.file_input_compat import build_file_from_input_mapping
from .account import Account, Tenant
from .base import Base, TypeBase, gen_uuidv4_string
from .engine import db
from .enums import AppMCPServerStatus, AppStatus, ConversationStatus, CreatorUserRole, MessageStatus
from .enums import (
ApiTokenType,
AppMCPServerStatus,
AppStatus,
BannerStatus,
ConversationFromSource,
ConversationStatus,
CreatorUserRole,
CustomizeTokenStrategy,
FeedbackFromSource,
FeedbackRating,
InvokeFrom,
MessageChainType,
MessageFileBelongsTo,
MessageStatus,
PromptType,
ProviderQuotaType,
TagType,
)
from .provider_ids import GenericProviderID
from .types import EnumText, LongText, StringUUID
@@ -40,6 +61,32 @@ if TYPE_CHECKING:
# --- TypedDict definitions for structured dict return types ---
@lru_cache(maxsize=1)
def _get_file_access_controller():
from core.app.file_access import DatabaseFileAccessController
return DatabaseFileAccessController()
def _resolve_app_tenant_id(app_id: str) -> str:
resolved_tenant_id = db.session.scalar(select(App.tenant_id).where(App.id == app_id))
if not resolved_tenant_id:
raise ValueError(f"Unable to resolve tenant_id for app {app_id}")
return resolved_tenant_id
def _build_app_tenant_resolver(app_id: str, owner_tenant_id: str | None = None) -> Callable[[], str]:
resolved_tenant_id = owner_tenant_id
def resolve_owner_tenant_id() -> str:
nonlocal resolved_tenant_id
if resolved_tenant_id is None:
resolved_tenant_id = _resolve_app_tenant_id(app_id)
return resolved_tenant_id
return resolve_owner_tenant_id
class EnabledConfig(TypedDict):
enabled: bool
@@ -380,13 +427,12 @@ class App(Base):
@property
def site(self) -> Site | None:
site = db.session.query(Site).where(Site.app_id == self.id).first()
return site
return db.session.scalar(select(Site).where(Site.app_id == self.id))
@property
def app_model_config(self) -> AppModelConfig | None:
if self.app_model_config_id:
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
return db.session.scalar(select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id))
return None
@@ -395,7 +441,7 @@ class App(Base):
if self.workflow_id:
from .workflow import Workflow
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
return None
@@ -405,8 +451,7 @@ class App(Base):
@property
def tenant(self) -> Tenant | None:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
@property
def is_agent(self) -> bool:
@@ -546,9 +591,9 @@ class App(Base):
return deleted_tools
@property
def tags(self) -> list[Tag]:
tags = (
db.session.query(Tag)
def tags(self) -> Sequence[Tag]:
tags = db.session.scalars(
select(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.where(
TagBinding.target_id == self.id,
@@ -556,15 +601,14 @@ class App(Base):
Tag.tenant_id == self.tenant_id,
Tag.type == "app",
)
.all()
)
).all()
return tags or []
@property
def author_name(self) -> str | None:
if self.created_by:
account = db.session.query(Account).where(Account.id == self.created_by).first()
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
if account:
return account.name
@@ -575,7 +619,9 @@ class AppModelConfig(TypeBase):
__tablename__ = "app_model_configs"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id"))
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
model_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
@@ -605,8 +651,11 @@ class AppModelConfig(TypeBase):
agent_mode: Mapped[str | None] = mapped_column(LongText, default=None)
sensitive_word_avoidance: Mapped[str | None] = mapped_column(LongText, default=None)
retriever_resource: Mapped[str | None] = mapped_column(LongText, default=None)
prompt_type: Mapped[str] = mapped_column(
String(255), nullable=False, server_default=sa.text("'simple'"), default="simple"
prompt_type: Mapped[PromptType] = mapped_column(
EnumText(PromptType, length=255),
nullable=False,
server_default=sa.text("'simple'"),
default=PromptType.SIMPLE,
)
chat_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None)
completion_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None)
@@ -616,8 +665,7 @@ class AppModelConfig(TypeBase):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
@property
def model_dict(self) -> ModelConfig:
@@ -652,8 +700,8 @@ class AppModelConfig(TypeBase):
@property
def annotation_reply_dict(self) -> AnnotationReplyConfig:
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id)
)
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
@@ -759,7 +807,7 @@ class AppModelConfig(TypeBase):
"dataset_query_variable": self.dataset_query_variable,
"pre_prompt": self.pre_prompt,
"agent_mode": self.agent_mode_dict,
"prompt_type": self.prompt_type,
"prompt_type": self.prompt_type.value if isinstance(self.prompt_type, PromptType) else self.prompt_type,
"chat_prompt_config": self.chat_prompt_config_dict,
"completion_prompt_config": self.completion_prompt_config_dict,
"dataset_configs": self.dataset_configs_dict,
@@ -803,7 +851,7 @@ class AppModelConfig(TypeBase):
self.retriever_resource = (
json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None
)
self.prompt_type = model_config.get("prompt_type", "simple")
self.prompt_type = PromptType(model_config.get("prompt_type", "simple"))
self.chat_prompt_config = (
json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None
)
@@ -845,8 +893,7 @@ class RecommendedApp(Base): # bug
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
class InstalledApp(TypeBase):
@@ -873,13 +920,11 @@ class InstalledApp(TypeBase):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
@property
def tenant(self) -> Tenant | None:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
class TrialApp(Base):
@@ -899,8 +944,7 @@ class TrialApp(Base):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
class AccountTrialAppRecord(Base):
@@ -919,24 +963,27 @@ class AccountTrialAppRecord(Base):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
@property
def user(self) -> Account | None:
user = db.session.query(Account).where(Account.id == self.account_id).first()
return user
return db.session.scalar(select(Account).where(Account.id == self.account_id))
class ExporleBanner(TypeBase):
__tablename__ = "exporle_banners"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv4_string, init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
)
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
link: Mapped[str] = mapped_column(String(255), nullable=False)
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
status: Mapped[str] = mapped_column(
sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled"
status: Mapped[BannerStatus] = mapped_column(
EnumText(BannerStatus, length=255),
nullable=False,
server_default=sa.text("'enabled'::character varying"),
default=BannerStatus.ENABLED,
)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
@@ -1017,10 +1064,12 @@ class Conversation(Base):
#
# Its value corresponds to the members of `InvokeFrom`.
# (api/core/app/entities/app_invoke_entities.py)
invoke_from = mapped_column(String(255), nullable=True)
invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True)
# ref: ConversationSource.
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
from_source: Mapped[ConversationFromSource] = mapped_column(
EnumText(ConversationFromSource, length=255), nullable=False
)
from_end_user_id = mapped_column(StringUUID)
from_account_id = mapped_column(StringUUID)
read_at = mapped_column(sa.DateTime)
@@ -1039,23 +1088,26 @@ class Conversation(Base):
@property
def inputs(self) -> dict[str, Any]:
inputs = self._inputs.copy()
# Compatibility bridge: stored input payloads may come from before or after the
# graph-layer file refactor. Newer rows may omit `tenant_id`, so keep tenant
# resolution at the SQLAlchemy model boundary instead of pushing ownership back
# into `graphon.file.File`.
tenant_resolver = _build_app_tenant_resolver(
app_id=self.app_id,
owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)),
)
# Convert file mapping to File object
for key, value in inputs.items():
# NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
from factories import file_factory
if (
isinstance(value, dict)
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
):
value_dict = cast(dict[str, Any], value)
if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
value_dict["tool_file_id"] = value_dict["related_id"]
elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value_dict["upload_file_id"] = value_dict["related_id"]
tenant_id = cast(str, value_dict.get("tenant_id", ""))
inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id)
inputs[key] = build_file_from_input_mapping(
file_mapping=value_dict,
tenant_resolver=tenant_resolver,
)
elif isinstance(value, list):
value_list = cast(list[Any], value)
if all(
@@ -1068,15 +1120,12 @@ class Conversation(Base):
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
item_dict["tool_file_id"] = item_dict["related_id"]
elif item_dict["transfer_method"] in [
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.REMOTE_URL,
]:
item_dict["upload_file_id"] = item_dict["related_id"]
tenant_id = cast(str, item_dict.get("tenant_id", ""))
file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id))
file_list.append(
build_file_from_input_mapping(
file_mapping=item_dict,
tenant_resolver=tenant_resolver,
)
)
inputs[key] = file_list
return inputs
@@ -1115,8 +1164,8 @@ class Conversation(Base):
else:
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
else:
app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
app_model_config = db.session.scalar(
select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id)
)
if app_model_config:
model_config = app_model_config.to_dict()
@@ -1139,36 +1188,43 @@ class Conversation(Base):
@property
def annotated(self):
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0
return (
db.session.scalar(
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.conversation_id == self.id)
)
or 0
) > 0
@property
def annotation(self):
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first()
return db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).limit(1))
@property
def message_count(self):
return db.session.query(Message).where(Message.conversation_id == self.id).count()
return db.session.scalar(select(func.count(Message.id)).where(Message.conversation_id == self.id)) or 0
@property
def user_feedback_stats(self):
like = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "like",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == FeedbackRating.LIKE,
)
)
.count()
or 0
)
dislike = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == FeedbackRating.DISLIKE,
)
)
.count()
or 0
)
return {"like": like, "dislike": dislike}
@@ -1176,23 +1232,25 @@ class Conversation(Base):
@property
def admin_feedback_stats(self):
like = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == FeedbackRating.LIKE,
)
)
.count()
or 0
)
dislike = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == FeedbackRating.DISLIKE,
)
)
.count()
or 0
)
return {"like": like, "dislike": dislike}
@@ -1254,22 +1312,19 @@ class Conversation(Base):
@property
def first_message(self):
return (
db.session.query(Message)
.where(Message.conversation_id == self.id)
.order_by(Message.created_at.asc())
.first()
return db.session.scalar(
select(Message).where(Message.conversation_id == self.id).order_by(Message.created_at.asc())
)
@property
def app(self) -> App | None:
with Session(db.engine, expire_on_commit=False) as session:
return session.query(App).where(App.id == self.app_id).first()
return session.scalar(select(App).where(App.id == self.app_id))
@property
def from_end_user_session_id(self):
if self.from_end_user_id:
end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first()
end_user = db.session.scalar(select(EndUser).where(EndUser.id == self.from_end_user_id))
if end_user:
return end_user.session_id
@@ -1278,7 +1333,7 @@ class Conversation(Base):
@property
def from_account_name(self) -> str | None:
if self.from_account_id:
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
account = db.session.scalar(select(Account).where(Account.id == self.from_account_id))
if account:
return account.name
@@ -1361,8 +1416,10 @@ class Message(Base):
)
error: Mapped[str | None] = mapped_column(LongText)
message_metadata: Mapped[str | None] = mapped_column(LongText)
invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True)
from_source: Mapped[ConversationFromSource] = mapped_column(
EnumText(ConversationFromSource, length=255), nullable=False
)
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
@@ -1376,21 +1433,23 @@ class Message(Base):
@property
def inputs(self) -> dict[str, Any]:
inputs = self._inputs.copy()
# Compatibility bridge: message inputs are persisted as JSON and must remain
# readable across file payload shape changes. Do not assume `tenant_id`
# is serialized into each file mapping going forward.
tenant_resolver = _build_app_tenant_resolver(
app_id=self.app_id,
owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)),
)
for key, value in inputs.items():
# NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
from factories import file_factory
if (
isinstance(value, dict)
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
):
value_dict = cast(dict[str, Any], value)
if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
value_dict["tool_file_id"] = value_dict["related_id"]
elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value_dict["upload_file_id"] = value_dict["related_id"]
tenant_id = cast(str, value_dict.get("tenant_id", ""))
inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id)
inputs[key] = build_file_from_input_mapping(
file_mapping=value_dict,
tenant_resolver=tenant_resolver,
)
elif isinstance(value, list):
value_list = cast(list[Any], value)
if all(
@@ -1403,15 +1462,12 @@ class Message(Base):
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
item_dict["tool_file_id"] = item_dict["related_id"]
elif item_dict["transfer_method"] in [
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.REMOTE_URL,
]:
item_dict["upload_file_id"] = item_dict["related_id"]
tenant_id = cast(str, item_dict.get("tenant_id", ""))
file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id))
file_list.append(
build_file_from_input_mapping(
file_mapping=item_dict,
tenant_resolver=tenant_resolver,
)
)
inputs[key] = file_list
return inputs
@@ -1503,21 +1559,15 @@ class Message(Base):
@property
def user_feedback(self):
feedback = (
db.session.query(MessageFeedback)
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
.first()
return db.session.scalar(
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
)
return feedback
@property
def admin_feedback(self):
feedback = (
db.session.query(MessageFeedback)
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
.first()
return db.session.scalar(
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
)
return feedback
@property
def feedbacks(self):
@@ -1526,28 +1576,27 @@ class Message(Base):
@property
def annotation(self):
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first()
annotation = db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.message_id == self.id))
return annotation
@property
def annotation_hit_history(self):
annotation_history = (
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first()
annotation_history = db.session.scalar(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id)
)
if annotation_history:
annotation = (
db.session.query(MessageAnnotation)
.where(MessageAnnotation.id == annotation_history.annotation_id)
.first()
return db.session.scalar(
select(MessageAnnotation).where(MessageAnnotation.id == annotation_history.annotation_id)
)
return annotation
return None
@property
def app_model_config(self):
conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first()
conversation = db.session.scalar(select(Conversation).where(Conversation.id == self.conversation_id))
if conversation:
return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first()
return db.session.scalar(
select(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id)
)
return None
@@ -1560,13 +1609,12 @@ class Message(Base):
return json.loads(self.message_metadata) if self.message_metadata else {}
@property
def agent_thoughts(self) -> list[MessageAgentThought]:
return (
db.session.query(MessageAgentThought)
def agent_thoughts(self) -> Sequence[MessageAgentThought]:
return db.session.scalars(
select(MessageAgentThought)
.where(MessageAgentThought.message_id == self.id)
.order_by(MessageAgentThought.position.asc())
.all()
)
).all()
@property
def retriever_resources(self) -> Any:
@@ -1577,7 +1625,7 @@ class Message(Base):
from factories import file_factory
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
current_app = db.session.query(App).where(App.id == self.app_id).first()
current_app = db.session.scalar(select(App).where(App.id == self.app_id))
if not current_app:
raise ValueError(f"App {self.app_id} not found")
@@ -1594,6 +1642,7 @@ class Message(Base):
"upload_file_id": message_file.upload_file_id,
},
tenant_id=current_app.tenant_id,
access_controller=_get_file_access_controller(),
)
elif message_file.transfer_method == FileTransferMethod.REMOTE_URL:
if message_file.url is None:
@@ -1607,6 +1656,7 @@ class Message(Base):
"url": message_file.url,
},
tenant_id=current_app.tenant_id,
access_controller=_get_file_access_controller(),
)
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE:
if message_file.upload_file_id is None:
@@ -1621,6 +1671,7 @@ class Message(Base):
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=current_app.tenant_id,
access_controller=_get_file_access_controller(),
)
else:
raise ValueError(
@@ -1723,8 +1774,8 @@ class MessageFeedback(TypeBase):
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
rating: Mapped[str] = mapped_column(String(255), nullable=False)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
rating: Mapped[FeedbackRating] = mapped_column(EnumText(FeedbackRating, length=255), nullable=False)
from_source: Mapped[FeedbackFromSource] = mapped_column(EnumText(FeedbackFromSource, length=255), nullable=False)
content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
from_account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
@@ -1741,8 +1792,7 @@ class MessageFeedback(TypeBase):
@property
def from_account(self) -> Account | None:
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.from_account_id))
def to_dict(self) -> MessageFeedbackDict:
return {
@@ -1772,13 +1822,15 @@ class MessageFile(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False)
transfer_method: Mapped[FileTransferMethod] = mapped_column(
EnumText(FileTransferMethod, length=255), nullable=False
)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None)
belongs_to: Mapped[MessageFileBelongsTo | None] = mapped_column(
EnumText(MessageFileBelongsTo, length=255), nullable=True, default=None
)
url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
@@ -1815,13 +1867,11 @@ class MessageAnnotation(Base):
@property
def account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.account_id))
@property
def annotation_create_account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.account_id))
class AppAnnotationHitHistory(TypeBase):
@@ -1834,7 +1884,9 @@ class AppAnnotationHitHistory(TypeBase):
sa.Index("app_annotation_hit_histories_message_idx", "message_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
source: Mapped[str] = mapped_column(LongText, nullable=False)
@@ -1850,18 +1902,15 @@ class AppAnnotationHitHistory(TypeBase):
@property
def account(self):
account = (
db.session.query(Account)
return db.session.scalar(
select(Account)
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
.where(MessageAnnotation.id == self.annotation_id)
.first()
)
return account
@property
def annotation_create_account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.account_id))
class AppAnnotationSetting(TypeBase):
@@ -1894,12 +1943,9 @@ class AppAnnotationSetting(TypeBase):
def collection_binding_detail(self):
from .dataset import DatasetCollectionBinding
collection_binding_detail = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == self.collection_binding_id)
.first()
return db.session.scalar(
select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == self.collection_binding_id)
)
return collection_binding_detail
class OperationLog(TypeBase):
@@ -2005,7 +2051,9 @@ class AppMCPServer(TypeBase):
def generate_server_code(n: int) -> str:
while True:
result = generate_string(n)
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
while (
db.session.scalar(select(func.count(AppMCPServer.id)).where(AppMCPServer.server_code == result)) or 0
) > 0:
result = generate_string(n)
return result
@@ -2039,7 +2087,9 @@ class Site(Base):
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
_custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", LongText, default="")
customize_domain = mapped_column(String(255))
customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False)
customize_token_strategy: Mapped[CustomizeTokenStrategy] = mapped_column(
EnumText(CustomizeTokenStrategy, length=255), nullable=False
)
prompt_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
status: Mapped[AppStatus] = mapped_column(
EnumText(AppStatus, length=255), nullable=False, server_default=sa.text("'normal'"), default=AppStatus.NORMAL
@@ -2066,7 +2116,7 @@ class Site(Base):
def generate_code(n: int) -> str:
while True:
result = generate_string(n)
while db.session.query(Site).where(Site.code == result).count() > 0:
while (db.session.scalar(select(func.count(Site.id)).where(Site.code == result)) or 0) > 0:
result = generate_string(n)
return result
@@ -2088,7 +2138,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field.
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
app_id = mapped_column(StringUUID, nullable=True)
tenant_id = mapped_column(StringUUID, nullable=True)
type = mapped_column(String(16), nullable=False)
type: Mapped[ApiTokenType] = mapped_column(EnumText(ApiTokenType, length=16), nullable=False)
token: Mapped[str] = mapped_column(String(255), nullable=False)
last_used_at = mapped_column(sa.DateTime, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@@ -2114,7 +2164,7 @@ class UploadFile(Base):
# The `server_default` serves as a fallback mechanism.
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
storage_type: Mapped[str] = mapped_column(String(255), nullable=False)
storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False)
key: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
@@ -2158,7 +2208,7 @@ class UploadFile(Base):
self,
*,
tenant_id: str,
storage_type: str,
storage_type: StorageType,
key: str,
name: str,
size: int,
@@ -2223,7 +2273,7 @@ class MessageChain(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[MessageChainType] = mapped_column(EnumText(MessageChainType, length=255), nullable=False)
input: Mapped[str | None] = mapped_column(LongText, nullable=True)
output: Mapped[str | None] = mapped_column(LongText, nullable=True)
created_at: Mapped[datetime] = mapped_column(
@@ -2398,7 +2448,7 @@ class Tag(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
type: Mapped[str] = mapped_column(String(16), nullable=False)
type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
@@ -2483,7 +2533,9 @@ class TenantCreditPool(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
pool_type: Mapped[ProviderQuotaType] = mapped_column(
EnumText(ProviderQuotaType, length=40), nullable=False, default=ProviderQuotaType.TRIAL, server_default="trial"
)
quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(