refactor: EnumText for preferred_provider_type MessageChain, Banner (#33696)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
tmimmanuel
2026-03-18 18:53:04 +00:00
committed by GitHub
parent 4254392221
commit 29577cac14
8 changed files with 27 additions and 14 deletions

View File

@@ -4,6 +4,7 @@ from flask_restx import Resource
from controllers.console import api from controllers.console import api
from controllers.console.explore.wraps import explore_banner_enabled from controllers.console.explore.wraps import explore_banner_enabled
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import BannerStatus
from models.model import ExporleBanner from models.model import ExporleBanner
@@ -16,7 +17,7 @@ class BannerApi(Resource):
language = request.args.get("language", "en-US") language = request.args.get("language", "en-US")
# Build base query for enabled banners # Build base query for enabled banners
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled") base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
# Try to get banners in the requested language # Try to get banners in the requested language
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all() banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()

View File

@@ -1422,12 +1422,12 @@ class ProviderConfiguration(BaseModel):
preferred_model_provider = s.execute(stmt).scalars().first() preferred_model_provider = s.execute(stmt).scalars().first()
if preferred_model_provider: if preferred_model_provider:
preferred_model_provider.preferred_provider_type = provider_type.value preferred_model_provider.preferred_provider_type = provider_type
else: else:
preferred_model_provider = TenantPreferredModelProvider( preferred_model_provider = TenantPreferredModelProvider(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
provider_name=self.provider.provider, provider_name=self.provider.provider,
preferred_provider_type=provider_type.value, preferred_provider_type=provider_type,
) )
s.add(preferred_model_provider) s.add(preferred_model_provider)
s.commit() s.commit()

View File

@@ -195,7 +195,7 @@ class ProviderManager:
preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name) preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name)
if preferred_provider_type_record: if preferred_provider_type_record:
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type) preferred_provider_type = preferred_provider_type_record.preferred_provider_type
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled: elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
preferred_provider_type = ProviderType.SYSTEM preferred_provider_type = ProviderType.SYSTEM
elif custom_configuration.provider or custom_configuration.models: elif custom_configuration.provider or custom_configuration.models:

View File

@@ -29,7 +29,15 @@ from libs.uuid_utils import uuidv7
from .account import Account, Tenant from .account import Account, Tenant
from .base import Base, TypeBase, gen_uuidv4_string from .base import Base, TypeBase, gen_uuidv4_string
from .engine import db from .engine import db
from .enums import AppMCPServerStatus, AppStatus, ConversationStatus, CreatorUserRole, MessageStatus from .enums import (
AppMCPServerStatus,
AppStatus,
BannerStatus,
ConversationStatus,
CreatorUserRole,
MessageChainType,
MessageStatus,
)
from .provider_ids import GenericProviderID from .provider_ids import GenericProviderID
from .types import EnumText, LongText, StringUUID from .types import EnumText, LongText, StringUUID
@@ -925,8 +933,11 @@ class ExporleBanner(TypeBase):
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
link: Mapped[str] = mapped_column(String(255), nullable=False) link: Mapped[str] = mapped_column(String(255), nullable=False)
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False) sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
status: Mapped[str] = mapped_column( status: Mapped[BannerStatus] = mapped_column(
sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled" EnumText(BannerStatus, length=255),
nullable=False,
server_default=sa.text("'enabled'::character varying"),
default=BannerStatus.ENABLED,
) )
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
@@ -2206,7 +2217,7 @@ class MessageChain(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
) )
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False) type: Mapped[MessageChainType] = mapped_column(EnumText(MessageChainType, length=255), nullable=False)
input: Mapped[str | None] = mapped_column(LongText, nullable=True) input: Mapped[str | None] = mapped_column(LongText, nullable=True)
output: Mapped[str | None] = mapped_column(LongText, nullable=True) output: Mapped[str | None] = mapped_column(LongText, nullable=True)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(

View File

@@ -210,7 +210,7 @@ class TenantPreferredModelProvider(TypeBase):
) )
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False) preferred_provider_type: Mapped[ProviderType] = mapped_column(EnumText(ProviderType, length=40), nullable=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False DateTime, nullable=False, server_default=func.current_timestamp(), init=False
) )

View File

@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session
from enums.cloud_plan import CloudPlan from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import DataSourceType from models.enums import DataSourceType, MessageChainType
from models.model import ( from models.model import (
App, App,
AppAnnotationHitHistory, AppAnnotationHitHistory,
@@ -236,7 +236,7 @@ class TestMessagesCleanServiceIntegration:
# MessageChain # MessageChain
chain = MessageChain( chain = MessageChain(
message_id=message.id, message_id=message.id,
type="system", type=MessageChainType.SYSTEM,
input=json.dumps({"test": "input"}), input=json.dumps({"test": "input"}),
output=json.dumps({"test": "output"}), output=json.dumps({"test": "output"}),
) )

View File

@@ -2,6 +2,7 @@ from datetime import datetime
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import controllers.console.explore.banner as banner_module import controllers.console.explore.banner as banner_module
from models.enums import BannerStatus
def unwrap(func): def unwrap(func):
@@ -20,7 +21,7 @@ class TestBannerApi:
banner.content = {"text": "hello"} banner.content = {"text": "hello"}
banner.link = "https://example.com" banner.link = "https://example.com"
banner.sort = 1 banner.sort = 1
banner.status = "enabled" banner.status = BannerStatus.ENABLED
banner.created_at = datetime(2024, 1, 1) banner.created_at = datetime(2024, 1, 1)
query = MagicMock() query = MagicMock()
@@ -54,7 +55,7 @@ class TestBannerApi:
banner.content = {"text": "fallback"} banner.content = {"text": "fallback"}
banner.link = None banner.link = None
banner.sort = 1 banner.sort = 1
banner.status = "enabled" banner.status = BannerStatus.ENABLED
banner.created_at = None banner.created_at = None
query = MagicMock() query = MagicMock()

View File

@@ -410,7 +410,7 @@ def test_switch_preferred_provider_type_updates_existing_record_with_session() -
configuration.switch_preferred_provider_type(ProviderType.SYSTEM, session=session) configuration.switch_preferred_provider_type(ProviderType.SYSTEM, session=session)
assert existing_record.preferred_provider_type == ProviderType.SYSTEM.value assert existing_record.preferred_provider_type == ProviderType.SYSTEM
session.commit.assert_called_once() session.commit.assert_called_once()