mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 16:39:26 +08:00
refactor: use EnumText for TidbAuthBinding.status and MessageFile.type (#33975)
This commit is contained in:
@@ -33,6 +33,7 @@ from core.rag.models.document import Document
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import Dataset, TidbAuthBinding
|
from models.dataset import Dataset, TidbAuthBinding
|
||||||
|
from models.enums import TidbAuthBindingStatus
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from qdrant_client import grpc # noqa
|
from qdrant_client import grpc # noqa
|
||||||
@@ -452,7 +453,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
|||||||
password=new_cluster["password"],
|
password=new_cluster["password"],
|
||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
active=True,
|
active=True,
|
||||||
status="ACTIVE",
|
status=TidbAuthBindingStatus.ACTIVE,
|
||||||
)
|
)
|
||||||
db.session.add(new_tidb_auth_binding)
|
db.session.add(new_tidb_auth_binding)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from configs import dify_config
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import TidbAuthBinding
|
from models.dataset import TidbAuthBinding
|
||||||
|
from models.enums import TidbAuthBindingStatus
|
||||||
|
|
||||||
|
|
||||||
class TidbService:
|
class TidbService:
|
||||||
@@ -170,7 +171,7 @@ class TidbService:
|
|||||||
userPrefix = item["userPrefix"]
|
userPrefix = item["userPrefix"]
|
||||||
if state == "ACTIVE" and len(userPrefix) > 0:
|
if state == "ACTIVE" and len(userPrefix) > 0:
|
||||||
cluster_info = tidb_serverless_list_map[item["clusterId"]]
|
cluster_info = tidb_serverless_list_map[item["clusterId"]]
|
||||||
cluster_info.status = "ACTIVE"
|
cluster_info.status = TidbAuthBindingStatus.ACTIVE
|
||||||
cluster_info.account = f"{userPrefix}.root"
|
cluster_info.account = f"{userPrefix}.root"
|
||||||
db.session.add(cluster_info)
|
db.session.add(cluster_info)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ from .enums import (
|
|||||||
SegmentStatus,
|
SegmentStatus,
|
||||||
SegmentType,
|
SegmentType,
|
||||||
SummaryStatus,
|
SummaryStatus,
|
||||||
|
TidbAuthBindingStatus,
|
||||||
)
|
)
|
||||||
from .model import App, Tag, TagBinding, UploadFile
|
from .model import App, Tag, TagBinding, UploadFile
|
||||||
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
|
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
|
||||||
@@ -1242,7 +1243,9 @@ class TidbAuthBinding(TypeBase):
|
|||||||
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||||
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
|
status: Mapped[TidbAuthBindingStatus] = mapped_column(
|
||||||
|
EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'")
|
||||||
|
)
|
||||||
account: Mapped[str] = mapped_column(String(255), nullable=False)
|
account: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
password: Mapped[str] = mapped_column(String(255), nullable=False)
|
password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from configs import dify_config
|
|||||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||||
from core.tools.signature import sign_tool_file
|
from core.tools.signature import sign_tool_file
|
||||||
from dify_graph.enums import WorkflowExecutionStatus
|
from dify_graph.enums import WorkflowExecutionStatus
|
||||||
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||||
from dify_graph.file import helpers as file_helpers
|
from dify_graph.file import helpers as file_helpers
|
||||||
from extensions.storage.storage_type import StorageType
|
from extensions.storage.storage_type import StorageType
|
||||||
from libs.helper import generate_string # type: ignore[import-not-found]
|
from libs.helper import generate_string # type: ignore[import-not-found]
|
||||||
@@ -1785,7 +1785,7 @@ class MessageFile(TypeBase):
|
|||||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||||
)
|
)
|
||||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False)
|
||||||
transfer_method: Mapped[FileTransferMethod] = mapped_column(
|
transfer_method: Mapped[FileTransferMethod] = mapped_column(
|
||||||
EnumText(FileTransferMethod, length=255), nullable=False
|
EnumText(FileTransferMethod, length=255), nullable=False
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from configs import dify_config
|
|||||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import TidbAuthBinding
|
from models.dataset import TidbAuthBinding
|
||||||
|
from models.enums import TidbAuthBindingStatus
|
||||||
|
|
||||||
|
|
||||||
@app.celery.task(queue="dataset")
|
@app.celery.task(queue="dataset")
|
||||||
@@ -57,7 +58,7 @@ def create_clusters(batch_size):
|
|||||||
account=new_cluster["account"],
|
account=new_cluster["account"],
|
||||||
password=new_cluster["password"],
|
password=new_cluster["password"],
|
||||||
active=False,
|
active=False,
|
||||||
status="CREATING",
|
status=TidbAuthBindingStatus.CREATING,
|
||||||
)
|
)
|
||||||
db.session.add(tidb_auth_binding)
|
db.session.add(tidb_auth_binding)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from configs import dify_config
|
|||||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import TidbAuthBinding
|
from models.dataset import TidbAuthBinding
|
||||||
|
from models.enums import TidbAuthBindingStatus
|
||||||
|
|
||||||
|
|
||||||
@app.celery.task(queue="dataset")
|
@app.celery.task(queue="dataset")
|
||||||
@@ -18,7 +19,10 @@ def update_tidb_serverless_status_task():
|
|||||||
try:
|
try:
|
||||||
# check the number of idle tidb serverless
|
# check the number of idle tidb serverless
|
||||||
tidb_serverless_list = db.session.scalars(
|
tidb_serverless_list = db.session.scalars(
|
||||||
select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
|
select(TidbAuthBinding).where(
|
||||||
|
TidbAuthBinding.active == False,
|
||||||
|
TidbAuthBinding.status == TidbAuthBindingStatus.CREATING,
|
||||||
|
)
|
||||||
).all()
|
).all()
|
||||||
if len(tidb_serverless_list) == 0:
|
if len(tidb_serverless_list) == 0:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import pytest
|
|||||||
from faker import Faker
|
from faker import Faker
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from dify_graph.file.enums import FileType
|
||||||
from enums.cloud_plan import CloudPlan
|
from enums.cloud_plan import CloudPlan
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||||
@@ -253,7 +254,7 @@ class TestMessagesCleanServiceIntegration:
|
|||||||
# MessageFile
|
# MessageFile
|
||||||
file = MessageFile(
|
file = MessageFile(
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
type="image",
|
type=FileType.IMAGE,
|
||||||
transfer_method="local_file",
|
transfer_method="local_file",
|
||||||
url="http://example.com/test.jpg",
|
url="http://example.com/test.jpg",
|
||||||
belongs_to=MessageFileBelongsTo.USER,
|
belongs_to=MessageFileBelongsTo.USER,
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from core.app.entities.task_entities import MessageEndStreamResponse
|
from core.app.entities.task_entities import MessageEndStreamResponse
|
||||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||||
from dify_graph.file.enums import FileTransferMethod
|
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||||
from models.model import MessageFile, UploadFile
|
from models.model import MessageFile, UploadFile
|
||||||
|
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ class TestMessageEndStreamResponseFiles:
|
|||||||
message_file.transfer_method = FileTransferMethod.LOCAL_FILE
|
message_file.transfer_method = FileTransferMethod.LOCAL_FILE
|
||||||
message_file.upload_file_id = str(uuid.uuid4())
|
message_file.upload_file_id = str(uuid.uuid4())
|
||||||
message_file.url = None
|
message_file.url = None
|
||||||
message_file.type = "image"
|
message_file.type = FileType.IMAGE
|
||||||
return message_file
|
return message_file
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -63,7 +63,7 @@ class TestMessageEndStreamResponseFiles:
|
|||||||
message_file.transfer_method = FileTransferMethod.REMOTE_URL
|
message_file.transfer_method = FileTransferMethod.REMOTE_URL
|
||||||
message_file.upload_file_id = None
|
message_file.upload_file_id = None
|
||||||
message_file.url = "https://example.com/image.jpg"
|
message_file.url = "https://example.com/image.jpg"
|
||||||
message_file.type = "image"
|
message_file.type = FileType.IMAGE
|
||||||
return message_file
|
return message_file
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -75,7 +75,7 @@ class TestMessageEndStreamResponseFiles:
|
|||||||
message_file.transfer_method = FileTransferMethod.TOOL_FILE
|
message_file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||||
message_file.upload_file_id = None
|
message_file.upload_file_id = None
|
||||||
message_file.url = "tool_file_123.png"
|
message_file.url = "tool_file_123.png"
|
||||||
message_file.type = "image"
|
message_file.type = FileType.IMAGE
|
||||||
return message_file
|
return message_file
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
Reference in New Issue
Block a user