mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 10:12:43 +08:00
refactor: use EnumText for Provider.quota_type and consolidate ProviderQuotaType (#34299)
This commit is contained in:
@@ -81,7 +81,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
|
||||
@@ -626,9 +626,8 @@ class ProviderManager:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM:
|
||||
continue
|
||||
|
||||
provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
|
||||
provider_record
|
||||
)
|
||||
if provider_record.quota_type is not None:
|
||||
provider_quota_to_provider_record_dict[provider_record.quota_type] = provider_record
|
||||
|
||||
for quota in configuration.quotas:
|
||||
if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
|
||||
@@ -641,7 +640,7 @@ class ProviderManager:
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
provider_name=ModelProviderID(provider_name).provider_name,
|
||||
provider_type=ProviderType.SYSTEM,
|
||||
quota_type=quota.quota_type,
|
||||
quota_type=quota.quota_type, # type: ignore[arg-type]
|
||||
quota_limit=0, # type: ignore
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
@@ -921,9 +920,8 @@ class ProviderManager:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM:
|
||||
continue
|
||||
|
||||
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
|
||||
provider_record
|
||||
)
|
||||
if provider_record.quota_type is not None:
|
||||
quota_type_to_provider_records_dict[provider_record.quota_type] = provider_record # type: ignore[index]
|
||||
quota_configurations = []
|
||||
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
|
||||
@@ -157,7 +157,7 @@ def handle(sender: Message, **kwargs):
|
||||
tenant_id=tenant_id,
|
||||
provider_name=ModelProviderID(model_config.provider).provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=provider_configuration.system_configuration.current_quota_type.value,
|
||||
quota_type=provider_configuration.system_configuration.current_quota_type,
|
||||
),
|
||||
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
|
||||
additional_filters=_ProviderUpdateAdditionalFilters(
|
||||
|
||||
@@ -13,7 +13,7 @@ from libs.uuid_utils import uuidv7
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .enums import CredentialSourceType, PaymentStatus
|
||||
from .enums import CredentialSourceType, PaymentStatus, ProviderQuotaType
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
|
||||
@@ -29,24 +29,6 @@ class ProviderType(StrEnum):
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class ProviderQuotaType(StrEnum):
|
||||
PAID = auto()
|
||||
"""hosted paid quota"""
|
||||
|
||||
FREE = auto()
|
||||
"""third-party free quota"""
|
||||
|
||||
TRIAL = auto()
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
def value_of(value: str) -> ProviderQuotaType:
|
||||
for member in ProviderQuotaType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class Provider(TypeBase):
|
||||
"""
|
||||
Provider model representing the API providers and their configurations.
|
||||
@@ -77,7 +59,9 @@ class Provider(TypeBase):
|
||||
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, init=False)
|
||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
|
||||
quota_type: Mapped[str | None] = mapped_column(String(40), nullable=True, server_default=text("''"), default="")
|
||||
quota_type: Mapped[ProviderQuotaType | None] = mapped_column(
|
||||
EnumText(ProviderQuotaType, length=40), nullable=True, server_default=text("''"), default=None
|
||||
)
|
||||
quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=None)
|
||||
quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=0)
|
||||
|
||||
|
||||
@@ -144,8 +144,8 @@ class EnumText(TypeDecorator[_E | None], Generic[_E]):
|
||||
return dialect.type_descriptor(VARCHAR(self._length))
|
||||
|
||||
def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None:
|
||||
if value is None:
|
||||
return value
|
||||
if value is None or value == "":
|
||||
return None
|
||||
# Type annotation guarantees value is str at this point
|
||||
return self._enum_class(value)
|
||||
|
||||
|
||||
@@ -202,7 +202,7 @@ class TestProviderModel:
|
||||
# Assert
|
||||
assert provider.provider_type == ProviderType.CUSTOM
|
||||
assert provider.is_valid is False
|
||||
assert provider.quota_type == ""
|
||||
assert provider.quota_type is None
|
||||
assert provider.quota_limit is None
|
||||
assert provider.quota_used == 0
|
||||
assert provider.credential_id is None
|
||||
|
||||
Reference in New Issue
Block a user