refactor(api): move llm quota deduction to app graph layer (#32786)

This commit is contained in:
-LAN-
2026-03-01 23:25:36 +08:00
committed by GitHub
parent fa4b8910c8
commit ef2b5d6107
16 changed files with 434 additions and 95 deletions

View File

@@ -29,6 +29,8 @@ ignore_imports =
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
@@ -107,14 +109,12 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
core.workflow.nodes.llm.llm_utils -> configs
core.workflow.nodes.llm.llm_utils -> core.model_manager core.workflow.nodes.llm.llm_utils -> core.model_manager
core.workflow.nodes.llm.protocols -> core.model_manager core.workflow.nodes.llm.protocols -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.llm.llm_utils -> models.model core.workflow.nodes.llm.llm_utils -> models.model
core.workflow.nodes.llm.llm_utils -> models.provider
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
core.workflow.nodes.llm.node -> core.tools.signature core.workflow.nodes.llm.node -> core.tools.signature
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
@@ -135,8 +135,8 @@ ignore_imports =
core.workflow.nodes.start.start_node -> core.app.app_config.entities core.workflow.nodes.start.start_node -> core.app.app_config.entities
core.workflow.workflow_entry -> core.app.apps.exc core.workflow.workflow_entry -> core.app.apps.exc
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
core.workflow.workflow_entry -> core.app.workflow.layers.llm_quota
core.workflow.workflow_entry -> core.app.workflow.node_factory core.workflow.workflow_entry -> core.app.workflow.node_factory
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
@@ -180,7 +180,7 @@ ignore_imports =
core.workflow.workflow_entry -> extensions.otel.runtime core.workflow.workflow_entry -> extensions.otel.runtime
core.workflow.nodes.agent.agent_node -> models core.workflow.nodes.agent.agent_node -> models
core.workflow.nodes.base.node -> models.enums core.workflow.nodes.base.node -> models.enums
core.workflow.nodes.llm.llm_utils -> models.provider_ids core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.llm.node -> models.model core.workflow.nodes.llm.node -> models.model
core.workflow.workflow_entry -> models.enums core.workflow.workflow_entry -> models.enums
core.workflow.nodes.agent.agent_node -> services core.workflow.nodes.agent.agent_node -> services

View File

@@ -1 +1,5 @@
"""LLM-related application services.""" """LLM-related application services."""
from .quota import deduct_llm_quota, ensure_llm_quota_available
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]

93
api/core/app/llm/quota.py Normal file
View File

@@ -0,0 +1,93 @@
from sqlalchemy import update
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
provider_model = provider_configuration.get_provider_model(
model_type=model_instance.model_type_instance.model_type,
model=model_instance.model_name,
)
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model_name)
else:
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@@ -1,9 +1,11 @@
"""Workflow-level GraphEngine layers that depend on outer infrastructure.""" """Workflow-level GraphEngine layers that depend on outer infrastructure."""
from .llm_quota import LLMQuotaLayer
from .observability import ObservabilityLayer from .observability import ObservabilityLayer
from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
__all__ = [ __all__ = [
"LLMQuotaLayer",
"ObservabilityLayer", "ObservabilityLayer",
"PersistenceWorkflowInfo", "PersistenceWorkflowInfo",
"WorkflowPersistenceLayer", "WorkflowPersistenceLayer",

View File

@@ -0,0 +1,128 @@
"""
LLM quota deduction layer for GraphEngine.
This layer centralizes model-quota deduction outside node implementations.
"""
import logging
from typing import TYPE_CHECKING, cast, final
from typing_extensions import override
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.workflow.enums import NodeType
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.nodes.base.node import Node
if TYPE_CHECKING:
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
logger = logging.getLogger(__name__)
@final
class LLMQuotaLayer(GraphEngineLayer):
"""Graph layer that applies LLM quota deduction after node execution."""
def __init__(self) -> None:
super().__init__()
self._abort_sent = False
@override
def on_graph_start(self) -> None:
self._abort_sent = False
@override
def on_event(self, event: GraphEngineEvent) -> None:
_ = event
@override
def on_graph_end(self, error: Exception | None) -> None:
_ = error
@override
def on_node_run_start(self, node: Node) -> None:
if self._abort_sent:
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
return
try:
ensure_llm_quota_available(model_instance=model_instance)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
@override
def on_node_run_end(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
return
try:
deduct_llm_quota(
tenant_id=node.tenant_id,
model_instance=model_instance,
usage=result_event.node_run_result.llm_usage,
)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
logger.warning("LLM quota deduction exceeded, node_id=%s, error=%s", node.id, exc)
except Exception:
logger.exception("LLM quota deduction failed, node_id=%s", node.id)
@staticmethod
def _set_stop_event(node: Node) -> None:
stop_event = getattr(node.graph_runtime_state, "stop_event", None)
if stop_event is not None:
stop_event.set()
def _send_abort_command(self, *, reason: str) -> None:
if not self.command_channel or self._abort_sent:
return
try:
self.command_channel.send_command(
AbortCommand(
command_type=CommandType.ABORT,
reason=reason,
)
)
self._abort_sent = True
except Exception:
logger.exception("Failed to send quota abort command")
@staticmethod
def _extract_model_instance(node: Node) -> ModelInstance | None:
try:
match node.node_type:
case NodeType.LLM:
return cast("LLMNode", node).model_instance
case NodeType.PARAMETER_EXTRACTOR:
return cast("ParameterExtractorNode", node).model_instance
case NodeType.QUESTION_CLASSIFIER:
return cast("QuestionClassifierNode", node).model_instance
case _:
return None
except AttributeError:
logger.warning(
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
node.id,
)
return None

View File

@@ -2,6 +2,7 @@ import tempfile
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from collections.abc import Generator from collections.abc import Generator
from core.app.llm import deduct_llm_quota
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import ( from core.model_runtime.entities.llm_entities import (
@@ -29,7 +30,6 @@ from core.plugin.entities.request import (
) )
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.workflow.nodes.llm import llm_utils
from models.account import Tenant from models.account import Tenant
@@ -63,16 +63,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
def handle() -> Generator[LLMResultChunk, None, None]: def handle() -> Generator[LLMResultChunk, None, None]:
for chunk in response: for chunk in response:
if chunk.delta.usage: if chunk.delta.usage:
llm_utils.deduct_llm_quota( deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
chunk.prompt_messages = [] chunk.prompt_messages = []
yield chunk yield chunk
return handle() return handle()
else: else:
if response.usage: if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]: def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk( yield LLMResultChunk(
@@ -126,16 +124,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]: def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
for chunk in response: for chunk in response:
if chunk.delta.usage: if chunk.delta.usage:
llm_utils.deduct_llm_quota( deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
chunk.prompt_messages = [] chunk.prompt_messages = []
yield chunk yield chunk
return handle() return handle()
else: else:
if response.usage: if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming( def handle_non_streaming(
response: LLMResultWithStructuredOutput, response: LLMResultWithStructuredOutput,

View File

@@ -8,6 +8,7 @@ from typing import Any, cast
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from core.app.llm import deduct_llm_quota
from core.entities.knowledge_entities import PreviewDetail from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
@@ -35,7 +36,6 @@ from core.rag.models.document import AttachmentDocument, Document, MultimodalGen
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.workflow.file import File, FileTransferMethod, FileType, file_manager from core.workflow.file import File, FileTransferMethod, FileType, file_manager
from core.workflow.nodes.llm import llm_utils
from extensions.ext_database import db from extensions.ext_database import db
from factories.file_factory import build_from_mapping from factories.file_factory import build_from_mapping
from libs import helper from libs import helper
@@ -474,7 +474,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
# Deduct quota for summary generation (same as workflow nodes) # Deduct quota for summary generation (same as workflow nodes)
try: try:
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
except Exception as e: except Exception as e:
# Log but don't fail summary generation if quota deduction fails # Log but don't fail summary generation if quota deduction fails
logger.warning("Failed to deduct quota for summary generation: %s", str(e)) logger.warning("Failed to deduct quota for summary generation: %s", str(e))

View File

@@ -2,6 +2,7 @@ from collections.abc import Generator, Sequence
from typing import Union from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.llm import deduct_llm_quota
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
@@ -9,7 +10,6 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.llm import llm_utils
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
@@ -162,7 +162,7 @@ class ReactMultiDatasetRouter:
text, usage = self._handle_invoke_result(invoke_result=invoke_result) text, usage = self._handle_invoke_result(invoke_result=invoke_result)
# deduct quota # deduct quota
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
return text, usage return text, usage

View File

@@ -588,6 +588,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
def _create_graph_engine(self, index: int, item: object): def _create_graph_engine(self, index: int, item: object):
# Import dependencies # Import dependencies
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.node_factory import DifyNodeFactory from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph from core.workflow.graph import Graph
@@ -642,5 +643,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(), config=GraphEngineConfig(),
) )
graph_engine.layer(LLMQuotaLayer())
return graph_engine return graph_engine

View File

@@ -1,14 +1,11 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import cast from typing import cast
from sqlalchemy import select, update from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.prompt.entities.advanced_prompt_entities import MemoryConfig
@@ -17,10 +14,7 @@ from core.workflow.file.models import File
from core.workflow.runtime import VariablePool from core.workflow.runtime import VariablePool
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import Conversation from models.model import Conversation
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
from .exc import InvalidVariableTypeError from .exc import InvalidVariableTypeError
@@ -68,68 +62,3 @@ def fetch_memory(
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model_name)
else:
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@@ -278,8 +278,6 @@ class LLMNode(Node[LLMNodeData]):
else None else None
) )
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break break
elif isinstance(event, LLMStructuredOutput): elif isinstance(event, LLMStructuredOutput):
structured_output = event structured_output = event
@@ -1234,6 +1232,10 @@ class LLMNode(Node[LLMNodeData]):
def retry(self) -> bool: def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled return self.node_data.retry_config.retry_enabled
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
def _combine_message_content_with_role( def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole

View File

@@ -413,6 +413,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _create_graph_engine(self, start_at: datetime, root_node_id: str): def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies # Import dependencies
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.node_factory import DifyNodeFactory from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph from core.workflow.graph import Graph
@@ -454,5 +455,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(), config=GraphEngineConfig(),
) )
graph_engine.layer(LLMQuotaLayer())
return graph_engine return graph_engine

View File

@@ -308,9 +308,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
usage = invoke_result.usage usage = invoke_result.usage
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage, tool_call return text, usage, tool_call
def _generate_function_call_prompt( def _generate_function_call_prompt(
@@ -828,6 +825,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
return rest_tokens return rest_tokens
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
cls, cls,

View File

@@ -240,6 +240,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
llm_usage=usage, llm_usage=usage,
) )
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
cls, cls,

View File

@@ -6,6 +6,7 @@ from typing import Any, cast
from configs import dify_config from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.layers.observability import ObservabilityLayer from core.app.workflow.layers.observability import ObservabilityLayer
from core.app.workflow.node_factory import DifyNodeFactory from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
@@ -106,6 +107,7 @@ class WorkflowEntry:
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
) )
self.graph_engine.layer(limits_layer) self.graph_engine.layer(limits_layer)
self.graph_engine.layer(LLMQuotaLayer())
# Add observability layer when OTel is enabled # Add observability layer when OTel is enabled
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled(): if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():

View File

@@ -0,0 +1,174 @@
import threading
from datetime import datetime
from unittest.mock import MagicMock, patch
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.errors.error import QuotaExceededError
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.commands import CommandType
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
def _build_succeeded_event() -> NodeRunSucceededEvent:
return NodeRunSucceededEvent(
id="execution-id",
node_id="llm-node-id",
node_type=NodeType.LLM,
start_at=datetime.now(),
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"question": "hello"},
llm_usage=LLMUsage.empty_usage(),
),
)
def test_deduct_quota_called_for_successful_llm_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.LLM
node.tenant_id = "tenant-id"
node.model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=node.model_instance,
usage=result_event.node_run_result.llm_usage,
)
def test_deduct_quota_called_for_question_classifier_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "question-classifier-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.QUESTION_CLASSIFIER
node.tenant_id = "tenant-id"
node.model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=node.model_instance,
usage=result_event.node_run_result.llm_usage,
)
def test_non_llm_node_is_ignored() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "start-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.START
node.tenant_id = "tenant-id"
node._model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_not_called()
def test_quota_error_is_handled_in_layer() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.LLM
node.tenant_id = "tenant-id"
node.model_instance = object()
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=ValueError("quota exceeded"),
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer()
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.LLM
node.tenant_id = "tenant-id"
node.model_instance = object()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=QuotaExceededError("No credits remaining"),
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
assert stop_event.is_set()
layer.command_channel.send_command.assert_called_once()
abort_command = layer.command_channel.send_command.call_args.args[0]
assert abort_command.command_type == CommandType.ABORT
assert abort_command.reason == "No credits remaining"
def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer()
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = NodeType.LLM
node.model_instance = object()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch(
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
autospec=True,
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
):
layer.on_node_run_start(node)
assert stop_event.is_set()
layer.command_channel.send_command.assert_called_once()
abort_command = layer.command_channel.send_command.call_args.args[0]
assert abort_command.command_type == CommandType.ABORT
assert abort_command.reason == "Model provider openai quota exceeded."
def test_quota_precheck_passes_without_abort() -> None:
layer = LLMQuotaLayer()
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = NodeType.LLM
node.model_instance = object()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check:
layer.on_node_run_start(node)
assert not stop_event.is_set()
mock_check.assert_called_once_with(model_instance=node.model_instance)
layer.command_channel.send_command.assert_not_called()