mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 19:32:16 +08:00
refactor(api): move llm quota deduction to app graph layer (#32786)
This commit is contained in:
@@ -29,6 +29,8 @@ ignore_imports =
|
|||||||
|
|
||||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||||
|
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
|
||||||
|
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
|
||||||
|
|
||||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||||
@@ -107,14 +109,12 @@ ignore_imports =
|
|||||||
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
|
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
|
||||||
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
||||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||||
|
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
|
||||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
|
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
|
||||||
core.workflow.nodes.llm.llm_utils -> configs
|
|
||||||
core.workflow.nodes.llm.llm_utils -> core.model_manager
|
core.workflow.nodes.llm.llm_utils -> core.model_manager
|
||||||
core.workflow.nodes.llm.protocols -> core.model_manager
|
core.workflow.nodes.llm.protocols -> core.model_manager
|
||||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
|
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
|
||||||
core.workflow.nodes.llm.llm_utils -> models.model
|
core.workflow.nodes.llm.llm_utils -> models.model
|
||||||
core.workflow.nodes.llm.llm_utils -> models.provider
|
|
||||||
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
|
|
||||||
core.workflow.nodes.llm.node -> core.tools.signature
|
core.workflow.nodes.llm.node -> core.tools.signature
|
||||||
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
|
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
|
||||||
@@ -135,8 +135,8 @@ ignore_imports =
|
|||||||
core.workflow.nodes.start.start_node -> core.app.app_config.entities
|
core.workflow.nodes.start.start_node -> core.app.app_config.entities
|
||||||
core.workflow.workflow_entry -> core.app.apps.exc
|
core.workflow.workflow_entry -> core.app.apps.exc
|
||||||
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
|
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
|
||||||
|
core.workflow.workflow_entry -> core.app.workflow.layers.llm_quota
|
||||||
core.workflow.workflow_entry -> core.app.workflow.node_factory
|
core.workflow.workflow_entry -> core.app.workflow.node_factory
|
||||||
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
|
|
||||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
|
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||||
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||||
@@ -180,7 +180,7 @@ ignore_imports =
|
|||||||
core.workflow.workflow_entry -> extensions.otel.runtime
|
core.workflow.workflow_entry -> extensions.otel.runtime
|
||||||
core.workflow.nodes.agent.agent_node -> models
|
core.workflow.nodes.agent.agent_node -> models
|
||||||
core.workflow.nodes.base.node -> models.enums
|
core.workflow.nodes.base.node -> models.enums
|
||||||
core.workflow.nodes.llm.llm_utils -> models.provider_ids
|
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
|
||||||
core.workflow.nodes.llm.node -> models.model
|
core.workflow.nodes.llm.node -> models.model
|
||||||
core.workflow.workflow_entry -> models.enums
|
core.workflow.workflow_entry -> models.enums
|
||||||
core.workflow.nodes.agent.agent_node -> services
|
core.workflow.nodes.agent.agent_node -> services
|
||||||
|
|||||||
@@ -1 +1,5 @@
|
|||||||
"""LLM-related application services."""
|
"""LLM-related application services."""
|
||||||
|
|
||||||
|
from .quota import deduct_llm_quota, ensure_llm_quota_available
|
||||||
|
|
||||||
|
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]
|
||||||
|
|||||||
93
api/core/app/llm/quota.py
Normal file
93
api/core/app/llm/quota.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
from sqlalchemy import update
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.entities.model_entities import ModelStatus
|
||||||
|
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||||
|
from core.errors.error import QuotaExceededError
|
||||||
|
from core.model_manager import ModelInstance
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
from models.provider import Provider, ProviderType
|
||||||
|
from models.provider_ids import ModelProviderID
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
|
||||||
|
provider_model_bundle = model_instance.provider_model_bundle
|
||||||
|
provider_configuration = provider_model_bundle.configuration
|
||||||
|
|
||||||
|
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||||
|
return
|
||||||
|
|
||||||
|
provider_model = provider_configuration.get_provider_model(
|
||||||
|
model_type=model_instance.model_type_instance.model_type,
|
||||||
|
model=model_instance.model_name,
|
||||||
|
)
|
||||||
|
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||||
|
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
|
||||||
|
|
||||||
|
|
||||||
|
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||||
|
provider_model_bundle = model_instance.provider_model_bundle
|
||||||
|
provider_configuration = provider_model_bundle.configuration
|
||||||
|
|
||||||
|
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||||
|
return
|
||||||
|
|
||||||
|
system_configuration = provider_configuration.system_configuration
|
||||||
|
|
||||||
|
quota_unit = None
|
||||||
|
for quota_configuration in system_configuration.quota_configurations:
|
||||||
|
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||||
|
quota_unit = quota_configuration.quota_unit
|
||||||
|
|
||||||
|
if quota_configuration.quota_limit == -1:
|
||||||
|
return
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
used_quota = None
|
||||||
|
if quota_unit:
|
||||||
|
if quota_unit == QuotaUnit.TOKENS:
|
||||||
|
used_quota = usage.total_tokens
|
||||||
|
elif quota_unit == QuotaUnit.CREDITS:
|
||||||
|
used_quota = dify_config.get_model_credits(model_instance.model_name)
|
||||||
|
else:
|
||||||
|
used_quota = 1
|
||||||
|
|
||||||
|
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||||
|
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||||
|
from services.credit_pool_service import CreditPoolService
|
||||||
|
|
||||||
|
CreditPoolService.check_and_deduct_credits(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
credits_required=used_quota,
|
||||||
|
)
|
||||||
|
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||||
|
from services.credit_pool_service import CreditPoolService
|
||||||
|
|
||||||
|
CreditPoolService.check_and_deduct_credits(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
credits_required=used_quota,
|
||||||
|
pool_type="paid",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
stmt = (
|
||||||
|
update(Provider)
|
||||||
|
.where(
|
||||||
|
Provider.tenant_id == tenant_id,
|
||||||
|
# TODO: Use provider name with prefix after the data migration.
|
||||||
|
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||||
|
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||||
|
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||||
|
Provider.quota_limit > Provider.quota_used,
|
||||||
|
)
|
||||||
|
.values(
|
||||||
|
quota_used=Provider.quota_used + used_quota,
|
||||||
|
last_used=naive_utc_now(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.execute(stmt)
|
||||||
|
session.commit()
|
||||||
@@ -1,9 +1,11 @@
|
|||||||
"""Workflow-level GraphEngine layers that depend on outer infrastructure."""
|
"""Workflow-level GraphEngine layers that depend on outer infrastructure."""
|
||||||
|
|
||||||
|
from .llm_quota import LLMQuotaLayer
|
||||||
from .observability import ObservabilityLayer
|
from .observability import ObservabilityLayer
|
||||||
from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"LLMQuotaLayer",
|
||||||
"ObservabilityLayer",
|
"ObservabilityLayer",
|
||||||
"PersistenceWorkflowInfo",
|
"PersistenceWorkflowInfo",
|
||||||
"WorkflowPersistenceLayer",
|
"WorkflowPersistenceLayer",
|
||||||
|
|||||||
128
api/core/app/workflow/layers/llm_quota.py
Normal file
128
api/core/app/workflow/layers/llm_quota.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""
|
||||||
|
LLM quota deduction layer for GraphEngine.
|
||||||
|
|
||||||
|
This layer centralizes model-quota deduction outside node implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, cast, final
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||||
|
from core.errors.error import QuotaExceededError
|
||||||
|
from core.model_manager import ModelInstance
|
||||||
|
from core.workflow.enums import NodeType
|
||||||
|
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||||
|
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||||
|
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
|
||||||
|
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||||
|
from core.workflow.nodes.base.node import Node
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
|
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||||
|
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
|
class LLMQuotaLayer(GraphEngineLayer):
|
||||||
|
"""Graph layer that applies LLM quota deduction after node execution."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._abort_sent = False
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_graph_start(self) -> None:
|
||||||
|
self._abort_sent = False
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_event(self, event: GraphEngineEvent) -> None:
|
||||||
|
_ = event
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_graph_end(self, error: Exception | None) -> None:
|
||||||
|
_ = error
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_node_run_start(self, node: Node) -> None:
|
||||||
|
if self._abort_sent:
|
||||||
|
return
|
||||||
|
|
||||||
|
model_instance = self._extract_model_instance(node)
|
||||||
|
if model_instance is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
ensure_llm_quota_available(model_instance=model_instance)
|
||||||
|
except QuotaExceededError as exc:
|
||||||
|
self._set_stop_event(node)
|
||||||
|
self._send_abort_command(reason=str(exc))
|
||||||
|
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_node_run_end(
|
||||||
|
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||||
|
) -> None:
|
||||||
|
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
|
||||||
|
return
|
||||||
|
|
||||||
|
model_instance = self._extract_model_instance(node)
|
||||||
|
if model_instance is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
deduct_llm_quota(
|
||||||
|
tenant_id=node.tenant_id,
|
||||||
|
model_instance=model_instance,
|
||||||
|
usage=result_event.node_run_result.llm_usage,
|
||||||
|
)
|
||||||
|
except QuotaExceededError as exc:
|
||||||
|
self._set_stop_event(node)
|
||||||
|
self._send_abort_command(reason=str(exc))
|
||||||
|
logger.warning("LLM quota deduction exceeded, node_id=%s, error=%s", node.id, exc)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("LLM quota deduction failed, node_id=%s", node.id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _set_stop_event(node: Node) -> None:
|
||||||
|
stop_event = getattr(node.graph_runtime_state, "stop_event", None)
|
||||||
|
if stop_event is not None:
|
||||||
|
stop_event.set()
|
||||||
|
|
||||||
|
def _send_abort_command(self, *, reason: str) -> None:
|
||||||
|
if not self.command_channel or self._abort_sent:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.command_channel.send_command(
|
||||||
|
AbortCommand(
|
||||||
|
command_type=CommandType.ABORT,
|
||||||
|
reason=reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._abort_sent = True
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to send quota abort command")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_model_instance(node: Node) -> ModelInstance | None:
|
||||||
|
try:
|
||||||
|
match node.node_type:
|
||||||
|
case NodeType.LLM:
|
||||||
|
return cast("LLMNode", node).model_instance
|
||||||
|
case NodeType.PARAMETER_EXTRACTOR:
|
||||||
|
return cast("ParameterExtractorNode", node).model_instance
|
||||||
|
case NodeType.QUESTION_CLASSIFIER:
|
||||||
|
return cast("QuestionClassifierNode", node).model_instance
|
||||||
|
case _:
|
||||||
|
return None
|
||||||
|
except AttributeError:
|
||||||
|
logger.warning(
|
||||||
|
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
|
||||||
|
node.id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
@@ -2,6 +2,7 @@ import tempfile
|
|||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from core.app.llm import deduct_llm_quota
|
||||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.model_runtime.entities.llm_entities import (
|
from core.model_runtime.entities.llm_entities import (
|
||||||
@@ -29,7 +30,6 @@ from core.plugin.entities.request import (
|
|||||||
)
|
)
|
||||||
from core.tools.entities.tool_entities import ToolProviderType
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
||||||
from core.workflow.nodes.llm import llm_utils
|
|
||||||
from models.account import Tenant
|
from models.account import Tenant
|
||||||
|
|
||||||
|
|
||||||
@@ -63,16 +63,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
def handle() -> Generator[LLMResultChunk, None, None]:
|
def handle() -> Generator[LLMResultChunk, None, None]:
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if chunk.delta.usage:
|
if chunk.delta.usage:
|
||||||
llm_utils.deduct_llm_quota(
|
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
|
||||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
|
||||||
)
|
|
||||||
chunk.prompt_messages = []
|
chunk.prompt_messages = []
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return handle()
|
return handle()
|
||||||
else:
|
else:
|
||||||
if response.usage:
|
if response.usage:
|
||||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||||
|
|
||||||
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
|
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
@@ -126,16 +124,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if chunk.delta.usage:
|
if chunk.delta.usage:
|
||||||
llm_utils.deduct_llm_quota(
|
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
|
||||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
|
||||||
)
|
|
||||||
chunk.prompt_messages = []
|
chunk.prompt_messages = []
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return handle()
|
return handle()
|
||||||
else:
|
else:
|
||||||
if response.usage:
|
if response.usage:
|
||||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||||
|
|
||||||
def handle_non_streaming(
|
def handle_non_streaming(
|
||||||
response: LLMResultWithStructuredOutput,
|
response: LLMResultWithStructuredOutput,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from typing import Any, cast
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from core.app.llm import deduct_llm_quota
|
||||||
from core.entities.knowledge_entities import PreviewDetail
|
from core.entities.knowledge_entities import PreviewDetail
|
||||||
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
@@ -35,7 +36,6 @@ from core.rag.models.document import AttachmentDocument, Document, MultimodalGen
|
|||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||||
from core.workflow.file import File, FileTransferMethod, FileType, file_manager
|
from core.workflow.file import File, FileTransferMethod, FileType, file_manager
|
||||||
from core.workflow.nodes.llm import llm_utils
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories.file_factory import build_from_mapping
|
from factories.file_factory import build_from_mapping
|
||||||
from libs import helper
|
from libs import helper
|
||||||
@@ -474,7 +474,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||||||
|
|
||||||
# Deduct quota for summary generation (same as workflow nodes)
|
# Deduct quota for summary generation (same as workflow nodes)
|
||||||
try:
|
try:
|
||||||
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log but don't fail summary generation if quota deduction fails
|
# Log but don't fail summary generation if quota deduction fails
|
||||||
logger.warning("Failed to deduct quota for summary generation: %s", str(e))
|
logger.warning("Failed to deduct quota for summary generation: %s", str(e))
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from collections.abc import Generator, Sequence
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
|
from core.app.llm import deduct_llm_quota
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||||
@@ -9,7 +10,6 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|||||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||||
from core.rag.retrieval.output_parser.react_output import ReactAction
|
from core.rag.retrieval.output_parser.react_output import ReactAction
|
||||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||||
from core.workflow.nodes.llm import llm_utils
|
|
||||||
|
|
||||||
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
|
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
|
||||||
|
|
||||||
@@ -162,7 +162,7 @@ class ReactMultiDatasetRouter:
|
|||||||
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||||
|
|
||||||
# deduct quota
|
# deduct quota
|
||||||
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||||
|
|
||||||
return text, usage
|
return text, usage
|
||||||
|
|
||||||
|
|||||||
@@ -588,6 +588,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||||||
|
|
||||||
def _create_graph_engine(self, index: int, item: object):
|
def _create_graph_engine(self, index: int, item: object):
|
||||||
# Import dependencies
|
# Import dependencies
|
||||||
|
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||||
from core.app.workflow.node_factory import DifyNodeFactory
|
from core.app.workflow.node_factory import DifyNodeFactory
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
@@ -642,5 +643,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||||
config=GraphEngineConfig(),
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
graph_engine.layer(LLMQuotaLayer())
|
||||||
|
|
||||||
return graph_engine
|
return graph_engine
|
||||||
|
|||||||
@@ -1,14 +1,11 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from sqlalchemy import select, update
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||||
@@ -17,10 +14,7 @@ from core.workflow.file.models import File
|
|||||||
from core.workflow.runtime import VariablePool
|
from core.workflow.runtime import VariablePool
|
||||||
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
|
||||||
from models.model import Conversation
|
from models.model import Conversation
|
||||||
from models.provider import Provider, ProviderType
|
|
||||||
from models.provider_ids import ModelProviderID
|
|
||||||
|
|
||||||
from .exc import InvalidVariableTypeError
|
from .exc import InvalidVariableTypeError
|
||||||
|
|
||||||
@@ -68,68 +62,3 @@ def fetch_memory(
|
|||||||
|
|
||||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
|
|
||||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
|
|
||||||
provider_model_bundle = model_instance.provider_model_bundle
|
|
||||||
provider_configuration = provider_model_bundle.configuration
|
|
||||||
|
|
||||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
|
||||||
return
|
|
||||||
|
|
||||||
system_configuration = provider_configuration.system_configuration
|
|
||||||
|
|
||||||
quota_unit = None
|
|
||||||
for quota_configuration in system_configuration.quota_configurations:
|
|
||||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
|
||||||
quota_unit = quota_configuration.quota_unit
|
|
||||||
|
|
||||||
if quota_configuration.quota_limit == -1:
|
|
||||||
return
|
|
||||||
|
|
||||||
break
|
|
||||||
|
|
||||||
used_quota = None
|
|
||||||
if quota_unit:
|
|
||||||
if quota_unit == QuotaUnit.TOKENS:
|
|
||||||
used_quota = usage.total_tokens
|
|
||||||
elif quota_unit == QuotaUnit.CREDITS:
|
|
||||||
used_quota = dify_config.get_model_credits(model_instance.model_name)
|
|
||||||
else:
|
|
||||||
used_quota = 1
|
|
||||||
|
|
||||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
|
||||||
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
|
||||||
from services.credit_pool_service import CreditPoolService
|
|
||||||
|
|
||||||
CreditPoolService.check_and_deduct_credits(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
credits_required=used_quota,
|
|
||||||
)
|
|
||||||
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
|
||||||
from services.credit_pool_service import CreditPoolService
|
|
||||||
|
|
||||||
CreditPoolService.check_and_deduct_credits(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
credits_required=used_quota,
|
|
||||||
pool_type="paid",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
with Session(db.engine) as session:
|
|
||||||
stmt = (
|
|
||||||
update(Provider)
|
|
||||||
.where(
|
|
||||||
Provider.tenant_id == tenant_id,
|
|
||||||
# TODO: Use provider name with prefix after the data migration.
|
|
||||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
|
||||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
|
||||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
|
||||||
Provider.quota_limit > Provider.quota_used,
|
|
||||||
)
|
|
||||||
.values(
|
|
||||||
quota_used=Provider.quota_used + used_quota,
|
|
||||||
last_used=naive_utc_now(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
session.execute(stmt)
|
|
||||||
session.commit()
|
|
||||||
|
|||||||
@@ -278,8 +278,6 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# deduct quota
|
|
||||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
|
||||||
break
|
break
|
||||||
elif isinstance(event, LLMStructuredOutput):
|
elif isinstance(event, LLMStructuredOutput):
|
||||||
structured_output = event
|
structured_output = event
|
||||||
@@ -1234,6 +1232,10 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
def retry(self) -> bool:
|
def retry(self) -> bool:
|
||||||
return self.node_data.retry_config.retry_enabled
|
return self.node_data.retry_config.retry_enabled
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_instance(self) -> ModelInstance:
|
||||||
|
return self._model_instance
|
||||||
|
|
||||||
|
|
||||||
def _combine_message_content_with_role(
|
def _combine_message_content_with_role(
|
||||||
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
||||||
|
|||||||
@@ -413,6 +413,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||||||
|
|
||||||
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
|
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
|
||||||
# Import dependencies
|
# Import dependencies
|
||||||
|
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||||
from core.app.workflow.node_factory import DifyNodeFactory
|
from core.app.workflow.node_factory import DifyNodeFactory
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
@@ -454,5 +455,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||||
config=GraphEngineConfig(),
|
config=GraphEngineConfig(),
|
||||||
)
|
)
|
||||||
|
graph_engine.layer(LLMQuotaLayer())
|
||||||
|
|
||||||
return graph_engine
|
return graph_engine
|
||||||
|
|||||||
@@ -308,9 +308,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
|||||||
usage = invoke_result.usage
|
usage = invoke_result.usage
|
||||||
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
|
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
|
||||||
|
|
||||||
# deduct quota
|
|
||||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
|
||||||
|
|
||||||
return text, usage, tool_call
|
return text, usage, tool_call
|
||||||
|
|
||||||
def _generate_function_call_prompt(
|
def _generate_function_call_prompt(
|
||||||
@@ -828,6 +825,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
|||||||
|
|
||||||
return rest_tokens
|
return rest_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_instance(self) -> ModelInstance:
|
||||||
|
return self._model_instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@@ -240,6 +240,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
|||||||
llm_usage=usage,
|
llm_usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_instance(self) -> ModelInstance:
|
||||||
|
return self._model_instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any, cast
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.apps.exc import GenerateTaskStoppedError
|
from core.app.apps.exc import GenerateTaskStoppedError
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||||
from core.app.workflow.node_factory import DifyNodeFactory
|
from core.app.workflow.node_factory import DifyNodeFactory
|
||||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||||
@@ -106,6 +107,7 @@ class WorkflowEntry:
|
|||||||
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||||
)
|
)
|
||||||
self.graph_engine.layer(limits_layer)
|
self.graph_engine.layer(limits_layer)
|
||||||
|
self.graph_engine.layer(LLMQuotaLayer())
|
||||||
|
|
||||||
# Add observability layer when OTel is enabled
|
# Add observability layer when OTel is enabled
|
||||||
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():
|
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():
|
||||||
|
|||||||
@@ -0,0 +1,174 @@
|
|||||||
|
import threading
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||||
|
from core.errors.error import QuotaExceededError
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
|
from core.workflow.graph_engine.entities.commands import CommandType
|
||||||
|
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||||
|
from core.workflow.node_events import NodeRunResult
|
||||||
|
|
||||||
|
|
||||||
|
def _build_succeeded_event() -> NodeRunSucceededEvent:
|
||||||
|
return NodeRunSucceededEvent(
|
||||||
|
id="execution-id",
|
||||||
|
node_id="llm-node-id",
|
||||||
|
node_type=NodeType.LLM,
|
||||||
|
start_at=datetime.now(),
|
||||||
|
node_run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
inputs={"question": "hello"},
|
||||||
|
llm_usage=LLMUsage.empty_usage(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_deduct_quota_called_for_successful_llm_node() -> None:
|
||||||
|
layer = LLMQuotaLayer()
|
||||||
|
node = MagicMock()
|
||||||
|
node.id = "llm-node-id"
|
||||||
|
node.execution_id = "execution-id"
|
||||||
|
node.node_type = NodeType.LLM
|
||||||
|
node.tenant_id = "tenant-id"
|
||||||
|
node.model_instance = object()
|
||||||
|
|
||||||
|
result_event = _build_succeeded_event()
|
||||||
|
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||||
|
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||||
|
|
||||||
|
mock_deduct.assert_called_once_with(
|
||||||
|
tenant_id="tenant-id",
|
||||||
|
model_instance=node.model_instance,
|
||||||
|
usage=result_event.node_run_result.llm_usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_deduct_quota_called_for_question_classifier_node() -> None:
|
||||||
|
layer = LLMQuotaLayer()
|
||||||
|
node = MagicMock()
|
||||||
|
node.id = "question-classifier-node-id"
|
||||||
|
node.execution_id = "execution-id"
|
||||||
|
node.node_type = NodeType.QUESTION_CLASSIFIER
|
||||||
|
node.tenant_id = "tenant-id"
|
||||||
|
node.model_instance = object()
|
||||||
|
|
||||||
|
result_event = _build_succeeded_event()
|
||||||
|
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||||
|
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||||
|
|
||||||
|
mock_deduct.assert_called_once_with(
|
||||||
|
tenant_id="tenant-id",
|
||||||
|
model_instance=node.model_instance,
|
||||||
|
usage=result_event.node_run_result.llm_usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_llm_node_is_ignored() -> None:
|
||||||
|
layer = LLMQuotaLayer()
|
||||||
|
node = MagicMock()
|
||||||
|
node.id = "start-node-id"
|
||||||
|
node.execution_id = "execution-id"
|
||||||
|
node.node_type = NodeType.START
|
||||||
|
node.tenant_id = "tenant-id"
|
||||||
|
node._model_instance = object()
|
||||||
|
|
||||||
|
result_event = _build_succeeded_event()
|
||||||
|
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||||
|
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||||
|
|
||||||
|
mock_deduct.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_quota_error_is_handled_in_layer() -> None:
|
||||||
|
layer = LLMQuotaLayer()
|
||||||
|
node = MagicMock()
|
||||||
|
node.id = "llm-node-id"
|
||||||
|
node.execution_id = "execution-id"
|
||||||
|
node.node_type = NodeType.LLM
|
||||||
|
node.tenant_id = "tenant-id"
|
||||||
|
node.model_instance = object()
|
||||||
|
|
||||||
|
result_event = _build_succeeded_event()
|
||||||
|
with patch(
|
||||||
|
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
|
||||||
|
autospec=True,
|
||||||
|
side_effect=ValueError("quota exceeded"),
|
||||||
|
):
|
||||||
|
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||||
|
|
||||||
|
|
||||||
|
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
|
||||||
|
layer = LLMQuotaLayer()
|
||||||
|
stop_event = threading.Event()
|
||||||
|
layer.command_channel = MagicMock()
|
||||||
|
|
||||||
|
node = MagicMock()
|
||||||
|
node.id = "llm-node-id"
|
||||||
|
node.execution_id = "execution-id"
|
||||||
|
node.node_type = NodeType.LLM
|
||||||
|
node.tenant_id = "tenant-id"
|
||||||
|
node.model_instance = object()
|
||||||
|
node.graph_runtime_state = MagicMock()
|
||||||
|
node.graph_runtime_state.stop_event = stop_event
|
||||||
|
|
||||||
|
result_event = _build_succeeded_event()
|
||||||
|
with patch(
|
||||||
|
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
|
||||||
|
autospec=True,
|
||||||
|
side_effect=QuotaExceededError("No credits remaining"),
|
||||||
|
):
|
||||||
|
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||||
|
|
||||||
|
assert stop_event.is_set()
|
||||||
|
layer.command_channel.send_command.assert_called_once()
|
||||||
|
abort_command = layer.command_channel.send_command.call_args.args[0]
|
||||||
|
assert abort_command.command_type == CommandType.ABORT
|
||||||
|
assert abort_command.reason == "No credits remaining"
|
||||||
|
|
||||||
|
|
||||||
|
def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
|
||||||
|
layer = LLMQuotaLayer()
|
||||||
|
stop_event = threading.Event()
|
||||||
|
layer.command_channel = MagicMock()
|
||||||
|
|
||||||
|
node = MagicMock()
|
||||||
|
node.id = "llm-node-id"
|
||||||
|
node.node_type = NodeType.LLM
|
||||||
|
node.model_instance = object()
|
||||||
|
node.graph_runtime_state = MagicMock()
|
||||||
|
node.graph_runtime_state.stop_event = stop_event
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
|
||||||
|
autospec=True,
|
||||||
|
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
|
||||||
|
):
|
||||||
|
layer.on_node_run_start(node)
|
||||||
|
|
||||||
|
assert stop_event.is_set()
|
||||||
|
layer.command_channel.send_command.assert_called_once()
|
||||||
|
abort_command = layer.command_channel.send_command.call_args.args[0]
|
||||||
|
assert abort_command.command_type == CommandType.ABORT
|
||||||
|
assert abort_command.reason == "Model provider openai quota exceeded."
|
||||||
|
|
||||||
|
|
||||||
|
def test_quota_precheck_passes_without_abort() -> None:
|
||||||
|
layer = LLMQuotaLayer()
|
||||||
|
stop_event = threading.Event()
|
||||||
|
layer.command_channel = MagicMock()
|
||||||
|
|
||||||
|
node = MagicMock()
|
||||||
|
node.id = "llm-node-id"
|
||||||
|
node.node_type = NodeType.LLM
|
||||||
|
node.model_instance = object()
|
||||||
|
node.graph_runtime_state = MagicMock()
|
||||||
|
node.graph_runtime_state.stop_event = stop_event
|
||||||
|
|
||||||
|
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check:
|
||||||
|
layer.on_node_run_start(node)
|
||||||
|
|
||||||
|
assert not stop_event.is_set()
|
||||||
|
mock_check.assert_called_once_with(model_instance=node.model_instance)
|
||||||
|
layer.command_channel.send_command.assert_not_called()
|
||||||
Reference in New Issue
Block a user