refactor: use EnumText for Provider.quota_type and consolidate ProviderQuotaType (#34299)

This commit is contained in:
tmimmanuel
2026-03-31 02:29:57 +02:00
committed by GitHub
parent 15aa8071f8
commit 5897b28355
6 changed files with 14 additions and 32 deletions

View File

@@ -81,7 +81,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_type == system_configuration.current_quota_type,
Provider.quota_limit > Provider.quota_used,
)
.values(

View File

@@ -626,9 +626,8 @@ class ProviderManager:
if provider_record.provider_type != ProviderType.SYSTEM:
continue
provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
provider_record
)
if provider_record.quota_type is not None:
provider_quota_to_provider_record_dict[provider_record.quota_type] = provider_record
for quota in configuration.quotas:
if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
@@ -641,7 +640,7 @@ class ProviderManager:
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
provider_type=ProviderType.SYSTEM,
quota_type=quota.quota_type,
quota_type=quota.quota_type, # type: ignore[arg-type]
quota_limit=0, # type: ignore
quota_used=0,
is_valid=True,
@@ -921,9 +920,8 @@ class ProviderManager:
if provider_record.provider_type != ProviderType.SYSTEM:
continue
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
provider_record
)
if provider_record.quota_type is not None:
quota_type_to_provider_records_dict[provider_record.quota_type] = provider_record # type: ignore[index]
quota_configurations = []
if dify_config.EDITION == "CLOUD":

View File

@@ -157,7 +157,7 @@ def handle(sender: Message, **kwargs):
tenant_id=tenant_id,
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
quota_type=provider_configuration.system_configuration.current_quota_type,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
additional_filters=_ProviderUpdateAdditionalFilters(

View File

@@ -13,7 +13,7 @@ from libs.uuid_utils import uuidv7
from .base import TypeBase
from .engine import db
from .enums import CredentialSourceType, PaymentStatus
from .enums import CredentialSourceType, PaymentStatus, ProviderQuotaType
from .types import EnumText, LongText, StringUUID
@@ -29,24 +29,6 @@ class ProviderType(StrEnum):
raise ValueError(f"No matching enum found for value '{value}'")
class ProviderQuotaType(StrEnum):
PAID = auto()
"""hosted paid quota"""
FREE = auto()
"""third-party free quota"""
TRIAL = auto()
"""hosted trial quota"""
@staticmethod
def value_of(value: str) -> ProviderQuotaType:
for member in ProviderQuotaType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class Provider(TypeBase):
"""
Provider model representing the API providers and their configurations.
@@ -77,7 +59,9 @@ class Provider(TypeBase):
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, init=False)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
quota_type: Mapped[str | None] = mapped_column(String(40), nullable=True, server_default=text("''"), default="")
quota_type: Mapped[ProviderQuotaType | None] = mapped_column(
EnumText(ProviderQuotaType, length=40), nullable=True, server_default=text("''"), default=None
)
quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=None)
quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=0)

View File

@@ -144,8 +144,8 @@ class EnumText(TypeDecorator[_E | None], Generic[_E]):
return dialect.type_descriptor(VARCHAR(self._length))
def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None:
if value is None:
return value
if value is None or value == "":
return None
# Type annotation guarantees value is str at this point
return self._enum_class(value)

View File

@@ -202,7 +202,7 @@ class TestProviderModel:
# Assert
assert provider.provider_type == ProviderType.CUSTOM
assert provider.is_valid is False
assert provider.quota_type == ""
assert provider.quota_type is None
assert provider.quota_limit is None
assert provider.quota_used == 0
assert provider.credential_id is None