mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 20:09:20 +08:00
refactor: EnumText for preferred_provider_type MessageChain, Banner (#33696)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from flask_restx import Resource
|
|||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.explore.wraps import explore_banner_enabled
|
from controllers.console.explore.wraps import explore_banner_enabled
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from models.enums import BannerStatus
|
||||||
from models.model import ExporleBanner
|
from models.model import ExporleBanner
|
||||||
|
|
||||||
|
|
||||||
@@ -16,7 +17,7 @@ class BannerApi(Resource):
|
|||||||
language = request.args.get("language", "en-US")
|
language = request.args.get("language", "en-US")
|
||||||
|
|
||||||
# Build base query for enabled banners
|
# Build base query for enabled banners
|
||||||
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
|
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
|
||||||
|
|
||||||
# Try to get banners in the requested language
|
# Try to get banners in the requested language
|
||||||
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
|
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
|
||||||
|
|||||||
@@ -1422,12 +1422,12 @@ class ProviderConfiguration(BaseModel):
|
|||||||
preferred_model_provider = s.execute(stmt).scalars().first()
|
preferred_model_provider = s.execute(stmt).scalars().first()
|
||||||
|
|
||||||
if preferred_model_provider:
|
if preferred_model_provider:
|
||||||
preferred_model_provider.preferred_provider_type = provider_type.value
|
preferred_model_provider.preferred_provider_type = provider_type
|
||||||
else:
|
else:
|
||||||
preferred_model_provider = TenantPreferredModelProvider(
|
preferred_model_provider = TenantPreferredModelProvider(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
provider_name=self.provider.provider,
|
provider_name=self.provider.provider,
|
||||||
preferred_provider_type=provider_type.value,
|
preferred_provider_type=provider_type,
|
||||||
)
|
)
|
||||||
s.add(preferred_model_provider)
|
s.add(preferred_model_provider)
|
||||||
s.commit()
|
s.commit()
|
||||||
|
|||||||
@@ -195,7 +195,7 @@ class ProviderManager:
|
|||||||
preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name)
|
preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name)
|
||||||
|
|
||||||
if preferred_provider_type_record:
|
if preferred_provider_type_record:
|
||||||
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
|
preferred_provider_type = preferred_provider_type_record.preferred_provider_type
|
||||||
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
|
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
|
||||||
preferred_provider_type = ProviderType.SYSTEM
|
preferred_provider_type = ProviderType.SYSTEM
|
||||||
elif custom_configuration.provider or custom_configuration.models:
|
elif custom_configuration.provider or custom_configuration.models:
|
||||||
|
|||||||
@@ -29,7 +29,15 @@ from libs.uuid_utils import uuidv7
|
|||||||
from .account import Account, Tenant
|
from .account import Account, Tenant
|
||||||
from .base import Base, TypeBase, gen_uuidv4_string
|
from .base import Base, TypeBase, gen_uuidv4_string
|
||||||
from .engine import db
|
from .engine import db
|
||||||
from .enums import AppMCPServerStatus, AppStatus, ConversationStatus, CreatorUserRole, MessageStatus
|
from .enums import (
|
||||||
|
AppMCPServerStatus,
|
||||||
|
AppStatus,
|
||||||
|
BannerStatus,
|
||||||
|
ConversationStatus,
|
||||||
|
CreatorUserRole,
|
||||||
|
MessageChainType,
|
||||||
|
MessageStatus,
|
||||||
|
)
|
||||||
from .provider_ids import GenericProviderID
|
from .provider_ids import GenericProviderID
|
||||||
from .types import EnumText, LongText, StringUUID
|
from .types import EnumText, LongText, StringUUID
|
||||||
|
|
||||||
@@ -925,8 +933,11 @@ class ExporleBanner(TypeBase):
|
|||||||
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
||||||
link: Mapped[str] = mapped_column(String(255), nullable=False)
|
link: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||||
status: Mapped[str] = mapped_column(
|
status: Mapped[BannerStatus] = mapped_column(
|
||||||
sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled"
|
EnumText(BannerStatus, length=255),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("'enabled'::character varying"),
|
||||||
|
default=BannerStatus.ENABLED,
|
||||||
)
|
)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||||
@@ -2206,7 +2217,7 @@ class MessageChain(TypeBase):
|
|||||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||||
)
|
)
|
||||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
type: Mapped[MessageChainType] = mapped_column(EnumText(MessageChainType, length=255), nullable=False)
|
||||||
input: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
input: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||||
output: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
output: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
|||||||
@@ -210,7 +210,7 @@ class TenantPreferredModelProvider(TypeBase):
|
|||||||
)
|
)
|
||||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
preferred_provider_type: Mapped[ProviderType] = mapped_column(EnumText(ProviderType, length=40), nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session
|
|||||||
from enums.cloud_plan import CloudPlan
|
from enums.cloud_plan import CloudPlan
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||||
from models.enums import DataSourceType
|
from models.enums import DataSourceType, MessageChainType
|
||||||
from models.model import (
|
from models.model import (
|
||||||
App,
|
App,
|
||||||
AppAnnotationHitHistory,
|
AppAnnotationHitHistory,
|
||||||
@@ -236,7 +236,7 @@ class TestMessagesCleanServiceIntegration:
|
|||||||
# MessageChain
|
# MessageChain
|
||||||
chain = MessageChain(
|
chain = MessageChain(
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
type="system",
|
type=MessageChainType.SYSTEM,
|
||||||
input=json.dumps({"test": "input"}),
|
input=json.dumps({"test": "input"}),
|
||||||
output=json.dumps({"test": "output"}),
|
output=json.dumps({"test": "output"}),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from datetime import datetime
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import controllers.console.explore.banner as banner_module
|
import controllers.console.explore.banner as banner_module
|
||||||
|
from models.enums import BannerStatus
|
||||||
|
|
||||||
|
|
||||||
def unwrap(func):
|
def unwrap(func):
|
||||||
@@ -20,7 +21,7 @@ class TestBannerApi:
|
|||||||
banner.content = {"text": "hello"}
|
banner.content = {"text": "hello"}
|
||||||
banner.link = "https://example.com"
|
banner.link = "https://example.com"
|
||||||
banner.sort = 1
|
banner.sort = 1
|
||||||
banner.status = "enabled"
|
banner.status = BannerStatus.ENABLED
|
||||||
banner.created_at = datetime(2024, 1, 1)
|
banner.created_at = datetime(2024, 1, 1)
|
||||||
|
|
||||||
query = MagicMock()
|
query = MagicMock()
|
||||||
@@ -54,7 +55,7 @@ class TestBannerApi:
|
|||||||
banner.content = {"text": "fallback"}
|
banner.content = {"text": "fallback"}
|
||||||
banner.link = None
|
banner.link = None
|
||||||
banner.sort = 1
|
banner.sort = 1
|
||||||
banner.status = "enabled"
|
banner.status = BannerStatus.ENABLED
|
||||||
banner.created_at = None
|
banner.created_at = None
|
||||||
|
|
||||||
query = MagicMock()
|
query = MagicMock()
|
||||||
|
|||||||
@@ -410,7 +410,7 @@ def test_switch_preferred_provider_type_updates_existing_record_with_session() -
|
|||||||
|
|
||||||
configuration.switch_preferred_provider_type(ProviderType.SYSTEM, session=session)
|
configuration.switch_preferred_provider_type(ProviderType.SYSTEM, session=session)
|
||||||
|
|
||||||
assert existing_record.preferred_provider_type == ProviderType.SYSTEM.value
|
assert existing_record.preferred_provider_type == ProviderType.SYSTEM
|
||||||
session.commit.assert_called_once()
|
session.commit.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user