mirror of
https://github.com/langgenius/dify.git
synced 2026-04-11 21:01:25 +08:00
Compare commits
73 Commits
refactor/s
...
fix/enterp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d22f0dfdac | ||
|
|
ef7918764b | ||
|
|
cf616ddc58 | ||
|
|
da06f738f7 | ||
|
|
c0b05db9d7 | ||
|
|
248504acef | ||
|
|
d105a2f568 | ||
|
|
64fcd9859f | ||
|
|
0f938d453c | ||
|
|
6625828246 | ||
|
|
7a1f0e3258 | ||
|
|
9a682f1009 | ||
|
|
877de7fb22 | ||
|
|
d81684d8d1 | ||
|
|
c900460ab3 | ||
|
|
5afb24f461 | ||
|
|
808002fbbd | ||
|
|
757fabda1e | ||
|
|
858ccd8746 | ||
|
|
ea35ee0a3e | ||
|
|
0e9dc86f3b | ||
|
|
0ed39d81e9 | ||
|
|
2b739b9544 | ||
|
|
22e82297c5 | ||
|
|
7ef139cadd | ||
|
|
bf5a327156 | ||
|
|
d94af41f07 | ||
|
|
8d8552cbb9 | ||
|
|
58524fd7fd | ||
|
|
2d7bffcc11 | ||
|
|
5025e29220 | ||
|
|
3cdc9c119e | ||
|
|
18ba367b11 | ||
|
|
d0bd74fccb | ||
|
|
5ccbc00eb9 | ||
|
|
94603b5408 | ||
|
|
8d4bd5636b | ||
|
|
ee0c4a8852 | ||
|
|
6032c598b0 | ||
|
|
afdd5b6c86 | ||
|
|
9acdfbde2f | ||
|
|
1977e68b2d | ||
|
|
e9a7e8f77f | ||
|
|
9e2b28c950 | ||
|
|
affd07ae94 | ||
|
|
111c76b71f | ||
|
|
793d22754e | ||
|
|
b62965034e | ||
|
|
016d72a8c6 | ||
|
|
08b8eff933 | ||
|
|
579cdea820 | ||
|
|
125f7e3ab4 | ||
|
|
400ed2fd72 | ||
|
|
840a8f3fc2 | ||
|
|
b4a5296fd1 | ||
|
|
d7c3ae50dc | ||
|
|
b921711e9e | ||
|
|
fb38ad84e1 | ||
|
|
91c854b5be | ||
|
|
d35b231941 | ||
|
|
849b4b8c40 | ||
|
|
990e8feee8 | ||
|
|
53641019b1 | ||
|
|
d1f10ff301 | ||
|
|
c8027e168b | ||
|
|
aae3f76999 | ||
|
|
2860c72b03 | ||
|
|
fcb53383df | ||
|
|
540e1db83c | ||
|
|
2f75e38c08 | ||
|
|
cd03e0a9ef | ||
|
|
df2421d187 | ||
|
|
0ba321d840 |
1
.codex/skills/component-refactoring
Symbolic link
1
.codex/skills/component-refactoring
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/component-refactoring
|
||||
1
.codex/skills/frontend-code-review
Symbolic link
1
.codex/skills/frontend-code-review
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/frontend-code-review
|
||||
1
.codex/skills/frontend-testing
Symbolic link
1
.codex/skills/frontend-testing
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/frontend-testing
|
||||
1
.codex/skills/orpc-contract-first
Symbolic link
1
.codex/skills/orpc-contract-first
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/orpc-contract-first
|
||||
7
.github/CODEOWNERS
vendored
7
.github/CODEOWNERS
vendored
@@ -24,10 +24,6 @@
|
||||
/api/services/tools/mcp_tools_manage_service.py @Nov1c444
|
||||
/api/controllers/mcp/ @Nov1c444
|
||||
/api/controllers/console/app/mcp_server.py @Nov1c444
|
||||
|
||||
# Backend - Tests
|
||||
/api/tests/ @laipz8200 @QuantumGhost
|
||||
|
||||
/api/tests/**/*mcp* @Nov1c444
|
||||
|
||||
# Backend - Workflow - Engine (Core graph execution engine)
|
||||
@@ -238,9 +234,6 @@
|
||||
# Frontend - Base Components
|
||||
/web/app/components/base/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Base Components Tests
|
||||
/web/app/components/base/**/*.spec.tsx @hyoban @CodingOnStar
|
||||
|
||||
# Frontend - Utils and Hooks
|
||||
/web/utils/classnames.ts @iamjoel @zxhlyh
|
||||
/web/utils/time.ts @iamjoel @zxhlyh
|
||||
|
||||
23
.github/workflows/autofix.yml
vendored
23
.github/workflows/autofix.yml
vendored
@@ -79,6 +79,29 @@ jobs:
|
||||
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
|
||||
find . -name "*.py.bak" -type f -delete
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install web dependencies
|
||||
run: |
|
||||
cd web
|
||||
pnpm install --frozen-lockfile
|
||||
|
||||
- name: ESLint autofix
|
||||
run: |
|
||||
cd web
|
||||
pnpm lint:fix || true
|
||||
|
||||
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
||||
- name: mdformat
|
||||
run: |
|
||||
|
||||
8
.github/workflows/deploy-hitl.yml
vendored
8
.github/workflows/deploy-hitl.yml
vendored
@@ -4,7 +4,8 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "build/feat/hitl"
|
||||
- "feat/hitl-frontend"
|
||||
- "feat/hitl-backend"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@@ -13,7 +14,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'build/feat/hitl'
|
||||
(
|
||||
github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
|
||||
github.event.workflow_run.head_branch == 'feat/hitl-backend'
|
||||
)
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v1
|
||||
|
||||
2
.github/workflows/web-tests.yml
vendored
2
.github/workflows/web-tests.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm test:ci
|
||||
run: pnpm test:coverage
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
|
||||
@@ -136,6 +136,7 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider
|
||||
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
|
||||
core.workflow.nodes.llm.node -> core.tools.signature
|
||||
core.workflow.nodes.template_transform.template_transform_node -> configs
|
||||
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_manager
|
||||
|
||||
@@ -106,10 +106,10 @@ ignore = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests,
|
||||
"S110", # allow ignoring exceptions in tests code (currently)
|
||||
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests,
|
||||
"S110", # allow ignoring exceptions in tests code (currently)
|
||||
"PT019", # @patch-injected params look like unused fixtures
|
||||
]
|
||||
"controllers/console/explore/trial.py" = ["TID251"]
|
||||
"controllers/console/human_input_form.py" = ["TID251"]
|
||||
|
||||
@@ -122,7 +122,8 @@ These commands assume you start from the repository root.
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
# Note: enterprise_telemetry queue is only used in Enterprise Edition
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,enterprise_telemetry
|
||||
```
|
||||
|
||||
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
|
||||
|
||||
@@ -1,16 +1,45 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from flask import request
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from controllers.console.error import UnauthorizedAndForceLogout
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Console bootstrap APIs exempt from license check.
|
||||
# Defined at module level to avoid per-request tuple construction.
|
||||
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
|
||||
# - setup: install/setup status check (AppInitializer)
|
||||
# - init: init password validation for fresh install (InitPasswordPopup)
|
||||
# - login: auto-login after setup completion (InstallForm)
|
||||
# - features: billing/plan features (ProviderContextProvider)
|
||||
# - account/profile: login check + user profile (AppContextProvider, useIsLogin)
|
||||
# - workspaces/current: workspace + model providers (AppContextProvider)
|
||||
# - version: version check (AppContextProvider)
|
||||
# - activate/check: invitation link validation (signin page)
|
||||
# Without these exemptions, the signin page triggers location.reload()
|
||||
# on unauthorized_and_force_logout, causing an infinite loop.
|
||||
_CONSOLE_EXEMPT_PREFIXES = (
|
||||
"/console/api/system-features",
|
||||
"/console/api/setup",
|
||||
"/console/api/init",
|
||||
"/console/api/login",
|
||||
"/console/api/features",
|
||||
"/console/api/account/profile",
|
||||
"/console/api/workspaces/current",
|
||||
"/console/api/version",
|
||||
"/console/api/activate/check",
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Application Factory Function
|
||||
@@ -31,6 +60,39 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
init_request_context()
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# Enterprise license validation for API endpoints (both console and webapp)
|
||||
# When license expires, block all API access except bootstrap endpoints needed
|
||||
# for the frontend to load the license expiration page without infinite reloads.
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
is_console_api = request.path.startswith("/console/api/")
|
||||
is_webapp_api = request.path.startswith("/api/") and not is_console_api
|
||||
|
||||
if is_console_api or is_webapp_api:
|
||||
if is_console_api:
|
||||
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
|
||||
else: # webapp API
|
||||
is_exempt = request.path.startswith("/api/system-features")
|
||||
|
||||
if not is_exempt:
|
||||
try:
|
||||
# Check license status with caching (10 min TTL)
|
||||
license_status = EnterpriseService.get_cached_license_status()
|
||||
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
|
||||
# Cookie clearing is handled by register_external_error_handlers
|
||||
# in libs/external_api.py which detects the error code and calls
|
||||
# build_force_logout_cookie_headers(). Frontend then checks
|
||||
# code === 'unauthorized_and_force_logout' and calls location.reload().
|
||||
raise UnauthorizedAndForceLogout(
|
||||
f"Enterprise license is {license_status}. Please contact your administrator."
|
||||
)
|
||||
except UnauthorizedAndForceLogout:
|
||||
raise
|
||||
except Exception:
|
||||
# If license check fails, log but don't block the request.
|
||||
# This prevents service disruption if enterprise API is temporarily
|
||||
# unavailable.
|
||||
logger.exception("Failed to check enterprise license status")
|
||||
|
||||
# add after request hook for injecting trace headers from OpenTelemetry span context
|
||||
# Only adds headers when OTEL is enabled and has valid context
|
||||
@dify_app.after_request
|
||||
@@ -81,6 +143,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_commands,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_enterprise_telemetry,
|
||||
ext_fastopenapi,
|
||||
ext_forward_refs,
|
||||
ext_hosting_provider,
|
||||
@@ -131,6 +194,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_commands,
|
||||
ext_fastopenapi,
|
||||
ext_otel,
|
||||
ext_enterprise_telemetry,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
]
|
||||
|
||||
@@ -30,6 +30,7 @@ from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.db_migration_lock import DbMigrationAutoRenewLock
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
@@ -54,6 +55,8 @@ from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DB_UPGRADE_LOCK_TTL_SECONDS = 60
|
||||
|
||||
|
||||
@click.command("reset-password", help="Reset the account password.")
|
||||
@click.option("--email", prompt=True, help="Account email to reset password for")
|
||||
@@ -727,8 +730,15 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
|
||||
@click.command("upgrade-db", help="Upgrade the database")
|
||||
def upgrade_db():
|
||||
click.echo("Preparing database migration...")
|
||||
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
|
||||
lock = DbMigrationAutoRenewLock(
|
||||
redis_client=redis_client,
|
||||
name="db_upgrade_lock",
|
||||
ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS,
|
||||
logger=logger,
|
||||
log_context="db_migration",
|
||||
)
|
||||
if lock.acquire(blocking=False):
|
||||
migration_succeeded = False
|
||||
try:
|
||||
click.echo(click.style("Starting database migration.", fg="green"))
|
||||
|
||||
@@ -737,6 +747,7 @@ def upgrade_db():
|
||||
|
||||
flask_migrate.upgrade()
|
||||
|
||||
migration_succeeded = True
|
||||
click.echo(click.style("Database migration successful!", fg="green"))
|
||||
|
||||
except Exception as e:
|
||||
@@ -744,7 +755,8 @@ def upgrade_db():
|
||||
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
|
||||
raise SystemExit(1)
|
||||
finally:
|
||||
lock.release()
|
||||
status = "successful" if migration_succeeded else "failed"
|
||||
lock.release_safely(status=status)
|
||||
else:
|
||||
click.echo("Database migration skipped")
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, Settings
|
||||
from libs.file_utils import search_file_upwards
|
||||
|
||||
from .deploy import DeploymentConfig
|
||||
from .enterprise import EnterpriseFeatureConfig
|
||||
from .enterprise import EnterpriseFeatureConfig, EnterpriseTelemetryConfig
|
||||
from .extra import ExtraServiceConfig
|
||||
from .feature import FeatureConfig
|
||||
from .middleware import MiddlewareConfig
|
||||
@@ -73,6 +73,8 @@ class DifyConfig(
|
||||
# Enterprise feature configs
|
||||
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
EnterpriseFeatureConfig,
|
||||
# Enterprise telemetry configs
|
||||
EnterpriseTelemetryConfig,
|
||||
):
|
||||
model_config = SettingsConfigDict(
|
||||
# read from dotenv format config file
|
||||
|
||||
@@ -18,3 +18,49 @@ class EnterpriseFeatureConfig(BaseSettings):
|
||||
description="Allow customization of the enterprise logo.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class EnterpriseTelemetryConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for enterprise telemetry.
|
||||
"""
|
||||
|
||||
ENTERPRISE_TELEMETRY_ENABLED: bool = Field(
|
||||
description="Enable enterprise telemetry collection (also requires ENTERPRISE_ENABLED=true).",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ENTERPRISE_OTLP_ENDPOINT: str = Field(
|
||||
description="Enterprise OTEL collector endpoint.",
|
||||
default="",
|
||||
)
|
||||
|
||||
ENTERPRISE_OTLP_HEADERS: str = Field(
|
||||
description="Auth headers for OTLP export (key=value,key2=value2).",
|
||||
default="",
|
||||
)
|
||||
|
||||
ENTERPRISE_OTLP_PROTOCOL: str = Field(
|
||||
description="OTLP protocol: 'http' or 'grpc' (default: http).",
|
||||
default="http",
|
||||
)
|
||||
|
||||
ENTERPRISE_OTLP_API_KEY: str = Field(
|
||||
description="Bearer token for enterprise OTLP export authentication.",
|
||||
default="",
|
||||
)
|
||||
|
||||
ENTERPRISE_INCLUDE_CONTENT: bool = Field(
|
||||
description="Include input/output content in traces (privacy toggle).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
ENTERPRISE_SERVICE_NAME: str = Field(
|
||||
description="Service name for OTEL resource.",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
ENTERPRISE_OTEL_SAMPLING_RATE: float = Field(
|
||||
description="Sampling rate for enterprise traces (0.0 to 1.0, default 1.0 = 100%).",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
@@ -1155,16 +1155,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
default=0,
|
||||
)
|
||||
|
||||
# API token last_used_at batch update
|
||||
ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK: bool = Field(
|
||||
description="Enable periodic batch update of API token last_used_at timestamps",
|
||||
default=True,
|
||||
)
|
||||
API_TOKEN_LAST_USED_UPDATE_INTERVAL: int = Field(
|
||||
description="Interval in minutes for batch updating API token last_used_at (default 30)",
|
||||
default=30,
|
||||
)
|
||||
|
||||
# Trigger provider refresh (simple version)
|
||||
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
|
||||
description="Enable trigger provider refresh poller",
|
||||
|
||||
@@ -10,7 +10,6 @@ from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
@@ -132,11 +131,6 @@ class BaseApiKeyResource(Resource):
|
||||
if key is None:
|
||||
flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found")
|
||||
|
||||
# Invalidate cache before deleting from database
|
||||
# Type assertion: key is guaranteed to be non-None here because abort() raises
|
||||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@@ -660,6 +660,19 @@ class AppCopyApi(Resource):
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Inherit web app permission from original app
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
try:
|
||||
# Get the original app's access mode
|
||||
original_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_model.id)
|
||||
access_mode = original_settings.access_mode
|
||||
except Exception:
|
||||
# If original app has no settings (old app), default to public to match fallback behavior
|
||||
access_mode = "public"
|
||||
|
||||
# Apply the same access mode to the copied app
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, access_mode)
|
||||
|
||||
stmt = select(App).where(App.id == result.app_id)
|
||||
app = session.scalar(stmt)
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ class InstructionGeneratePayload(BaseModel):
|
||||
instruction: str = Field(..., description="Instruction for generation")
|
||||
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
|
||||
ideal_output: str = Field(default="", description="Expected ideal output")
|
||||
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
|
||||
|
||||
|
||||
class InstructionTemplatePayload(BaseModel):
|
||||
@@ -66,10 +67,17 @@ class RuleGenerateApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = RuleGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
account, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
|
||||
rules = LLMGenerator.generate_rule_config(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=args.no_variable,
|
||||
user_id=account.id,
|
||||
app_id=args.app_id,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
@@ -95,12 +103,16 @@ class RuleCodeGenerateApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
account, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
args=args,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.code_language,
|
||||
user_id=account.id,
|
||||
app_id=args.app_id,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -127,12 +139,15 @@ class RuleStructuredOutputGenerateApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
account, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=current_tenant_id,
|
||||
args=args,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
user_id=account.id,
|
||||
app_id=args.app_id,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -159,14 +174,14 @@ class InstructionGenerateApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
account, current_tenant_id = current_account_with_tenant()
|
||||
app_id = args.app_id or args.flow_id
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args.language)), None
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.query(App).where(App.id == args.flow_id).first()
|
||||
if not app:
|
||||
@@ -183,33 +198,33 @@ class InstructionGenerateApi(Resource):
|
||||
case "llm":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
args=RuleGeneratePayload(
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
),
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
user_id=account.id,
|
||||
app_id=app_id,
|
||||
)
|
||||
case "agent":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
args=RuleGeneratePayload(
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
),
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
user_id=account.id,
|
||||
app_id=app_id,
|
||||
)
|
||||
case "code":
|
||||
return LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
args=RuleCodeGeneratePayload(
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.language,
|
||||
),
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.language,
|
||||
user_id=account.id,
|
||||
app_id=app_id,
|
||||
)
|
||||
case _:
|
||||
return {"error": f"invalid node type: {node_type}"}
|
||||
if args.node_id == "" and args.current != "": # For legacy app without a workflow
|
||||
if args.node_id == "" and args.current != "":
|
||||
return LLMGenerator.instruction_modify_legacy(
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args.flow_id,
|
||||
@@ -217,8 +232,10 @@ class InstructionGenerateApi(Resource):
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
user_id=account.id,
|
||||
app_id=app_id,
|
||||
)
|
||||
if args.node_id != "" and args.current != "": # For workflow node
|
||||
if args.node_id != "" and args.current != "":
|
||||
return LLMGenerator.instruction_modify_workflow(
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args.flow_id,
|
||||
@@ -228,6 +245,8 @@ class InstructionGenerateApi(Resource):
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
workflow_service=WorkflowService(),
|
||||
user_id=account.id,
|
||||
app_id=app_id,
|
||||
)
|
||||
return {"error": "incompatible parameters"}, 400
|
||||
except ProviderTokenNotInitError as ex:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest
|
||||
@@ -77,7 +78,10 @@ class TraceAppConfigApi(Resource):
|
||||
|
||||
try:
|
||||
result = OpsService.create_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
app_id=app_id,
|
||||
tracing_provider=args.tracing_provider,
|
||||
tracing_config=args.tracing_config,
|
||||
account_id=current_user.id,
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigIsExist()
|
||||
@@ -102,7 +106,10 @@ class TraceAppConfigApi(Resource):
|
||||
|
||||
try:
|
||||
result = OpsService.update_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
app_id=app_id,
|
||||
tracing_provider=args.tracing_provider,
|
||||
tracing_config=args.tracing_config,
|
||||
account_id=current_user.id,
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
@@ -124,7 +131,9 @@ class TraceAppConfigApi(Resource):
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||
result = OpsService.delete_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, account_id=current_user.id
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@@ -55,7 +55,6 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
@@ -821,11 +820,6 @@ class DatasetApiDeleteApi(Resource):
|
||||
if key is None:
|
||||
console_ns.abort(404, message="API key not found")
|
||||
|
||||
# Invalidate cache before deleting from database
|
||||
# Type assertion: key is guaranteed to be non-None here because abort() raises
|
||||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import logging
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import Parameters as ParametersResponse
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
@@ -118,56 +117,7 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel
|
||||
workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy)
|
||||
|
||||
|
||||
# Pydantic models for request validation
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowRunRequest(BaseModel):
|
||||
inputs: dict
|
||||
files: list | None = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
inputs: dict
|
||||
query: str
|
||||
files: list | None = None
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
class TextToSpeechRequest(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
inputs: dict
|
||||
query: str = ""
|
||||
files: list | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
# Register schemas for Swagger documentation
|
||||
console_ns.schema_model(
|
||||
WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
|
||||
def post(self, trial_app):
|
||||
"""
|
||||
Run workflow
|
||||
@@ -179,8 +129,10 @@ class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
request_data = WorkflowRunRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
assert current_user is not None
|
||||
try:
|
||||
app_id = app_model.id
|
||||
@@ -231,7 +183,6 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialChatApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[ChatRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
@@ -239,14 +190,14 @@ class TrialChatApi(TrialAppResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
request_data = ChatRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
|
||||
# Validate UUID values if provided
|
||||
if args.get("conversation_id"):
|
||||
args["conversation_id"] = uuid_value(args["conversation_id"])
|
||||
if args.get("parent_message_id"):
|
||||
args["parent_message_id"] = uuid_value(args["parent_message_id"])
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@@ -369,16 +320,20 @@ class TrialChatAudioApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialChatTextApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[TextToSpeechRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
try:
|
||||
request_data = TextToSpeechRequest.model_validate(console_ns.payload)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = request_data.message_id
|
||||
text = request_data.text
|
||||
voice = request_data.voice
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
@@ -416,15 +371,19 @@ class TrialChatTextApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialCompletionApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
request_data = CompletionRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@@ -120,7 +120,7 @@ class TagUpdateDeleteApi(Resource):
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
|
||||
@@ -878,7 +878,11 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
|
||||
return BuiltinToolManageService.set_default_provider(
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id
|
||||
tenant_id=current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
provider=provider,
|
||||
id=payload.id,
|
||||
account=current_user,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -396,7 +396,7 @@ class DatasetApi(DatasetApiResource):
|
||||
try:
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
return "", 204
|
||||
return 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
except services.errors.dataset.DatasetInUseError:
|
||||
@@ -557,7 +557,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag(payload.tag_id)
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags/binding")
|
||||
@@ -581,7 +581,7 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags/unbinding")
|
||||
@@ -605,7 +605,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/tags")
|
||||
|
||||
@@ -746,4 +746,4 @@ class DocumentApi(DatasetApiResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
@@ -128,7 +128,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in")
|
||||
|
||||
@@ -233,7 +233,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_segment")
|
||||
@@ -499,7 +499,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
except ChildChunkDeleteIndexServiceError as e:
|
||||
raise ChildChunkDeleteIndexError(str(e))
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_child_chunk")
|
||||
|
||||
@@ -1,24 +1,27 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache, fetch_token_with_single_flight, record_token_usage
|
||||
from services.end_user_service import EndUserService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@@ -293,14 +296,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
|
||||
def validate_and_get_api_token(scope: str | None = None):
|
||||
"""
|
||||
Validate and get API token with Redis caching.
|
||||
|
||||
This function uses a two-tier approach:
|
||||
1. First checks Redis cache for the token
|
||||
2. If not cached, queries database and caches the result
|
||||
|
||||
The last_used_at field is updated asynchronously via Celery task
|
||||
to avoid blocking the request.
|
||||
Validate and get API token.
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None or " " not in auth_header:
|
||||
@@ -312,18 +308,29 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||
|
||||
# Try to get token from cache first
|
||||
# Returns a CachedApiToken (plain Python object), not a SQLAlchemy model
|
||||
cached_token = ApiTokenCache.get(auth_token, scope)
|
||||
if cached_token is not None:
|
||||
logger.debug("Token validation served from cache for scope: %s", scope)
|
||||
# Record usage in Redis for later batch update (no Celery task per request)
|
||||
record_token_usage(auth_token, scope)
|
||||
return cast(ApiToken, cached_token)
|
||||
current_time = naive_utc_now()
|
||||
cutoff_time = current_time - timedelta(minutes=1)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
update_stmt = (
|
||||
update(ApiToken)
|
||||
.where(
|
||||
ApiToken.token == auth_token,
|
||||
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
|
||||
ApiToken.type == scope,
|
||||
)
|
||||
.values(last_used_at=current_time)
|
||||
)
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
result = session.execute(update_stmt)
|
||||
api_token = session.scalar(stmt)
|
||||
|
||||
# Cache miss - use Redis lock for single-flight mode
|
||||
# This ensures only one request queries DB for the same token concurrently
|
||||
return fetch_token_with_single_flight(auth_token, scope)
|
||||
if hasattr(result, "rowcount") and result.rowcount > 0:
|
||||
session.commit()
|
||||
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
return api_token
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
|
||||
@@ -79,7 +79,7 @@ class BaseAgentRunner(AppRunner):
|
||||
self.model_instance = model_instance
|
||||
|
||||
# init callback
|
||||
self.agent_callback = DifyAgentCallbackHandler()
|
||||
self.agent_callback = DifyAgentCallbackHandler(tenant_id=tenant_id)
|
||||
# init dataset tools
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager=queue_manager,
|
||||
|
||||
@@ -63,6 +63,8 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
@@ -564,7 +566,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle stop events."""
|
||||
_ = trace_manager
|
||||
resolved_state = None
|
||||
if self._workflow_run_id:
|
||||
resolved_state = self._resolve_graph_runtime_state(graph_runtime_state)
|
||||
@@ -579,8 +580,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
# Save message
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif event.stopped_by in (
|
||||
@@ -589,8 +589,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
):
|
||||
# When hitting input-moderation or annotation-reply, the workflow will not start
|
||||
with self._database_session() as session:
|
||||
# Save message
|
||||
self._save_message(session=session)
|
||||
self._save_message(session=session, trace_manager=trace_manager)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
@@ -599,6 +598,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
event: QueueAdvancedChatMessageEndEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle advanced chat message end events."""
|
||||
@@ -616,7 +616,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
# Save message
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
@@ -770,7 +770,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if self._conversation_name_generate_thread:
|
||||
logger.debug("Conversation name generation running as daemon thread")
|
||||
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
|
||||
def _save_message(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
):
|
||||
message = self._get_message(session=session)
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
@@ -826,6 +832,22 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
if trace_manager:
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
),
|
||||
payload={
|
||||
"conversation_id": str(message.conversation_id),
|
||||
"message_id": str(message.id),
|
||||
},
|
||||
),
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
@@ -147,9 +147,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
|
||||
extras = {
|
||||
extras: dict[str, Any] = {
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
parent_trace_context = args.get("_parent_trace_context")
|
||||
if parent_trace_context:
|
||||
extras["parent_trace_context"] = parent_trace_context
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
||||
# trigger shouldn't prepare user inputs
|
||||
|
||||
@@ -45,6 +45,8 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.file import helpers as file_helpers
|
||||
from core.file.enums import FileTransferMethod
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
@@ -52,14 +54,16 @@ from core.model_runtime.entities.message_entities import (
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from core.tools.signature import sign_tool_file
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -409,10 +413,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
|
||||
)
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
),
|
||||
payload={
|
||||
"conversation_id": self._conversation_id,
|
||||
"message_id": self._message_id,
|
||||
},
|
||||
),
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
message_was_created.send(
|
||||
@@ -463,6 +476,85 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
|
||||
def _record_files(self):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
|
||||
if not message_files:
|
||||
return None
|
||||
|
||||
files_list = []
|
||||
upload_file_ids = [
|
||||
mf.upload_file_id
|
||||
for mf in message_files
|
||||
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
|
||||
]
|
||||
upload_files_map = {}
|
||||
if upload_file_ids:
|
||||
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
|
||||
upload_files_map = {uf.id: uf for uf in upload_files}
|
||||
|
||||
for message_file in message_files:
|
||||
upload_file = None
|
||||
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
|
||||
upload_file = upload_files_map.get(message_file.upload_file_id)
|
||||
|
||||
url = None
|
||||
filename = "file"
|
||||
mime_type = "application/octet-stream"
|
||||
size = 0
|
||||
extension = ""
|
||||
|
||||
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = message_file.url
|
||||
if message_file.url:
|
||||
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
|
||||
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if upload_file:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||
filename = upload_file.name
|
||||
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||
size = upload_file.size or 0
|
||||
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||
elif message_file.upload_file_id:
|
||||
# Fallback: generate URL even if upload_file not found
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
|
||||
# For tool files, use URL directly if it's HTTP, otherwise sign it
|
||||
if message_file.url.startswith("http"):
|
||||
url = message_file.url
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
else:
|
||||
# Extract tool file id and extension from URL
|
||||
url_parts = message_file.url.split("/")
|
||||
if url_parts:
|
||||
file_part = url_parts[-1].split("?")[0] # Remove query params first
|
||||
# Use rsplit to correctly handle filenames with multiple dots
|
||||
if "." in file_part:
|
||||
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
else:
|
||||
tool_file_id = file_part
|
||||
extension = ".bin"
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
filename = file_part
|
||||
|
||||
transfer_method_value = message_file.transfer_method
|
||||
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
|
||||
file_dict = {
|
||||
"related_id": message_file.id,
|
||||
"extension": extension,
|
||||
"filename": filename,
|
||||
"size": size,
|
||||
"mime_type": mime_type,
|
||||
"transfer_method": transfer_method_value,
|
||||
"type": message_file.type,
|
||||
"url": url or "",
|
||||
"upload_file_id": message_file.upload_file_id or message_file.id,
|
||||
"remote_url": remote_url,
|
||||
}
|
||||
files_list.append(file_dict)
|
||||
return files_list or None
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
"""
|
||||
Agent message to stream response.
|
||||
|
||||
@@ -64,7 +64,13 @@ class MessageCycleManager:
|
||||
|
||||
# Use SQLAlchemy 2.x style session.scalar(select(...))
|
||||
with session_factory.create_session() as session:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id))
|
||||
message_file = session.scalar(
|
||||
select(MessageFile)
|
||||
.where(
|
||||
MessageFile.message_id == message_id,
|
||||
)
|
||||
.where(MessageFile.belongs_to == "assistant")
|
||||
)
|
||||
|
||||
if message_file:
|
||||
self._message_has_file.add(message_id)
|
||||
|
||||
@@ -15,8 +15,7 @@ from datetime import datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from core.workflow.enums import (
|
||||
@@ -373,6 +372,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
self._enqueue_node_trace_task(domain_execution)
|
||||
|
||||
def _fail_running_node_executions(self, *, error_message: str) -> None:
|
||||
now = naive_utc_now()
|
||||
@@ -390,17 +390,138 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
|
||||
external_trace_id = None
|
||||
parent_trace_context = None
|
||||
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
|
||||
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
|
||||
parent_trace_context = self._application_generate_entity.extras.get("parent_trace_context")
|
||||
|
||||
trace_task = TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_execution=execution,
|
||||
conversation_id=conversation_id,
|
||||
user_id=self._trace_manager.user_id,
|
||||
external_trace_id=external_trace_id,
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
user_id=self._trace_manager.user_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
),
|
||||
payload={
|
||||
"workflow_execution": execution,
|
||||
"conversation_id": conversation_id,
|
||||
"user_id": self._trace_manager.user_id,
|
||||
"external_trace_id": external_trace_id,
|
||||
"parent_trace_context": parent_trace_context,
|
||||
},
|
||||
),
|
||||
trace_manager=self._trace_manager,
|
||||
)
|
||||
|
||||
def _enqueue_node_trace_task(self, domain_execution: WorkflowNodeExecution) -> None:
|
||||
if not self._trace_manager:
|
||||
return
|
||||
|
||||
execution = self._get_workflow_execution()
|
||||
meta = domain_execution.metadata or {}
|
||||
|
||||
parent_trace_context = None
|
||||
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
|
||||
parent_trace_context = self._application_generate_entity.extras.get("parent_trace_context")
|
||||
|
||||
node_data: dict[str, Any] = {
|
||||
"workflow_id": domain_execution.workflow_id,
|
||||
"workflow_execution_id": execution.id_,
|
||||
"tenant_id": self._application_generate_entity.app_config.tenant_id,
|
||||
"app_id": self._application_generate_entity.app_config.app_id,
|
||||
"node_execution_id": domain_execution.id,
|
||||
"node_id": domain_execution.node_id,
|
||||
"node_type": str(domain_execution.node_type.value),
|
||||
"title": domain_execution.title,
|
||||
"status": str(domain_execution.status.value),
|
||||
"error": domain_execution.error,
|
||||
"elapsed_time": domain_execution.elapsed_time,
|
||||
"index": domain_execution.index,
|
||||
"predecessor_node_id": domain_execution.predecessor_node_id,
|
||||
"created_at": domain_execution.created_at,
|
||||
"finished_at": domain_execution.finished_at,
|
||||
"total_tokens": meta.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0),
|
||||
"prompt_tokens": meta.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS),
|
||||
"completion_tokens": meta.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS),
|
||||
"total_price": meta.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0),
|
||||
"currency": meta.get(WorkflowNodeExecutionMetadataKey.CURRENCY),
|
||||
"tool_name": (meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name")
|
||||
if isinstance(meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict)
|
||||
else None,
|
||||
"iteration_id": meta.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID),
|
||||
"iteration_index": meta.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX),
|
||||
"loop_id": meta.get(WorkflowNodeExecutionMetadataKey.LOOP_ID),
|
||||
"loop_index": meta.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX),
|
||||
"parallel_id": meta.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID),
|
||||
"node_inputs": dict(domain_execution.inputs) if domain_execution.inputs else None,
|
||||
"node_outputs": dict(domain_execution.outputs) if domain_execution.outputs else None,
|
||||
"process_data": dict(domain_execution.process_data) if domain_execution.process_data else None,
|
||||
}
|
||||
node_data["invoke_from"] = self._application_generate_entity.invoke_from.value
|
||||
node_data["user_id"] = self._system_variables().get(SystemVariableKey.USER_ID.value)
|
||||
|
||||
# Extract model info from process_data — LLM nodes store provider/model there,
|
||||
if domain_execution.process_data:
|
||||
if mp := domain_execution.process_data.get("model_provider"):
|
||||
node_data["model_provider"] = mp
|
||||
if mn := domain_execution.process_data.get("model_name"):
|
||||
node_data["model_name"] = mn
|
||||
|
||||
if domain_execution.node_type.value == "knowledge-retrieval" and domain_execution.outputs:
|
||||
results = domain_execution.outputs.get("result") or []
|
||||
dataset_ids: list[str] = []
|
||||
dataset_names: list[str] = []
|
||||
for doc in results:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
doc_meta = doc.get("metadata") or {}
|
||||
did = doc_meta.get("dataset_id")
|
||||
dname = doc_meta.get("dataset_name")
|
||||
if did and did not in dataset_ids:
|
||||
dataset_ids.append(did)
|
||||
if dname and dname not in dataset_names:
|
||||
dataset_names.append(dname)
|
||||
if dataset_ids:
|
||||
node_data["dataset_ids"] = dataset_ids
|
||||
if dataset_names:
|
||||
node_data["dataset_names"] = dataset_names
|
||||
|
||||
tool_info = meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO)
|
||||
if isinstance(tool_info, dict):
|
||||
plugin_id = tool_info.get("plugin_unique_identifier")
|
||||
if plugin_id:
|
||||
node_data["plugin_name"] = plugin_id
|
||||
credential_id = tool_info.get("credential_id")
|
||||
if credential_id:
|
||||
node_data["credential_id"] = credential_id
|
||||
node_data["credential_provider_type"] = tool_info.get("provider_type")
|
||||
|
||||
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
|
||||
if conversation_id:
|
||||
node_data["conversation_id"] = conversation_id
|
||||
|
||||
if parent_trace_context:
|
||||
node_data["parent_trace_context"] = parent_trace_context
|
||||
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.NODE_EXECUTION_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=node_data.get("tenant_id"),
|
||||
user_id=node_data.get("user_id"),
|
||||
app_id=node_data.get("app_id"),
|
||||
),
|
||||
payload={"node_execution_data": node_data},
|
||||
),
|
||||
trace_manager=self._trace_manager,
|
||||
)
|
||||
self._trace_manager.add_trace_task(trace_task)
|
||||
|
||||
def _system_variables(self) -> Mapping[str, Any]:
|
||||
runtime_state = self.graph_runtime_state
|
||||
|
||||
@@ -47,7 +47,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits | None = None,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
template_transform_max_output_length: int | None = None,
|
||||
http_request_http_client: HttpClientProtocol | None = None,
|
||||
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
http_request_file_manager: FileManagerProtocol | None = None,
|
||||
@@ -69,9 +68,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
self._template_transform_max_output_length = (
|
||||
template_transform_max_output_length or dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
)
|
||||
self._http_request_http_client = http_request_http_client or ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
|
||||
self._http_request_file_manager = http_request_file_manager or file_manager
|
||||
@@ -126,7 +122,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
template_renderer=self._template_renderer,
|
||||
max_output_length=self._template_transform_max_output_length,
|
||||
)
|
||||
|
||||
if node_type == NodeType.HTTP_REQUEST:
|
||||
|
||||
@@ -4,8 +4,9 @@ from typing import Any, TextIO, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
@@ -36,13 +37,15 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
|
||||
color: str | None = ""
|
||||
current_loop: int = 1
|
||||
tenant_id: str | None = None
|
||||
|
||||
def __init__(self, color: str | None = None):
|
||||
def __init__(self, color: str | None = None, tenant_id: str | None = None):
|
||||
super().__init__()
|
||||
"""Initialize callback handler."""
|
||||
# use a specific color is not specified
|
||||
self.color = color or "green"
|
||||
self.current_loop = 1
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
@@ -71,15 +74,23 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
print_text("\n")
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.TOOL_TRACE,
|
||||
message_id=message_id,
|
||||
tool_name=tool_name,
|
||||
tool_inputs=tool_inputs,
|
||||
tool_outputs=tool_outputs,
|
||||
timer=timer,
|
||||
)
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.TOOL_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=trace_manager.app_id,
|
||||
user_id=trace_manager.user_id,
|
||||
),
|
||||
payload={
|
||||
"message_id": message_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_inputs": tool_inputs,
|
||||
"tool_outputs": tool_outputs,
|
||||
"timer": timer,
|
||||
},
|
||||
),
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any):
|
||||
|
||||
@@ -6,8 +6,7 @@ from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.download import download_with_size_limit
|
||||
from core.plugin.entities.marketplace import MarketplacePluginDeclaration, MarketplacePluginSnapshot
|
||||
from extensions.ext_redis import redis_client
|
||||
from core.plugin.entities.marketplace import MarketplacePluginDeclaration
|
||||
|
||||
marketplace_api_url = URL(str(dify_config.MARKETPLACE_API_URL))
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,37 +43,28 @@ def batch_fetch_plugin_by_ids(plugin_ids: list[str]) -> list[dict]:
|
||||
return data.get("data", {}).get("plugins", [])
|
||||
|
||||
|
||||
def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
plugin_ids: list[str],
|
||||
) -> Sequence[MarketplacePluginDeclaration]:
|
||||
if len(plugin_ids) == 0:
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
|
||||
response.raise_for_status()
|
||||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin in response.json()["data"]["plugins"]:
|
||||
try:
|
||||
result.append(MarketplacePluginDeclaration.model_validate(plugin))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to deserialize marketplace plugin manifest for %s", plugin.get("plugin_id", "unknown")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def record_install_plugin_event(plugin_unique_identifier: str):
|
||||
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
|
||||
response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def fetch_global_plugin_manifest(cache_key_prefix: str, cache_ttl: int) -> None:
|
||||
"""
|
||||
Fetch all plugin manifests from marketplace and cache them in Redis.
|
||||
This should be called once per check cycle to populate the instance-level cache.
|
||||
|
||||
Args:
|
||||
cache_key_prefix: Redis key prefix for caching plugin manifests
|
||||
cache_ttl: Cache TTL in seconds
|
||||
|
||||
Raises:
|
||||
httpx.HTTPError: If the HTTP request fails
|
||||
Exception: If any other error occurs during fetching or caching
|
||||
"""
|
||||
url = str(marketplace_api_url / "api/v1/dist/plugins/manifest.json")
|
||||
response = httpx.get(url, headers={"X-Dify-Version": dify_config.project.version}, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
raw_json = response.json()
|
||||
plugins_data = raw_json.get("plugins", [])
|
||||
|
||||
# Parse and cache all plugin snapshots
|
||||
for plugin_data in plugins_data:
|
||||
plugin_snapshot = MarketplacePluginSnapshot.model_validate(plugin_data)
|
||||
redis_client.setex(
|
||||
name=f"{cache_key_prefix}{plugin_snapshot.plugin_id}",
|
||||
time=cache_ttl,
|
||||
value=plugin_snapshot.model_dump_json(),
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ class RuleGeneratePayload(BaseModel):
|
||||
instruction: str = Field(..., description="Rule generation instruction")
|
||||
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
|
||||
no_variable: bool = Field(default=False, description="Whether to exclude variables")
|
||||
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
|
||||
|
||||
|
||||
class RuleCodeGeneratePayload(RuleGeneratePayload):
|
||||
@@ -18,3 +19,4 @@ class RuleCodeGeneratePayload(RuleGeneratePayload):
|
||||
class RuleStructuredOutputPayload(BaseModel):
|
||||
instruction: str = Field(..., description="Structured output generation instruction")
|
||||
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
|
||||
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Protocol, cast
|
||||
import json_repair
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
from core.llm_generator.prompts import (
|
||||
@@ -27,10 +26,11 @@ from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.entities.trace_entity import OperationType
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
@@ -74,7 +74,7 @@ class LLMGenerator:
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
)
|
||||
answer = response.message.get_text_content()
|
||||
if answer == "":
|
||||
if answer is None:
|
||||
return ""
|
||||
try:
|
||||
result_dict = json.loads(answer)
|
||||
@@ -96,15 +96,17 @@ class LLMGenerator:
|
||||
name = name[:75] + "..."
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_id=app_id)
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.GENERATE_NAME_TRACE,
|
||||
conversation_id=conversation_id,
|
||||
generate_conversation_name=name,
|
||||
inputs=prompt,
|
||||
timer=timer,
|
||||
tenant_id=tenant_id,
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE,
|
||||
context=TelemetryContext(tenant_id=tenant_id, app_id=app_id),
|
||||
payload={
|
||||
"conversation_id": conversation_id,
|
||||
"generate_conversation_name": name,
|
||||
"inputs": prompt,
|
||||
"timer": timer,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -153,19 +155,27 @@ class LLMGenerator:
|
||||
return questions
|
||||
|
||||
@classmethod
|
||||
def generate_rule_config(cls, tenant_id: str, args: RuleGeneratePayload):
|
||||
def generate_rule_config(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
instruction: str,
|
||||
model_config: ModelConfig,
|
||||
no_variable: bool,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
output_parser = RuleConfigGeneratorOutputParser()
|
||||
|
||||
error = ""
|
||||
error_step = ""
|
||||
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
|
||||
model_parameters = args.model_config_data.completion_params
|
||||
if args.no_variable:
|
||||
model_parameters = model_config.completion_params
|
||||
if no_variable:
|
||||
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
|
||||
|
||||
prompt_generate = prompt_template.format(
|
||||
inputs={
|
||||
"TASK_DESCRIPTION": args.instruction,
|
||||
"TASK_DESCRIPTION": instruction,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
@@ -177,26 +187,45 @@ class LLMGenerator:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=args.model_config_data.provider,
|
||||
model=args.model_config_data.name,
|
||||
provider=model_config.provider,
|
||||
model=model_config.name,
|
||||
)
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
llm_result = None
|
||||
with measure_time() as timer:
|
||||
try:
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
rule_config["prompt"] = response.message.get_text_content()
|
||||
rule_config["prompt"] = llm_result.message.get_text_content() or ""
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate rule config"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||
rule_config["error"] = str(e)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate rule config"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", model_config.name)
|
||||
rule_config["error"] = str(e)
|
||||
error = str(e)
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
if user_id:
|
||||
prompt_value = rule_config.get("prompt", "")
|
||||
generated_output = str(prompt_value) if prompt_value else ""
|
||||
cls._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.RULE_GENERATE,
|
||||
instruction=instruction,
|
||||
generated_output=generated_output,
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error or None,
|
||||
)
|
||||
|
||||
return rule_config
|
||||
|
||||
# get rule config prompt, parameter and statement
|
||||
@@ -211,7 +240,7 @@ class LLMGenerator:
|
||||
# format the prompt_generate_prompt
|
||||
prompt_generate_prompt = prompt_template.format(
|
||||
inputs={
|
||||
"TASK_DESCRIPTION": args.instruction,
|
||||
"TASK_DESCRIPTION": instruction,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
@@ -222,84 +251,125 @@ class LLMGenerator:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=args.model_config_data.provider,
|
||||
model=args.model_config_data.name,
|
||||
provider=model_config.provider,
|
||||
model=model_config.name,
|
||||
)
|
||||
|
||||
try:
|
||||
llm_result = None
|
||||
with measure_time() as timer:
|
||||
try:
|
||||
# the first step to generate the task prompt
|
||||
prompt_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
try:
|
||||
# the first step to generate the task prompt
|
||||
prompt_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
llm_result = prompt_content
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate prefix prompt"
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
if user_id:
|
||||
cls._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.RULE_GENERATE,
|
||||
instruction=instruction,
|
||||
generated_output="",
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error,
|
||||
)
|
||||
|
||||
return rule_config
|
||||
|
||||
rule_config["prompt"] = prompt_content.message.get_text_content() or ""
|
||||
|
||||
if not isinstance(prompt_content.message.content, str):
|
||||
raise NotImplementedError("prompt content is not a string")
|
||||
parameter_generate_prompt = parameter_template.format(
|
||||
inputs={
|
||||
"INPUT_TEXT": prompt_content.message.content,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate prefix prompt"
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)]
|
||||
|
||||
return rule_config
|
||||
|
||||
rule_config["prompt"] = prompt_content.message.get_text_content()
|
||||
|
||||
parameter_generate_prompt = parameter_template.format(
|
||||
inputs={
|
||||
"INPUT_TEXT": prompt_content.message.get_text_content(),
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)]
|
||||
|
||||
# the second step to generate the task_parameter and task_statement
|
||||
statement_generate_prompt = statement_template.format(
|
||||
inputs={
|
||||
"TASK_DESCRIPTION": args.instruction,
|
||||
"INPUT_TEXT": prompt_content.message.get_text_content(),
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
|
||||
|
||||
try:
|
||||
parameter_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
# the second step to generate the task_parameter and task_statement
|
||||
statement_generate_prompt = statement_template.format(
|
||||
inputs={
|
||||
"TASK_DESCRIPTION": instruction,
|
||||
"INPUT_TEXT": prompt_content.message.content,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content())
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate variables"
|
||||
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
|
||||
|
||||
try:
|
||||
statement_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["opening_statement"] = statement_content.message.get_text_content()
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate conversation opener"
|
||||
try:
|
||||
parameter_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["variables"] = re.findall(
|
||||
r'"\s*([^"]+)\s*"', prompt_content.message.get_text_content() or ""
|
||||
)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate variables"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||
rule_config["error"] = str(e)
|
||||
try:
|
||||
statement_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["opening_statement"] = statement_content.message.get_text_content() or ""
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate conversation opener"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", model_config.name)
|
||||
rule_config["error"] = str(e)
|
||||
error = str(e)
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
if user_id:
|
||||
generated_output = rule_config.get("prompt", "")
|
||||
cls._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.RULE_GENERATE,
|
||||
instruction=instruction,
|
||||
generated_output=str(generated_output) if generated_output else "",
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error or None,
|
||||
)
|
||||
|
||||
return rule_config
|
||||
|
||||
@classmethod
|
||||
def generate_code(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
args: RuleCodeGeneratePayload,
|
||||
instruction: str,
|
||||
model_config: ModelConfig,
|
||||
code_language: str = "javascript",
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
if args.code_language == "python":
|
||||
if code_language == "python":
|
||||
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
else:
|
||||
prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
inputs={
|
||||
"INSTRUCTION": args.instruction,
|
||||
"CODE_LANGUAGE": args.code_language,
|
||||
"INSTRUCTION": instruction,
|
||||
"CODE_LANGUAGE": code_language,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
@@ -308,28 +378,49 @@ class LLMGenerator:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=args.model_config_data.provider,
|
||||
model=args.model_config_data.name,
|
||||
provider=model_config.provider,
|
||||
model=model_config.name,
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
model_parameters = args.model_config_data.completion_params
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
model_parameters = model_config.completion_params
|
||||
|
||||
llm_result = None
|
||||
error = None
|
||||
with measure_time() as timer:
|
||||
try:
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_code = llm_result.message.get_text_content() or ""
|
||||
result = {"code": generated_code, "language": code_language, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
result = {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to invoke LLM model, model: %s, language: %s", model_config.name, code_language
|
||||
)
|
||||
error = str(e)
|
||||
result = {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
if user_id:
|
||||
cls._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.CODE_GENERATE,
|
||||
instruction=instruction,
|
||||
generated_output=result.get("code", ""),
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error,
|
||||
)
|
||||
|
||||
generated_code = response.message.get_text_content()
|
||||
return {"code": generated_code, "language": args.code_language, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"code": "", "language": args.code_language, "error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to invoke LLM model, model: %s, language: %s", args.model_config_data.name, args.code_language
|
||||
)
|
||||
return {"code": "", "language": args.code_language, "error": f"An unexpected error occurred: {str(e)}"}
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
|
||||
@@ -355,49 +446,81 @@ class LLMGenerator:
|
||||
raise TypeError("Expected LLMResult when stream=False")
|
||||
response = result
|
||||
|
||||
answer = response.message.get_text_content()
|
||||
answer = response.message.get_text_content() or ""
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
|
||||
def generate_structured_output(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
instruction: str,
|
||||
model_config: ModelConfig,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=args.model_config_data.provider,
|
||||
model=args.model_config_data.name,
|
||||
provider=model_config.provider,
|
||||
model=model_config.name,
|
||||
)
|
||||
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
|
||||
UserPromptMessage(content=args.instruction),
|
||||
UserPromptMessage(content=instruction),
|
||||
]
|
||||
model_parameters = args.model_config_data.completion_params
|
||||
model_parameters = model_config.completion_params
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
llm_result = None
|
||||
error = None
|
||||
result = {"output": "", "error": ""}
|
||||
|
||||
with measure_time() as timer:
|
||||
try:
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
raw_content = llm_result.message.content
|
||||
|
||||
if not isinstance(raw_content, str):
|
||||
raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}")
|
||||
|
||||
try:
|
||||
parsed_content = json.loads(raw_content)
|
||||
except json.JSONDecodeError:
|
||||
parsed_content = json_repair.loads(raw_content)
|
||||
|
||||
if not isinstance(parsed_content, dict | list):
|
||||
raise ValueError(f"Failed to parse structured output from llm: {raw_content}")
|
||||
|
||||
generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False)
|
||||
result = {"output": generated_json_schema, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
result = {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", model_config.name)
|
||||
error = str(e)
|
||||
result = {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
if user_id:
|
||||
cls._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.STRUCTURED_OUTPUT,
|
||||
instruction=instruction,
|
||||
generated_output=result.get("output", ""),
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error,
|
||||
)
|
||||
|
||||
raw_content = response.message.get_text_content()
|
||||
|
||||
try:
|
||||
parsed_content = json.loads(raw_content)
|
||||
except json.JSONDecodeError:
|
||||
parsed_content = json_repair.loads(raw_content)
|
||||
|
||||
if not isinstance(parsed_content, dict | list):
|
||||
raise ValueError(f"Failed to parse structured output from llm: {raw_content}")
|
||||
|
||||
generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False)
|
||||
return {"output": generated_json_schema, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", args.model_config_data.name)
|
||||
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def instruction_modify_legacy(
|
||||
@@ -407,12 +530,14 @@ class LLMGenerator:
|
||||
instruction: str,
|
||||
model_config: ModelConfig,
|
||||
ideal_output: str | None,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
last_run: Message | None = (
|
||||
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
|
||||
)
|
||||
if not last_run:
|
||||
return LLMGenerator.__instruction_modify_common(
|
||||
result = LLMGenerator.__instruction_modify_common(
|
||||
tenant_id=tenant_id,
|
||||
model_config=model_config,
|
||||
last_run=None,
|
||||
@@ -421,22 +546,28 @@ class LLMGenerator:
|
||||
instruction=instruction,
|
||||
node_type="llm",
|
||||
ideal_output=ideal_output,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
)
|
||||
last_run_dict = {
|
||||
"query": last_run.query,
|
||||
"answer": last_run.answer,
|
||||
"error": last_run.error,
|
||||
}
|
||||
return LLMGenerator.__instruction_modify_common(
|
||||
tenant_id=tenant_id,
|
||||
model_config=model_config,
|
||||
last_run=last_run_dict,
|
||||
current=current,
|
||||
error_message=str(last_run.error),
|
||||
instruction=instruction,
|
||||
node_type="llm",
|
||||
ideal_output=ideal_output,
|
||||
)
|
||||
else:
|
||||
last_run_dict = {
|
||||
"query": last_run.query,
|
||||
"answer": last_run.answer,
|
||||
"error": last_run.error,
|
||||
}
|
||||
result = LLMGenerator.__instruction_modify_common(
|
||||
tenant_id=tenant_id,
|
||||
model_config=model_config,
|
||||
last_run=last_run_dict,
|
||||
current=current,
|
||||
error_message=str(last_run.error),
|
||||
instruction=instruction,
|
||||
node_type="llm",
|
||||
ideal_output=ideal_output,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def instruction_modify_workflow(
|
||||
@@ -448,6 +579,8 @@ class LLMGenerator:
|
||||
model_config: ModelConfig,
|
||||
ideal_output: str | None,
|
||||
workflow_service: WorkflowServiceInterface,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
session = db.session()
|
||||
|
||||
@@ -478,6 +611,8 @@ class LLMGenerator:
|
||||
instruction=instruction,
|
||||
node_type=node_type,
|
||||
ideal_output=ideal_output,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
)
|
||||
|
||||
def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence:
|
||||
@@ -511,6 +646,8 @@ class LLMGenerator:
|
||||
instruction=instruction,
|
||||
node_type=last_run.node_type,
|
||||
ideal_output=ideal_output,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -523,6 +660,8 @@ class LLMGenerator:
|
||||
instruction: str,
|
||||
node_type: str,
|
||||
ideal_output: str | None,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
LAST_RUN = "{{#last_run#}}"
|
||||
CURRENT = "{{#current#}}"
|
||||
@@ -562,24 +701,120 @@ class LLMGenerator:
|
||||
]
|
||||
model_parameters = {"temperature": 0.4}
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
llm_result = None
|
||||
error = None
|
||||
result = {}
|
||||
|
||||
with measure_time() as timer:
|
||||
try:
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_raw = llm_result.message.get_text_content()
|
||||
first_brace = generated_raw.find("{")
|
||||
last_brace = generated_raw.rfind("}")
|
||||
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
|
||||
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
|
||||
json_str = generated_raw[first_brace : last_brace + 1]
|
||||
data = json_repair.loads(json_str)
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
|
||||
result = data
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
result = {"error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True)
|
||||
error = str(e)
|
||||
result = {"error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
if user_id:
|
||||
generated_output = ""
|
||||
if isinstance(result, dict):
|
||||
for key in ["prompt", "code", "output", "modified"]:
|
||||
if result.get(key):
|
||||
generated_output = str(result[key])
|
||||
break
|
||||
|
||||
LLMGenerator._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.INSTRUCTION_MODIFY,
|
||||
instruction=instruction,
|
||||
generated_output=generated_output,
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error,
|
||||
)
|
||||
|
||||
generated_raw = response.message.get_text_content()
|
||||
first_brace = generated_raw.find("{")
|
||||
last_brace = generated_raw.rfind("}")
|
||||
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
|
||||
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
|
||||
json_str = generated_raw[first_brace : last_brace + 1]
|
||||
data = json_repair.loads(json_str)
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
|
||||
return data
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True)
|
||||
return {"error": f"An unexpected error occurred: {str(e)}"}
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _emit_prompt_generation_trace(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
app_id: str | None,
|
||||
operation_type: OperationType,
|
||||
instruction: str,
|
||||
generated_output: str,
|
||||
llm_result: LLMResult | None,
|
||||
model_config: ModelConfig | None = None,
|
||||
timer=None,
|
||||
error: str | None = None,
|
||||
):
|
||||
if llm_result:
|
||||
prompt_tokens = llm_result.usage.prompt_tokens
|
||||
completion_tokens = llm_result.usage.completion_tokens
|
||||
total_tokens = llm_result.usage.total_tokens
|
||||
model_name = llm_result.model
|
||||
# Extract provider from model_config if available, otherwise fall back to parsing model name
|
||||
if model_config and model_config.provider:
|
||||
model_provider = model_config.provider
|
||||
else:
|
||||
model_provider = model_name.split("/")[0] if "/" in model_name else ""
|
||||
latency = llm_result.usage.latency
|
||||
total_price = float(llm_result.usage.total_price) if llm_result.usage.total_price else None
|
||||
currency = llm_result.usage.currency
|
||||
else:
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_tokens = 0
|
||||
model_provider = model_config.provider if model_config else ""
|
||||
model_name = model_config.name if model_config else ""
|
||||
latency = 0.0
|
||||
if timer:
|
||||
start_time = timer.get("start")
|
||||
end_time = timer.get("end")
|
||||
if start_time and end_time:
|
||||
latency = (end_time - start_time).total_seconds()
|
||||
total_price = None
|
||||
currency = None
|
||||
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.PROMPT_GENERATION_TRACE,
|
||||
context=TelemetryContext(tenant_id=tenant_id, user_id=user_id, app_id=app_id),
|
||||
payload={
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"app_id": app_id,
|
||||
"operation_type": operation_type,
|
||||
"instruction": instruction,
|
||||
"generated_output": generated_output,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"model_provider": model_provider,
|
||||
"model_name": model_name,
|
||||
"latency": latency,
|
||||
"total_price": total_price,
|
||||
"currency": currency,
|
||||
"timer": timer,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -15,16 +15,23 @@ class TraceContextFilter(logging.Filter):
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
# Get trace context from OpenTelemetry
|
||||
trace_id, span_id = self._get_otel_context()
|
||||
# Preserve explicit trace_id set by the caller (e.g. emit_metric_only_event)
|
||||
existing_trace_id = getattr(record, "trace_id", "")
|
||||
if not existing_trace_id:
|
||||
# Get trace context from OpenTelemetry
|
||||
trace_id, span_id = self._get_otel_context()
|
||||
|
||||
# Set trace_id (fallback to ContextVar if no OTEL context)
|
||||
if trace_id:
|
||||
record.trace_id = trace_id
|
||||
# Set trace_id (fallback to ContextVar if no OTEL context)
|
||||
if trace_id:
|
||||
record.trace_id = trace_id
|
||||
else:
|
||||
record.trace_id = get_trace_id()
|
||||
|
||||
record.span_id = span_id or ""
|
||||
else:
|
||||
record.trace_id = get_trace_id()
|
||||
|
||||
record.span_id = span_id or ""
|
||||
# Keep existing trace_id; only fill span_id if missing
|
||||
if not getattr(record, "span_id", ""):
|
||||
record.span_id = ""
|
||||
|
||||
# For backward compatibility, also set req_id
|
||||
record.req_id = get_request_id()
|
||||
@@ -55,9 +62,12 @@ class IdentityContextFilter(logging.Filter):
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
identity = self._extract_identity()
|
||||
record.tenant_id = identity.get("tenant_id", "")
|
||||
record.user_id = identity.get("user_id", "")
|
||||
record.user_type = identity.get("user_type", "")
|
||||
if not getattr(record, "tenant_id", ""):
|
||||
record.tenant_id = identity.get("tenant_id", "")
|
||||
if not getattr(record, "user_id", ""):
|
||||
record.user_id = identity.get("user_id", "")
|
||||
if not getattr(record, "user_type", ""):
|
||||
record.user_type = identity.get("user_type", "")
|
||||
return True
|
||||
|
||||
def _extract_identity(self) -> dict[str, str]:
|
||||
|
||||
@@ -5,9 +5,10 @@ from typing import Any
|
||||
from core.app.app_config.entities import AppConfig
|
||||
from core.moderation.base import ModerationAction, ModerationError
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.ops.utils import measure_time
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -49,14 +50,18 @@ class InputModeration:
|
||||
moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MODERATION_TRACE,
|
||||
message_id=message_id,
|
||||
moderation_result=moderation_result,
|
||||
inputs=inputs,
|
||||
timer=timer,
|
||||
)
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.MODERATION_TRACE,
|
||||
context=TelemetryContext(tenant_id=tenant_id, app_id=app_id),
|
||||
payload={
|
||||
"message_id": message_id,
|
||||
"moderation_result": moderation_result,
|
||||
"inputs": inputs,
|
||||
"timer": timer,
|
||||
},
|
||||
),
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
if not moderation_result.flagged:
|
||||
|
||||
@@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
||||
class BaseTraceInfo(BaseModel):
|
||||
message_id: str | None = None
|
||||
message_data: Any | None = None
|
||||
inputs: Union[str, dict[str, Any], list] | None = None
|
||||
outputs: Union[str, dict[str, Any], list] | None = None
|
||||
inputs: Union[str, dict[str, Any], list[Any]] | None = None
|
||||
outputs: Union[str, dict[str, Any], list[Any]] | None = None
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
metadata: dict[str, Any]
|
||||
@@ -18,7 +18,7 @@ class BaseTraceInfo(BaseModel):
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
def ensure_type(cls, v):
|
||||
def ensure_type(cls, v: str | dict[str, Any] | list[Any] | None) -> str | dict[str, Any] | list[Any] | None:
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str | dict | list):
|
||||
@@ -27,6 +27,48 @@ class BaseTraceInfo(BaseModel):
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def resolved_trace_id(self) -> str | None:
|
||||
"""Get trace_id with intelligent fallback.
|
||||
|
||||
Priority:
|
||||
1. External trace_id (from X-Trace-Id header)
|
||||
2. workflow_run_id (if this trace type has it)
|
||||
3. message_id (as final fallback)
|
||||
"""
|
||||
if self.trace_id:
|
||||
return self.trace_id
|
||||
|
||||
# Try workflow_run_id (only exists on workflow-related traces)
|
||||
workflow_run_id = getattr(self, "workflow_run_id", None)
|
||||
if workflow_run_id:
|
||||
return workflow_run_id
|
||||
|
||||
# Final fallback to message_id
|
||||
return str(self.message_id) if self.message_id else None
|
||||
|
||||
@property
|
||||
def resolved_parent_context(self) -> tuple[str | None, str | None]:
|
||||
"""Resolve cross-workflow parent linking from metadata.
|
||||
|
||||
Extracts typed parent IDs from the untyped ``parent_trace_context``
|
||||
metadata dict (set by tool_node when invoking nested workflows).
|
||||
|
||||
Returns:
|
||||
(trace_correlation_override, parent_span_id_source) where
|
||||
trace_correlation_override is the outer workflow_run_id and
|
||||
parent_span_id_source is the outer node_execution_id.
|
||||
"""
|
||||
parent_ctx = self.metadata.get("parent_trace_context")
|
||||
if not isinstance(parent_ctx, dict):
|
||||
return None, None
|
||||
trace_override = parent_ctx.get("parent_workflow_run_id")
|
||||
parent_span = parent_ctx.get("parent_node_execution_id")
|
||||
return (
|
||||
trace_override if isinstance(trace_override, str) else None,
|
||||
parent_span if isinstance(parent_span, str) else None,
|
||||
)
|
||||
|
||||
@field_serializer("start_time", "end_time")
|
||||
def serialize_datetime(self, dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
@@ -48,10 +90,14 @@ class WorkflowTraceInfo(BaseTraceInfo):
|
||||
workflow_run_version: str
|
||||
error: str | None = None
|
||||
total_tokens: int
|
||||
prompt_tokens: int | None = None
|
||||
completion_tokens: int | None = None
|
||||
file_list: list[str]
|
||||
query: str
|
||||
metadata: dict[str, Any]
|
||||
|
||||
invoked_by: str | None = None
|
||||
|
||||
|
||||
class MessageTraceInfo(BaseTraceInfo):
|
||||
conversation_model: str
|
||||
@@ -59,7 +105,7 @@ class MessageTraceInfo(BaseTraceInfo):
|
||||
answer_tokens: int
|
||||
total_tokens: int
|
||||
error: str | None = None
|
||||
file_list: Union[str, dict[str, Any], list] | None = None
|
||||
file_list: Union[str, dict[str, Any], list[Any]] | None = None
|
||||
message_file_data: Any | None = None
|
||||
conversation_mode: str
|
||||
gen_ai_server_time_to_first_token: float | None = None
|
||||
@@ -106,7 +152,7 @@ class ToolTraceInfo(BaseTraceInfo):
|
||||
tool_config: dict[str, Any]
|
||||
time_cost: Union[int, float]
|
||||
tool_parameters: dict[str, Any]
|
||||
file_url: Union[str, None, list] = None
|
||||
file_url: Union[str, None, list[str]] = None
|
||||
|
||||
|
||||
class GenerateNameTraceInfo(BaseTraceInfo):
|
||||
@@ -114,6 +160,79 @@ class GenerateNameTraceInfo(BaseTraceInfo):
|
||||
tenant_id: str
|
||||
|
||||
|
||||
class PromptGenerationTraceInfo(BaseTraceInfo):
|
||||
"""Trace information for prompt generation operations (rule-generate, code-generate, etc.)."""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
app_id: str | None = None
|
||||
|
||||
operation_type: str
|
||||
instruction: str
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
model_provider: str
|
||||
model_name: str
|
||||
|
||||
latency: float
|
||||
|
||||
total_price: float | None = None
|
||||
currency: str | None = None
|
||||
|
||||
error: str | None = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class WorkflowNodeTraceInfo(BaseTraceInfo):
|
||||
workflow_id: str
|
||||
workflow_run_id: str
|
||||
tenant_id: str
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
|
||||
status: str
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
|
||||
index: int
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
total_tokens: int = 0
|
||||
total_price: float = 0.0
|
||||
currency: str | None = None
|
||||
|
||||
model_provider: str | None = None
|
||||
model_name: str | None = None
|
||||
prompt_tokens: int | None = None
|
||||
completion_tokens: int | None = None
|
||||
|
||||
tool_name: str | None = None
|
||||
|
||||
iteration_id: str | None = None
|
||||
iteration_index: int | None = None
|
||||
loop_id: str | None = None
|
||||
loop_index: int | None = None
|
||||
parallel_id: str | None = None
|
||||
|
||||
node_inputs: Mapping[str, Any] | None = None
|
||||
node_outputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
|
||||
invoked_by: str | None = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class DraftNodeExecutionTrace(WorkflowNodeTraceInfo):
|
||||
pass
|
||||
|
||||
|
||||
class TaskData(BaseModel):
|
||||
app_id: str
|
||||
trace_info_type: str
|
||||
@@ -128,16 +247,38 @@ trace_info_info_map = {
|
||||
"DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo,
|
||||
"ToolTraceInfo": ToolTraceInfo,
|
||||
"GenerateNameTraceInfo": GenerateNameTraceInfo,
|
||||
"PromptGenerationTraceInfo": PromptGenerationTraceInfo,
|
||||
"WorkflowNodeTraceInfo": WorkflowNodeTraceInfo,
|
||||
"DraftNodeExecutionTrace": DraftNodeExecutionTrace,
|
||||
}
|
||||
|
||||
|
||||
class OperationType(StrEnum):
|
||||
"""Operation type for token metric labels.
|
||||
|
||||
Used as a metric attribute on ``dify.tokens.input`` / ``dify.tokens.output``
|
||||
counters so consumers can break down token usage by operation.
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
NODE_EXECUTION = "node_execution"
|
||||
MESSAGE = "message"
|
||||
RULE_GENERATE = "rule_generate"
|
||||
CODE_GENERATE = "code_generate"
|
||||
STRUCTURED_OUTPUT = "structured_output"
|
||||
INSTRUCTION_MODIFY = "instruction_modify"
|
||||
|
||||
|
||||
class TraceTaskName(StrEnum):
|
||||
CONVERSATION_TRACE = "conversation"
|
||||
WORKFLOW_TRACE = "workflow"
|
||||
DRAFT_NODE_EXECUTION_TRACE = "draft_node_execution"
|
||||
MESSAGE_TRACE = "message"
|
||||
MODERATION_TRACE = "moderation"
|
||||
SUGGESTED_QUESTION_TRACE = "suggested_question"
|
||||
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
|
||||
TOOL_TRACE = "tool"
|
||||
GENERATE_NAME_TRACE = "generate_conversation_name"
|
||||
PROMPT_GENERATION_TRACE = "prompt_generation"
|
||||
DATASOURCE_TRACE = "datasource"
|
||||
NODE_EXECUTION_TRACE = "node_execution"
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from langfuse import Langfuse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
@@ -30,7 +31,7 @@ from core.ops.utils import filter_none_values
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, Message, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import MessageStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -71,7 +72,50 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
metadata = trace_info.metadata
|
||||
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
||||
|
||||
if trace_info.message_id:
|
||||
# Check for parent_trace_context to detect nested workflow
|
||||
parent_trace_context = trace_info.metadata.get("parent_trace_context")
|
||||
|
||||
if parent_trace_context:
|
||||
# Nested workflow: create span under outer trace
|
||||
outer_trace_id = parent_trace_context.get("trace_id")
|
||||
parent_node_execution_id = parent_trace_context.get("parent_node_execution_id")
|
||||
parent_conversation_id = parent_trace_context.get("parent_conversation_id")
|
||||
parent_workflow_run_id = parent_trace_context.get("parent_workflow_run_id")
|
||||
|
||||
# Resolve outer trace_id: try message_id lookup first, fallback to workflow_run_id
|
||||
if parent_conversation_id:
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
with session_factory() as session:
|
||||
message_data_stmt = select(Message.id).where(
|
||||
Message.conversation_id == parent_conversation_id,
|
||||
Message.workflow_run_id == parent_workflow_run_id,
|
||||
)
|
||||
resolved_message_id = session.scalar(message_data_stmt)
|
||||
if resolved_message_id:
|
||||
outer_trace_id = resolved_message_id
|
||||
else:
|
||||
outer_trace_id = parent_workflow_run_id
|
||||
else:
|
||||
outer_trace_id = parent_workflow_run_id
|
||||
|
||||
# Create inner workflow span under outer trace
|
||||
workflow_span_data = LangfuseSpan(
|
||||
id=trace_info.workflow_run_id,
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
input=dict(trace_info.workflow_run_inputs),
|
||||
output=dict(trace_info.workflow_run_outputs),
|
||||
trace_id=outer_trace_id,
|
||||
parent_observation_id=parent_node_execution_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=metadata,
|
||||
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
|
||||
status_message=trace_info.error or "",
|
||||
)
|
||||
self.add_span(langfuse_span_data=workflow_span_data)
|
||||
# Use outer_trace_id for all node spans/generations
|
||||
trace_id = outer_trace_id
|
||||
elif trace_info.message_id:
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
name = TraceTaskName.MESSAGE_TRACE
|
||||
trace_data = LangfuseTrace(
|
||||
@@ -174,6 +218,11 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
}
|
||||
)
|
||||
|
||||
# Determine parent_observation_id for nested workflows
|
||||
node_parent_observation_id = None
|
||||
if parent_trace_context or trace_info.message_id:
|
||||
node_parent_observation_id = trace_info.workflow_run_id
|
||||
|
||||
# add generation span
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
total_token = metadata.get("total_tokens", 0)
|
||||
@@ -206,7 +255,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error or "",
|
||||
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
|
||||
parent_observation_id=node_parent_observation_id,
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
@@ -225,7 +274,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error or "",
|
||||
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
|
||||
parent_observation_id=node_parent_observation_id,
|
||||
)
|
||||
|
||||
self.add_span(langfuse_span_data=span_data)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import cast
|
||||
|
||||
from langsmith import Client
|
||||
from langsmith.schemas import RunBase
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
@@ -30,7 +31,7 @@ from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, Message, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,7 +65,35 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
# Check for parent_trace_context for cross-workflow linking
|
||||
parent_trace_context = trace_info.metadata.get("parent_trace_context")
|
||||
|
||||
if parent_trace_context:
|
||||
# Inner workflow: resolve outer trace_id and link to parent node
|
||||
outer_trace_id = parent_trace_context.get("parent_workflow_run_id")
|
||||
|
||||
# Try to resolve message_id from conversation_id if available
|
||||
if parent_trace_context.get("parent_conversation_id"):
|
||||
try:
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
with session_factory() as session:
|
||||
message_data_stmt = select(Message.id).where(
|
||||
Message.conversation_id == parent_trace_context["parent_conversation_id"],
|
||||
Message.workflow_run_id == parent_trace_context["parent_workflow_run_id"],
|
||||
)
|
||||
resolved_message_id = session.scalar(message_data_stmt)
|
||||
if resolved_message_id:
|
||||
outer_trace_id = resolved_message_id
|
||||
except Exception as e:
|
||||
logger.debug("Failed to resolve message_id from conversation_id: %s", str(e))
|
||||
|
||||
trace_id = outer_trace_id
|
||||
parent_run_id = parent_trace_context.get("parent_node_execution_id")
|
||||
else:
|
||||
# Outer workflow: existing behavior
|
||||
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
parent_run_id = trace_info.message_id or None
|
||||
|
||||
if trace_info.start_time is None:
|
||||
trace_info.start_time = datetime.now()
|
||||
message_dotted_order = (
|
||||
@@ -78,7 +107,8 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
metadata = trace_info.metadata
|
||||
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
||||
|
||||
if trace_info.message_id:
|
||||
# Only create message_run for outer workflows (no parent_trace_context)
|
||||
if trace_info.message_id and not parent_trace_context:
|
||||
message_run = LangSmithRunModel(
|
||||
id=trace_info.message_id,
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
@@ -121,9 +151,9 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
},
|
||||
error=trace_info.error,
|
||||
tags=["workflow"],
|
||||
parent_run_id=trace_info.message_id or None,
|
||||
parent_run_id=parent_run_id,
|
||||
trace_id=trace_id,
|
||||
dotted_order=workflow_dotted_order,
|
||||
dotted_order=None if parent_trace_context else workflow_dotted_order,
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
|
||||
@@ -21,19 +21,26 @@ from core.ops.entities.config_entity import (
|
||||
)
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
DraftNodeExecutionTrace,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
PromptGenerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
TaskData,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowNodeTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import get_message_data
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.account import Tenant
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
from models.provider import Provider, ProviderCredential, ProviderModel, ProviderModelCredential, ProviderType
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from models.workflow import WorkflowAppLog
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
@@ -43,6 +50,139 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]:
|
||||
"""Return (app_name, workspace_name) for the given IDs. Falls back to empty strings."""
|
||||
app_name = ""
|
||||
workspace_name = ""
|
||||
if not app_id and not tenant_id:
|
||||
return app_name, workspace_name
|
||||
with Session(db.engine) as session:
|
||||
if app_id:
|
||||
name = session.scalar(select(App.name).where(App.id == app_id))
|
||||
if name:
|
||||
app_name = name
|
||||
if tenant_id:
|
||||
name = session.scalar(select(Tenant.name).where(Tenant.id == tenant_id))
|
||||
if name:
|
||||
workspace_name = name
|
||||
return app_name, workspace_name
|
||||
|
||||
|
||||
_PROVIDER_TYPE_TO_MODEL: dict[str, type] = {
|
||||
"builtin": BuiltinToolProvider,
|
||||
"plugin": BuiltinToolProvider,
|
||||
"api": ApiToolProvider,
|
||||
"workflow": WorkflowToolProvider,
|
||||
"mcp": MCPToolProvider,
|
||||
}
|
||||
|
||||
|
||||
def _lookup_credential_name(credential_id: str | None, provider_type: str | None) -> str:
|
||||
if not credential_id:
|
||||
return ""
|
||||
model_cls = _PROVIDER_TYPE_TO_MODEL.get(provider_type or "")
|
||||
if not model_cls:
|
||||
return ""
|
||||
with Session(db.engine) as session:
|
||||
name = session.scalar(select(model_cls.name).where(model_cls.id == credential_id))
|
||||
return str(name) if name else ""
|
||||
|
||||
|
||||
def _lookup_llm_credential_info(
|
||||
tenant_id: str | None, provider: str | None, model: str | None, model_type: str | None = "llm"
|
||||
) -> tuple[str | None, str]:
|
||||
"""
|
||||
Lookup LLM credential ID and name for the given provider and model.
|
||||
Returns (credential_id, credential_name).
|
||||
|
||||
Handles async timing issues gracefully - if credential is deleted between lookups,
|
||||
returns the ID but empty name rather than failing.
|
||||
"""
|
||||
if not tenant_id or not provider:
|
||||
return None, ""
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
# Try to find provider-level or model-level configuration
|
||||
provider_record = session.scalar(
|
||||
select(Provider).where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM,
|
||||
)
|
||||
)
|
||||
|
||||
if not provider_record:
|
||||
return None, ""
|
||||
|
||||
# Check if there's a model-specific config
|
||||
credential_id = None
|
||||
credential_name = ""
|
||||
is_model_level = False
|
||||
|
||||
if model and provider_record.credential_id:
|
||||
# Try model-level first
|
||||
model_record = session.scalar(
|
||||
select(ProviderModel).where(
|
||||
ProviderModel.tenant_id == tenant_id,
|
||||
ProviderModel.provider_name == provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type,
|
||||
)
|
||||
)
|
||||
|
||||
if model_record and model_record.credential_id:
|
||||
credential_id = model_record.credential_id
|
||||
is_model_level = True
|
||||
|
||||
if not credential_id and provider_record.credential_id:
|
||||
# Fall back to provider-level credential
|
||||
credential_id = provider_record.credential_id
|
||||
is_model_level = False
|
||||
|
||||
# Lookup credential_name if we have credential_id
|
||||
if credential_id:
|
||||
try:
|
||||
if is_model_level:
|
||||
# Query ProviderModelCredential
|
||||
cred_name = session.scalar(
|
||||
select(ProviderModelCredential.credential_name).where(
|
||||
ProviderModelCredential.id == credential_id
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Query ProviderCredential
|
||||
cred_name = session.scalar(
|
||||
select(ProviderCredential.credential_name).where(ProviderCredential.id == credential_id)
|
||||
)
|
||||
|
||||
if cred_name:
|
||||
credential_name = str(cred_name)
|
||||
except Exception as e:
|
||||
# Credential might have been deleted between lookups (async timing)
|
||||
# Return ID but empty name rather than failing
|
||||
logger.warning(
|
||||
"Failed to lookup credential name for credential_id=%s (provider=%s, model=%s): %s",
|
||||
credential_id,
|
||||
provider,
|
||||
model,
|
||||
str(e),
|
||||
)
|
||||
|
||||
return credential_id, credential_name
|
||||
except Exception as e:
|
||||
# Database query failed or other unexpected error
|
||||
# Return empty rather than propagating error to telemetry emission
|
||||
logger.warning(
|
||||
"Failed to lookup LLM credential info for tenant_id=%s, provider=%s, model=%s: %s",
|
||||
tenant_id,
|
||||
provider,
|
||||
model,
|
||||
str(e),
|
||||
)
|
||||
return None, ""
|
||||
|
||||
|
||||
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
def __getitem__(self, provider: str) -> dict[str, Any]:
|
||||
match provider:
|
||||
@@ -317,6 +457,10 @@ class OpsTraceManager:
|
||||
if app_id is None:
|
||||
return None
|
||||
|
||||
# Handle storage_id format (tenant-{uuid}) - not a real app_id
|
||||
if isinstance(app_id, str) and app_id.startswith("tenant-"):
|
||||
return None
|
||||
|
||||
app: App | None = db.session.query(App).where(App.id == app_id).first()
|
||||
|
||||
if app is None:
|
||||
@@ -479,6 +623,56 @@ class TraceTask:
|
||||
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
return cls._workflow_run_repo
|
||||
|
||||
@classmethod
|
||||
def _get_user_id_from_metadata(cls, metadata: dict[str, Any]) -> str:
|
||||
"""Extract user ID from metadata, prioritizing end_user over account.
|
||||
|
||||
Returns the actual user ID (end_user or account) who invoked the workflow,
|
||||
regardless of invoke_from context.
|
||||
"""
|
||||
# Priority 1: End user (external users via API/WebApp)
|
||||
if user_id := metadata.get("from_end_user_id"):
|
||||
return f"end_user:{user_id}"
|
||||
|
||||
# Priority 2: Account user (internal users via console/debugger)
|
||||
if user_id := metadata.get("from_account_id"):
|
||||
return f"account:{user_id}"
|
||||
|
||||
# Priority 3: User (internal users via console/debugger)
|
||||
if user_id := metadata.get("user_id"):
|
||||
return f"user:{user_id}"
|
||||
|
||||
return "anonymous"
|
||||
|
||||
@classmethod
|
||||
def _calculate_workflow_token_split(cls, workflow_run_id: str, tenant_id: str) -> tuple[int, int]:
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
with Session(db.engine) as session:
|
||||
node_executions = session.scalars(
|
||||
select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
total_prompt = 0
|
||||
total_completion = 0
|
||||
|
||||
for node_exec in node_executions:
|
||||
metadata = node_exec.execution_metadata_dict
|
||||
|
||||
prompt = metadata.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS)
|
||||
if prompt is not None:
|
||||
total_prompt += prompt
|
||||
|
||||
completion = metadata.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS)
|
||||
if completion is not None:
|
||||
total_completion += completion
|
||||
|
||||
return (total_prompt, total_completion)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trace_type: Any,
|
||||
@@ -499,6 +693,8 @@ class TraceTask:
|
||||
self.app_id = None
|
||||
self.trace_id = None
|
||||
self.kwargs = kwargs
|
||||
if user_id is not None and "user_id" not in self.kwargs:
|
||||
self.kwargs["user_id"] = user_id
|
||||
external_trace_id = kwargs.get("external_trace_id")
|
||||
if external_trace_id:
|
||||
self.trace_id = external_trace_id
|
||||
@@ -512,7 +708,7 @@ class TraceTask:
|
||||
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
|
||||
workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
|
||||
),
|
||||
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
|
||||
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id, **self.kwargs),
|
||||
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
|
||||
message_id=self.message_id, timer=self.timer, **self.kwargs
|
||||
),
|
||||
@@ -528,6 +724,9 @@ class TraceTask:
|
||||
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
|
||||
conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
|
||||
),
|
||||
TraceTaskName.PROMPT_GENERATION_TRACE: lambda: self.prompt_generation_trace(**self.kwargs),
|
||||
TraceTaskName.NODE_EXECUTION_TRACE: lambda: self.node_execution_trace(**self.kwargs),
|
||||
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE: lambda: self.draft_node_execution_trace(**self.kwargs),
|
||||
}
|
||||
|
||||
return preprocess_map.get(self.trace_type, lambda: None)()
|
||||
@@ -563,6 +762,10 @@ class TraceTask:
|
||||
|
||||
total_tokens = workflow_run.total_tokens
|
||||
|
||||
prompt_tokens, completion_tokens = self._calculate_workflow_token_split(
|
||||
workflow_run_id=workflow_run_id, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
file_list = workflow_run_inputs.get("sys.file") or []
|
||||
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
|
||||
|
||||
@@ -583,7 +786,14 @@ class TraceTask:
|
||||
)
|
||||
message_id = session.scalar(message_data_stmt)
|
||||
|
||||
metadata = {
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id)
|
||||
else:
|
||||
app_name, workspace_name = "", ""
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"workflow_id": workflow_id,
|
||||
"conversation_id": conversation_id,
|
||||
"workflow_run_id": workflow_run_id,
|
||||
@@ -596,8 +806,14 @@ class TraceTask:
|
||||
"triggered_from": workflow_run.triggered_from,
|
||||
"user_id": user_id,
|
||||
"app_id": workflow_run.app_id,
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
}
|
||||
|
||||
parent_trace_context = self.kwargs.get("parent_trace_context")
|
||||
if parent_trace_context:
|
||||
metadata["parent_trace_context"] = parent_trace_context
|
||||
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
workflow_data=workflow_run.to_dict(),
|
||||
@@ -612,6 +828,8 @@ class TraceTask:
|
||||
workflow_run_version=workflow_run_version,
|
||||
error=error,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
file_list=file_list,
|
||||
query=query,
|
||||
metadata=metadata,
|
||||
@@ -619,10 +837,11 @@ class TraceTask:
|
||||
message_id=message_id,
|
||||
start_time=workflow_run.created_at,
|
||||
end_time=workflow_run.finished_at,
|
||||
invoked_by=self._get_user_id_from_metadata(metadata),
|
||||
)
|
||||
return workflow_trace_info
|
||||
|
||||
def message_trace(self, message_id: str | None):
|
||||
def message_trace(self, message_id: str | None, **kwargs):
|
||||
if not message_id:
|
||||
return {}
|
||||
message_data = get_message_data(message_id)
|
||||
@@ -645,6 +864,19 @@ class TraceTask:
|
||||
|
||||
streaming_metrics = self._extract_streaming_metrics(message_data)
|
||||
|
||||
tenant_id = ""
|
||||
with Session(db.engine) as session:
|
||||
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
|
||||
if tid:
|
||||
tenant_id = str(tid)
|
||||
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
|
||||
else:
|
||||
app_name, workspace_name = "", ""
|
||||
|
||||
metadata = {
|
||||
"conversation_id": message_data.conversation_id,
|
||||
"ls_provider": message_data.model_provider,
|
||||
@@ -656,7 +888,14 @@ class TraceTask:
|
||||
"workflow_run_id": message_data.workflow_run_id,
|
||||
"from_source": message_data.from_source,
|
||||
"message_id": message_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": message_data.app_id,
|
||||
"user_id": message_data.from_end_user_id or message_data.from_account_id,
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
message_tokens = message_data.message_tokens
|
||||
|
||||
@@ -673,7 +912,9 @@ class TraceTask:
|
||||
outputs=message_data.answer,
|
||||
file_list=file_list,
|
||||
start_time=created_at,
|
||||
end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
|
||||
end_time=message_data.updated_at
|
||||
if message_data.updated_at and message_data.updated_at > created_at
|
||||
else created_at + timedelta(seconds=message_data.provider_response_latency),
|
||||
metadata=metadata,
|
||||
message_file_data=message_file_data,
|
||||
conversation_mode=conversation_mode,
|
||||
@@ -698,6 +939,8 @@ class TraceTask:
|
||||
"preset_response": moderation_result.preset_response,
|
||||
"query": moderation_result.query,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_id = None
|
||||
@@ -739,6 +982,8 @@ class TraceTask:
|
||||
"workflow_run_id": message_data.workflow_run_id,
|
||||
"from_source": message_data.from_source,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_id = None
|
||||
@@ -778,6 +1023,52 @@ class TraceTask:
|
||||
if not message_data:
|
||||
return {}
|
||||
|
||||
tenant_id = ""
|
||||
with Session(db.engine) as session:
|
||||
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
|
||||
if tid:
|
||||
tenant_id = str(tid)
|
||||
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
|
||||
else:
|
||||
app_name, workspace_name = "", ""
|
||||
|
||||
doc_list = [doc.model_dump() for doc in documents] if documents else []
|
||||
dataset_ids: set[str] = set()
|
||||
for doc in doc_list:
|
||||
doc_meta = doc.get("metadata") or {}
|
||||
did = doc_meta.get("dataset_id")
|
||||
if did:
|
||||
dataset_ids.add(did)
|
||||
|
||||
embedding_models: dict[str, dict[str, str]] = {}
|
||||
if dataset_ids:
|
||||
with Session(db.engine) as session:
|
||||
rows = session.execute(
|
||||
select(Dataset.id, Dataset.embedding_model, Dataset.embedding_model_provider).where(
|
||||
Dataset.id.in_(list(dataset_ids))
|
||||
)
|
||||
).all()
|
||||
for row in rows:
|
||||
embedding_models[str(row[0])] = {
|
||||
"embedding_model": row[1] or "",
|
||||
"embedding_model_provider": row[2] or "",
|
||||
}
|
||||
|
||||
# Extract rerank model info from retrieval_model kwargs
|
||||
rerank_model_provider = ""
|
||||
rerank_model_name = ""
|
||||
if "retrieval_model" in kwargs:
|
||||
retrieval_model = kwargs["retrieval_model"]
|
||||
if isinstance(retrieval_model, dict):
|
||||
reranking_model = retrieval_model.get("reranking_model")
|
||||
if isinstance(reranking_model, dict):
|
||||
rerank_model_provider = reranking_model.get("reranking_provider_name", "")
|
||||
rerank_model_name = reranking_model.get("reranking_model_name", "")
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
"ls_provider": message_data.model_provider,
|
||||
@@ -788,13 +1079,23 @@ class TraceTask:
|
||||
"agent_based": message_data.agent_based,
|
||||
"workflow_run_id": message_data.workflow_run_id,
|
||||
"from_source": message_data.from_source,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": message_data.app_id,
|
||||
"user_id": message_data.from_end_user_id or message_data.from_account_id,
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
"embedding_models": embedding_models,
|
||||
"rerank_model_provider": rerank_model_provider,
|
||||
"rerank_model_name": rerank_model_name,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
message_id=message_id,
|
||||
inputs=message_data.query or message_data.inputs,
|
||||
documents=[doc.model_dump() for doc in documents] if documents else [],
|
||||
documents=doc_list,
|
||||
start_time=timer.get("start"),
|
||||
end_time=timer.get("end"),
|
||||
metadata=metadata,
|
||||
@@ -837,6 +1138,10 @@ class TraceTask:
|
||||
"error": error,
|
||||
"tool_parameters": tool_parameters,
|
||||
}
|
||||
if message_data.workflow_run_id:
|
||||
metadata["workflow_run_id"] = message_data.workflow_run_id
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
file_url = ""
|
||||
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
|
||||
@@ -891,6 +1196,8 @@ class TraceTask:
|
||||
"conversation_id": conversation_id,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
generate_name_trace_info = GenerateNameTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
@@ -905,6 +1212,182 @@ class TraceTask:
|
||||
|
||||
return generate_name_trace_info
|
||||
|
||||
def prompt_generation_trace(self, **kwargs) -> PromptGenerationTraceInfo | dict:
|
||||
tenant_id = kwargs.get("tenant_id", "")
|
||||
user_id = kwargs.get("user_id", "")
|
||||
app_id = kwargs.get("app_id")
|
||||
operation_type = kwargs.get("operation_type", "")
|
||||
instruction = kwargs.get("instruction", "")
|
||||
generated_output = kwargs.get("generated_output", "")
|
||||
|
||||
prompt_tokens = kwargs.get("prompt_tokens", 0)
|
||||
completion_tokens = kwargs.get("completion_tokens", 0)
|
||||
total_tokens = kwargs.get("total_tokens", 0)
|
||||
|
||||
model_provider = kwargs.get("model_provider", "")
|
||||
model_name = kwargs.get("model_name", "")
|
||||
|
||||
latency = kwargs.get("latency", 0.0)
|
||||
|
||||
timer = kwargs.get("timer")
|
||||
start_time = timer.get("start") if timer else None
|
||||
end_time = timer.get("end") if timer else None
|
||||
|
||||
total_price = kwargs.get("total_price")
|
||||
currency = kwargs.get("currency")
|
||||
|
||||
error = kwargs.get("error")
|
||||
|
||||
app_name = None
|
||||
workspace_name = None
|
||||
if app_id:
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(app_id, tenant_id)
|
||||
|
||||
metadata = {
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"app_id": app_id or "",
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
"operation_type": operation_type,
|
||||
"model_provider": model_provider,
|
||||
"model_name": model_name,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
return PromptGenerationTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
inputs=instruction,
|
||||
outputs=generated_output,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
metadata=metadata,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=operation_type,
|
||||
instruction=instruction,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
latency=latency,
|
||||
total_price=total_price,
|
||||
currency=currency,
|
||||
error=error,
|
||||
)
|
||||
|
||||
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict:
|
||||
node_data: dict = kwargs.get("node_execution_data", {})
|
||||
if not node_data:
|
||||
return {}
|
||||
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(
|
||||
node_data.get("app_id"), node_data.get("tenant_id")
|
||||
)
|
||||
else:
|
||||
app_name, workspace_name = "", ""
|
||||
|
||||
# Try tool credential lookup first
|
||||
credential_id = node_data.get("credential_id")
|
||||
if is_enterprise_telemetry_enabled():
|
||||
credential_name = _lookup_credential_name(credential_id, node_data.get("credential_provider_type"))
|
||||
# If no credential_id found (e.g., LLM nodes), try LLM credential lookup
|
||||
if not credential_id:
|
||||
llm_cred_id, llm_cred_name = _lookup_llm_credential_info(
|
||||
tenant_id=node_data.get("tenant_id"),
|
||||
provider=node_data.get("model_provider"),
|
||||
model=node_data.get("model_name"),
|
||||
model_type="llm",
|
||||
)
|
||||
if llm_cred_id:
|
||||
credential_id = llm_cred_id
|
||||
credential_name = llm_cred_name
|
||||
else:
|
||||
credential_name = ""
|
||||
metadata: dict[str, Any] = {
|
||||
"tenant_id": node_data.get("tenant_id"),
|
||||
"app_id": node_data.get("app_id"),
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
"user_id": node_data.get("user_id"),
|
||||
"invoke_from": node_data.get("invoke_from"),
|
||||
"credential_id": node_data.get("credential_id"),
|
||||
"credential_name": credential_name,
|
||||
"dataset_ids": node_data.get("dataset_ids"),
|
||||
"dataset_names": node_data.get("dataset_names"),
|
||||
"plugin_name": node_data.get("plugin_name"),
|
||||
}
|
||||
|
||||
parent_trace_context = node_data.get("parent_trace_context")
|
||||
if parent_trace_context:
|
||||
metadata["parent_trace_context"] = parent_trace_context
|
||||
|
||||
message_id: str | None = None
|
||||
conversation_id = node_data.get("conversation_id")
|
||||
workflow_execution_id = node_data.get("workflow_execution_id")
|
||||
if conversation_id and workflow_execution_id and not parent_trace_context:
|
||||
with Session(db.engine) as session:
|
||||
msg_id = session.scalar(
|
||||
select(Message.id).where(
|
||||
Message.conversation_id == conversation_id,
|
||||
Message.workflow_run_id == workflow_execution_id,
|
||||
)
|
||||
)
|
||||
if msg_id:
|
||||
message_id = str(msg_id)
|
||||
metadata["message_id"] = message_id
|
||||
if conversation_id:
|
||||
metadata["conversation_id"] = conversation_id
|
||||
|
||||
return WorkflowNodeTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
message_id=message_id,
|
||||
start_time=node_data.get("created_at"),
|
||||
end_time=node_data.get("finished_at"),
|
||||
metadata=metadata,
|
||||
workflow_id=node_data.get("workflow_id", ""),
|
||||
workflow_run_id=node_data.get("workflow_execution_id", ""),
|
||||
tenant_id=node_data.get("tenant_id", ""),
|
||||
node_execution_id=node_data.get("node_execution_id", ""),
|
||||
node_id=node_data.get("node_id", ""),
|
||||
node_type=node_data.get("node_type", ""),
|
||||
title=node_data.get("title", ""),
|
||||
status=node_data.get("status", ""),
|
||||
error=node_data.get("error"),
|
||||
elapsed_time=node_data.get("elapsed_time", 0.0),
|
||||
index=node_data.get("index", 0),
|
||||
predecessor_node_id=node_data.get("predecessor_node_id"),
|
||||
total_tokens=node_data.get("total_tokens", 0),
|
||||
total_price=node_data.get("total_price", 0.0),
|
||||
currency=node_data.get("currency"),
|
||||
model_provider=node_data.get("model_provider"),
|
||||
model_name=node_data.get("model_name"),
|
||||
prompt_tokens=node_data.get("prompt_tokens"),
|
||||
completion_tokens=node_data.get("completion_tokens"),
|
||||
tool_name=node_data.get("tool_name"),
|
||||
iteration_id=node_data.get("iteration_id"),
|
||||
iteration_index=node_data.get("iteration_index"),
|
||||
loop_id=node_data.get("loop_id"),
|
||||
loop_index=node_data.get("loop_index"),
|
||||
parallel_id=node_data.get("parallel_id"),
|
||||
node_inputs=node_data.get("node_inputs"),
|
||||
node_outputs=node_data.get("node_outputs"),
|
||||
process_data=node_data.get("process_data"),
|
||||
invoked_by=self._get_user_id_from_metadata(metadata),
|
||||
)
|
||||
|
||||
def draft_node_execution_trace(self, **kwargs) -> DraftNodeExecutionTrace | dict:
|
||||
node_trace = self.node_execution_trace(**kwargs)
|
||||
if not node_trace or not isinstance(node_trace, WorkflowNodeTraceInfo):
|
||||
return node_trace
|
||||
return DraftNodeExecutionTrace(**node_trace.model_dump())
|
||||
|
||||
def _extract_streaming_metrics(self, message_data) -> dict:
|
||||
if not message_data.message_metadata:
|
||||
return {}
|
||||
@@ -938,13 +1421,17 @@ class TraceQueueManager:
|
||||
self.user_id = user_id
|
||||
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
|
||||
self.flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled()
|
||||
if trace_manager_timer is None:
|
||||
self.start_timer()
|
||||
|
||||
def add_trace_task(self, trace_task: TraceTask):
|
||||
global trace_manager_timer, trace_manager_queue
|
||||
try:
|
||||
if self.trace_instance:
|
||||
if self._enterprise_telemetry_enabled or self.trace_instance:
|
||||
trace_task.app_id = self.app_id
|
||||
trace_manager_queue.put(trace_task)
|
||||
except Exception:
|
||||
@@ -980,20 +1467,27 @@ class TraceQueueManager:
|
||||
def send_to_celery(self, tasks: list[TraceTask]):
|
||||
with self.flask_app.app_context():
|
||||
for task in tasks:
|
||||
if task.app_id is None:
|
||||
continue
|
||||
storage_id = task.app_id
|
||||
if storage_id is None:
|
||||
tenant_id = task.kwargs.get("tenant_id")
|
||||
if tenant_id:
|
||||
storage_id = f"tenant-{tenant_id}"
|
||||
else:
|
||||
logger.warning("Skipping trace without app_id or tenant_id, trace_type: %s", task.trace_type)
|
||||
continue
|
||||
|
||||
file_id = uuid4().hex
|
||||
trace_info = task.execute()
|
||||
|
||||
task_data = TaskData(
|
||||
app_id=task.app_id,
|
||||
app_id=storage_id,
|
||||
trace_info_type=type(trace_info).__name__,
|
||||
trace_info=trace_info.model_dump() if trace_info else None,
|
||||
)
|
||||
file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
|
||||
file_path = f"{OPS_FILE_PATH}{storage_id}/{file_id}.json"
|
||||
storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
|
||||
file_info = {
|
||||
"file_id": file_id,
|
||||
"app_id": task.app_id,
|
||||
"app_id": storage_id,
|
||||
}
|
||||
process_trace_tasks.delay(file_info) # type: ignore
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field, computed_field, model_validator
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||
@@ -48,15 +48,3 @@ class MarketplacePluginDeclaration(BaseModel):
|
||||
if "tool" in data and not data["tool"]:
|
||||
del data["tool"]
|
||||
return data
|
||||
|
||||
|
||||
class MarketplacePluginSnapshot(BaseModel):
|
||||
org: str
|
||||
name: str
|
||||
latest_version: str
|
||||
latest_package_identifier: str
|
||||
latest_package_url: str
|
||||
|
||||
@computed_field
|
||||
def plugin_id(self) -> str:
|
||||
return f"{self.org}/{self.name}"
|
||||
|
||||
@@ -27,8 +27,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
@@ -56,6 +55,8 @@ from core.rag.retrieval.template_prompts import (
|
||||
METADATA_FILTER_USER_PROMPT_2,
|
||||
METADATA_FILTER_USER_PROMPT_3,
|
||||
)
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
@@ -728,10 +729,21 @@ class DatasetRetrieval:
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
|
||||
)
|
||||
app_config = self.application_generate_entity.app_config if self.application_generate_entity else None
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=app_config.tenant_id if app_config else None,
|
||||
app_id=app_config.app_id if app_config else None,
|
||||
),
|
||||
payload={
|
||||
"message_id": message_id,
|
||||
"documents": documents,
|
||||
"timer": timer,
|
||||
},
|
||||
),
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
def _on_query(
|
||||
|
||||
43
api/core/telemetry/__init__.py
Normal file
43
api/core/telemetry/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Telemetry facade.
|
||||
|
||||
Thin public API for emitting telemetry events. All routing logic
|
||||
lives in ``core.telemetry.gateway`` which is shared by both CE and EE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.telemetry.events import TelemetryContext, TelemetryEvent
|
||||
from core.telemetry.gateway import emit as gateway_emit
|
||||
from core.telemetry.gateway import get_trace_task_to_case
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
||||
def emit(event: TelemetryEvent, trace_manager: TraceQueueManager | None = None) -> None:
|
||||
"""Emit a telemetry event.
|
||||
|
||||
Translates the ``TelemetryEvent`` (keyed by ``TraceTaskName``) into a
|
||||
``TelemetryCase`` and delegates to ``core.telemetry.gateway.emit()``.
|
||||
"""
|
||||
case = get_trace_task_to_case().get(event.name)
|
||||
if case is None:
|
||||
return
|
||||
|
||||
context: dict[str, object] = {
|
||||
"tenant_id": event.context.tenant_id,
|
||||
"user_id": event.context.user_id,
|
||||
"app_id": event.context.app_id,
|
||||
}
|
||||
gateway_emit(case, context, event.payload, trace_manager)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TelemetryContext",
|
||||
"TelemetryEvent",
|
||||
"TraceTaskName",
|
||||
"emit",
|
||||
]
|
||||
21
api/core/telemetry/events.py
Normal file
21
api/core/telemetry/events.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TelemetryContext:
|
||||
tenant_id: str | None = None
|
||||
user_id: str | None = None
|
||||
app_id: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TelemetryEvent:
|
||||
name: TraceTaskName
|
||||
context: TelemetryContext
|
||||
payload: dict[str, Any]
|
||||
233
api/core/telemetry/gateway.py
Normal file
233
api/core/telemetry/gateway.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Telemetry gateway — single routing layer for all editions.
|
||||
|
||||
Maps ``TelemetryCase`` → ``CaseRoute`` and dispatches events to either
|
||||
the CE/EE trace pipeline (``TraceQueueManager``) or the enterprise-only
|
||||
metric/log Celery queue.
|
||||
|
||||
This module lives in ``core/`` so both CE and EE share one routing table
|
||||
and one ``emit()`` entry point. No separate enterprise gateway module is
|
||||
needed — enterprise-specific dispatch (Celery task, payload offloading)
|
||||
is handled here behind lazy imports that no-op in CE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from enterprise.telemetry.contracts import TelemetryCase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routing table — authoritative mapping for all editions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_case_to_trace_task: dict | None = None
|
||||
_case_routing: dict | None = None
|
||||
|
||||
|
||||
def _get_case_to_trace_task() -> dict:
|
||||
global _case_to_trace_task
|
||||
if _case_to_trace_task is None:
|
||||
from enterprise.telemetry.contracts import TelemetryCase
|
||||
|
||||
_case_to_trace_task = {
|
||||
TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE,
|
||||
TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE,
|
||||
TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE,
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE,
|
||||
TelemetryCase.TOOL_EXECUTION: TraceTaskName.TOOL_TRACE,
|
||||
TelemetryCase.MODERATION_CHECK: TraceTaskName.MODERATION_TRACE,
|
||||
TelemetryCase.SUGGESTED_QUESTION: TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
TelemetryCase.DATASET_RETRIEVAL: TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
TelemetryCase.GENERATE_NAME: TraceTaskName.GENERATE_NAME_TRACE,
|
||||
}
|
||||
return _case_to_trace_task
|
||||
|
||||
|
||||
def get_trace_task_to_case() -> dict:
|
||||
"""Return TraceTaskName → TelemetryCase (inverse of _get_case_to_trace_task)."""
|
||||
return {v: k for k, v in _get_case_to_trace_task().items()}
|
||||
|
||||
|
||||
def _get_case_routing() -> dict:
|
||||
global _case_routing
|
||||
if _case_routing is None:
|
||||
from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase
|
||||
|
||||
_case_routing = {
|
||||
# TRACE — CE-eligible (flow in both CE and EE)
|
||||
TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
# TRACE — enterprise-only
|
||||
TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
|
||||
TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
|
||||
# METRIC_LOG — enterprise-only (signal-driven, not trace)
|
||||
TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
}
|
||||
return _case_routing
|
||||
|
||||
# Public exports for tests and external consumers
|
||||
CASE_TO_TRACE_TASK = _get_case_to_trace_task()
|
||||
CASE_ROUTING = _get_case_routing()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def is_enterprise_telemetry_enabled() -> bool:
|
||||
try:
|
||||
from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled
|
||||
|
||||
return is_enterprise_telemetry_enabled()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _handle_payload_sizing(
|
||||
payload: dict[str, Any],
|
||||
tenant_id: str,
|
||||
event_id: str,
|
||||
) -> tuple[dict[str, Any], str | None]:
|
||||
"""Inline or offload payload based on size.
|
||||
|
||||
Returns ``(payload_for_envelope, storage_key | None)``. Payloads
|
||||
exceeding ``PAYLOAD_SIZE_THRESHOLD_BYTES`` are written to object
|
||||
storage and replaced with an empty dict in the envelope.
|
||||
"""
|
||||
try:
|
||||
payload_json = json.dumps(payload)
|
||||
payload_size = len(payload_json.encode("utf-8"))
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id)
|
||||
return payload, None
|
||||
|
||||
if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES:
|
||||
return payload, None
|
||||
|
||||
storage_key = f"telemetry/{tenant_id}/{event_id}.json"
|
||||
try:
|
||||
storage.save(storage_key, payload_json.encode("utf-8"))
|
||||
logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size)
|
||||
return {}, storage_key
|
||||
except Exception:
|
||||
logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True)
|
||||
return payload, None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def emit(
|
||||
case: TelemetryCase,
|
||||
context: dict[str, Any],
|
||||
payload: dict[str, Any],
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> None:
|
||||
"""Route a telemetry event to the correct pipeline.
|
||||
|
||||
TRACE events are enqueued into ``TraceQueueManager`` (works in both CE
|
||||
and EE). Enterprise-only traces are silently dropped when EE is
|
||||
disabled.
|
||||
|
||||
METRIC_LOG events are dispatched to the enterprise Celery queue;
|
||||
silently dropped when enterprise telemetry is unavailable.
|
||||
"""
|
||||
route = _get_case_routing().get(case)
|
||||
if route is None:
|
||||
logger.warning("Unknown telemetry case: %s, dropping event", case)
|
||||
return
|
||||
|
||||
if not route.ce_eligible and not is_enterprise_telemetry_enabled():
|
||||
logger.debug("Dropping EE-only event: case=%s (EE disabled)", case)
|
||||
return
|
||||
|
||||
if route.signal_type == "trace":
|
||||
_emit_trace(case, context, payload, trace_manager)
|
||||
else:
|
||||
_emit_metric_log(case, context, payload)
|
||||
|
||||
|
||||
def _emit_trace(
|
||||
case: TelemetryCase,
|
||||
context: dict[str, Any],
|
||||
payload: dict[str, Any],
|
||||
trace_manager: TraceQueueManager | None,
|
||||
) -> None:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager
|
||||
from core.ops.ops_trace_manager import TraceTask
|
||||
|
||||
trace_task_name = _get_case_to_trace_task().get(case)
|
||||
if trace_task_name is None:
|
||||
logger.warning("No TraceTaskName mapping for case: %s", case)
|
||||
return
|
||||
|
||||
queue_manager = trace_manager or LocalTraceQueueManager(
|
||||
app_id=context.get("app_id"),
|
||||
user_id=context.get("user_id"),
|
||||
)
|
||||
queue_manager.add_trace_task(TraceTask(trace_task_name, **payload))
|
||||
logger.debug("Enqueued trace task: case=%s, app_id=%s", case, context.get("app_id"))
|
||||
|
||||
|
||||
def _emit_metric_log(
|
||||
case: TelemetryCase,
|
||||
context: dict[str, Any],
|
||||
payload: dict[str, Any],
|
||||
) -> None:
|
||||
"""Build envelope and dispatch to enterprise Celery queue.
|
||||
|
||||
No-ops when the enterprise telemetry task is not importable (CE mode).
|
||||
"""
|
||||
try:
|
||||
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
|
||||
except ImportError:
|
||||
logger.debug("Enterprise metric/log dispatch unavailable, dropping: case=%s", case)
|
||||
return
|
||||
|
||||
tenant_id = context.get("tenant_id", "")
|
||||
event_id = str(uuid.uuid4())
|
||||
|
||||
payload_for_envelope, payload_ref = _handle_payload_sizing(payload, tenant_id, event_id)
|
||||
|
||||
from enterprise.telemetry.contracts import TelemetryEnvelope
|
||||
|
||||
envelope = TelemetryEnvelope(
|
||||
case=case,
|
||||
tenant_id=tenant_id,
|
||||
event_id=event_id,
|
||||
payload=payload_for_envelope,
|
||||
metadata={"payload_ref": payload_ref} if payload_ref else None,
|
||||
)
|
||||
|
||||
process_enterprise_telemetry.delay(envelope.model_dump_json())
|
||||
logger.debug(
|
||||
"Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s",
|
||||
case,
|
||||
tenant_id,
|
||||
event_id,
|
||||
)
|
||||
@@ -50,6 +50,7 @@ class WorkflowTool(Tool):
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.label = label
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
self.parent_trace_context: dict[str, str] | None = None
|
||||
|
||||
super().__init__(entity=entity, runtime=runtime)
|
||||
|
||||
@@ -90,11 +91,15 @@ class WorkflowTool(Tool):
|
||||
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
|
||||
args: dict[str, Any] = {"inputs": tool_parameters, "files": files}
|
||||
if self.parent_trace_context:
|
||||
args["_parent_trace_context"] = self.parent_trace_context
|
||||
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args={"inputs": tool_parameters, "files": files},
|
||||
args=args,
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
streaming=False,
|
||||
call_depth=self.workflow_call_depth + 1,
|
||||
|
||||
@@ -112,7 +112,7 @@ class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
|
||||
|
||||
class RAGPipelineVariable(BaseModel):
|
||||
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
|
||||
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
|
||||
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
|
||||
label: str = Field(description="label")
|
||||
description: str | None = Field(description="description", default="")
|
||||
variable: str = Field(description="variable key", default="")
|
||||
|
||||
@@ -232,6 +232,8 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
PROMPT_TOKENS = "prompt_tokens"
|
||||
COMPLETION_TOKENS = "completion_tokens"
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
|
||||
@@ -322,6 +322,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS: usage.prompt_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS: usage.completion_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
@@ -15,13 +16,12 @@ if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
|
||||
|
||||
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
_template_renderer: Jinja2TemplateRenderer
|
||||
_max_output_length: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -31,7 +31,6 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
max_output_length: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -41,10 +40,6 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
|
||||
if max_output_length is not None and max_output_length <= 0:
|
||||
raise ValueError("max_output_length must be a positive integer")
|
||||
self._max_output_length = max_output_length or DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
@@ -74,11 +69,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
except TemplateRenderError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
||||
if len(rendered) > self._max_output_length:
|
||||
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=f"Output length exceeds {self._max_output_length} characters",
|
||||
error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters",
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
|
||||
@@ -60,7 +60,9 @@ class ToolNode(Node[ToolNodeData]):
|
||||
tool_info = {
|
||||
"provider_type": self.node_data.provider_type.value,
|
||||
"provider_id": self.node_data.provider_id,
|
||||
"tool_name": self.node_data.tool_name,
|
||||
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
|
||||
"credential_id": self.node_data.credential_id,
|
||||
}
|
||||
|
||||
# get tool runtime
|
||||
@@ -105,6 +107,20 @@ class ToolNode(Node[ToolNodeData]):
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
|
||||
if isinstance(tool_runtime, WorkflowTool):
|
||||
workflow_run_id_var = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID]
|
||||
)
|
||||
tool_runtime.parent_trace_context = {
|
||||
"trace_id": str(workflow_run_id_var.text) if workflow_run_id_var else "",
|
||||
"parent_node_execution_id": self.execution_id,
|
||||
"parent_workflow_run_id": str(workflow_run_id_var.text) if workflow_run_id_var else "",
|
||||
"parent_app_id": self.app_id,
|
||||
"parent_conversation_id": conversation_id.text if conversation_id else None,
|
||||
}
|
||||
|
||||
try:
|
||||
message_stream = ToolEngine.generic_invoke(
|
||||
tool=tool_runtime,
|
||||
@@ -431,6 +447,8 @@ class ToolNode(Node[ToolNodeData]):
|
||||
}
|
||||
if isinstance(usage.total_tokens, int) and usage.total_tokens > 0:
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
|
||||
metadata[WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS] = usage.prompt_tokens
|
||||
metadata[WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS] = usage.completion_tokens
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
|
||||
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
|
||||
|
||||
|
||||
@@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
|
||||
if [[ -z "${CELERY_QUEUES}" ]]; then
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
|
||||
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
fi
|
||||
else
|
||||
DEFAULT_QUEUES="${CELERY_QUEUES}"
|
||||
|
||||
0
api/enterprise/__init__.py
Normal file
0
api/enterprise/__init__.py
Normal file
522
api/enterprise/telemetry/DATA_DICTIONARY.md
Normal file
522
api/enterprise/telemetry/DATA_DICTIONARY.md
Normal file
@@ -0,0 +1,522 @@
|
||||
# Dify Enterprise Telemetry Data Dictionary
|
||||
|
||||
Quick reference for all telemetry signals emitted by Dify Enterprise. For configuration and architecture details, see [README.md](./README.md).
|
||||
|
||||
## Resource Attributes
|
||||
|
||||
Attached to every signal (Span, Metric, Log).
|
||||
|
||||
| Attribute | Type | Example |
|
||||
|-----------|------|---------|
|
||||
| `service.name` | string | `dify` |
|
||||
| `host.name` | string | `dify-api-7f8b` |
|
||||
|
||||
## Traces (Spans)
|
||||
|
||||
### `dify.workflow.run`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.trace_id` | string | Business trace ID (Workflow Run ID) |
|
||||
| `dify.tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.workflow.id` | string | Workflow definition ID |
|
||||
| `dify.workflow.run_id` | string | Unique ID for this run |
|
||||
| `dify.workflow.status` | string | `succeeded`, `failed`, `stopped`, etc. |
|
||||
| `dify.workflow.error` | string | Error message if failed |
|
||||
| `dify.workflow.elapsed_time` | float | Total execution time (seconds) |
|
||||
| `dify.invoke_from` | string | `api`, `webapp`, `debug` |
|
||||
| `dify.conversation.id` | string | Conversation ID (optional) |
|
||||
| `dify.message.id` | string | Message ID (optional) |
|
||||
| `dify.invoked_by` | string | User ID who triggered the run |
|
||||
| `gen_ai.usage.total_tokens` | int | Total tokens across all nodes (optional) |
|
||||
| `gen_ai.user.id` | string | End-user identifier (optional) |
|
||||
| `dify.parent.trace_id` | string | Parent workflow trace ID (optional) |
|
||||
| `dify.parent.workflow.run_id` | string | Parent workflow run ID (optional) |
|
||||
| `dify.parent.node.execution_id` | string | Parent node execution ID (optional) |
|
||||
| `dify.parent.app.id` | string | Parent app ID (optional) |
|
||||
|
||||
### `dify.node.execution`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.trace_id` | string | Business trace ID |
|
||||
| `dify.tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.workflow.id` | string | Workflow definition ID |
|
||||
| `dify.workflow.run_id` | string | Workflow Run ID |
|
||||
| `dify.message.id` | string | Message ID (optional) |
|
||||
| `dify.conversation.id` | string | Conversation ID (optional) |
|
||||
| `dify.node.execution_id` | string | Unique node execution ID |
|
||||
| `dify.node.id` | string | Node ID in workflow graph |
|
||||
| `dify.node.type` | string | Node type (see appendix) |
|
||||
| `dify.node.title` | string | Display title |
|
||||
| `dify.node.status` | string | `succeeded`, `failed` |
|
||||
| `dify.node.error` | string | Error message if failed |
|
||||
| `dify.node.elapsed_time` | float | Execution time (seconds) |
|
||||
| `dify.node.index` | int | Execution order index |
|
||||
| `dify.node.predecessor_node_id` | string | Triggering node ID |
|
||||
| `dify.node.iteration_id` | string | Iteration ID (optional) |
|
||||
| `dify.node.loop_id` | string | Loop ID (optional) |
|
||||
| `dify.node.parallel_id` | string | Parallel branch ID (optional) |
|
||||
| `dify.node.invoked_by` | string | User ID who triggered execution |
|
||||
| `gen_ai.usage.input_tokens` | int | Prompt tokens (LLM nodes only) |
|
||||
| `gen_ai.usage.output_tokens` | int | Completion tokens (LLM nodes only) |
|
||||
| `gen_ai.usage.total_tokens` | int | Total tokens (LLM nodes only) |
|
||||
| `gen_ai.request.model` | string | LLM model name (LLM nodes only) |
|
||||
| `gen_ai.provider.name` | string | LLM provider name (LLM nodes only) |
|
||||
| `gen_ai.user.id` | string | End-user identifier (optional) |
|
||||
|
||||
### `dify.node.execution.draft`
|
||||
|
||||
Same attributes as `dify.node.execution`. Emitted during Preview/Debug runs.
|
||||
|
||||
## Counters
|
||||
|
||||
All counters are cumulative and emitted at 100% accuracy.
|
||||
|
||||
### Token Counters
|
||||
|
||||
| Metric | Unit | Description |
|
||||
|--------|------|-------------|
|
||||
| `dify.tokens.total` | `{token}` | Total tokens consumed |
|
||||
| `dify.tokens.input` | `{token}` | Input (prompt) tokens |
|
||||
| `dify.tokens.output` | `{token}` | Output (completion) tokens |
|
||||
|
||||
**Labels:**
|
||||
- `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type` (if node_execution)
|
||||
|
||||
⚠️ **Warning:** `dify.tokens.total` at workflow level includes all node tokens. Filter by `operation_type` to avoid double-counting.
|
||||
|
||||
#### Token Hierarchy & Query Patterns
|
||||
|
||||
Token metrics are emitted at multiple layers. Understanding the hierarchy prevents double-counting:
|
||||
|
||||
```
|
||||
App-level total
|
||||
├── workflow ← sum of all node_execution tokens (DO NOT add both)
|
||||
│ └── node_execution ← per-node breakdown
|
||||
├── message ← independent (non-workflow chat apps only)
|
||||
├── rule_generate ← independent helper LLM call
|
||||
├── code_generate ← independent helper LLM call
|
||||
├── structured_output ← independent helper LLM call
|
||||
└── instruction_modify← independent helper LLM call
|
||||
```
|
||||
|
||||
**Key rule:** `workflow` tokens already include all `node_execution` tokens. Never sum both.
|
||||
|
||||
**Available labels on token metrics:** `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type`.
|
||||
App name is only available on span attributes (`dify.app.name`), not metric labels — use `app_id` for metric queries.
|
||||
|
||||
**Common queries** (PromQL):
|
||||
|
||||
```promql
|
||||
# ── Totals ──────────────────────────────────────────────────
|
||||
# App-level total (exclude node_execution to avoid double-counting)
|
||||
sum by (app_id) (dify_tokens_total{operation_type!="node_execution"})
|
||||
|
||||
# Single app total
|
||||
sum (dify_tokens_total{app_id="<app_id>", operation_type!="node_execution"})
|
||||
|
||||
# Per-tenant totals
|
||||
sum by (tenant_id) (dify_tokens_total{operation_type!="node_execution"})
|
||||
|
||||
# ── Drill-down ──────────────────────────────────────────────
|
||||
# Workflow-level tokens for an app
|
||||
sum (dify_tokens_total{app_id="<app_id>", operation_type="workflow"})
|
||||
|
||||
# Node-level breakdown within an app
|
||||
sum by (node_type) (dify_tokens_total{app_id="<app_id>", operation_type="node_execution"})
|
||||
|
||||
# Model breakdown for an app
|
||||
sum by (model_provider, model_name) (dify_tokens_total{app_id="<app_id>"})
|
||||
|
||||
# Input vs output per model
|
||||
sum by (model_name) (dify_tokens_input_total{app_id="<app_id>"})
|
||||
sum by (model_name) (dify_tokens_output_total{app_id="<app_id>"})
|
||||
|
||||
# ── Rates ───────────────────────────────────────────────────
|
||||
# Token consumption rate (per hour)
|
||||
sum(rate(dify_tokens_total{operation_type!="node_execution"}[1h]))
|
||||
|
||||
# Per-app consumption rate
|
||||
sum by (app_id) (rate(dify_tokens_total{operation_type!="node_execution"}[1h]))
|
||||
```
|
||||
|
||||
**Finding `app_id` from app name** (trace query — Tempo / Jaeger):
|
||||
|
||||
```
|
||||
{ resource.dify.app.name = "My Chatbot" } | select(resource.dify.app.id)
|
||||
```
|
||||
|
||||
### Request Counters
|
||||
|
||||
| Metric | Unit | Description |
|
||||
|--------|------|-------------|
|
||||
| `dify.requests.total` | `{request}` | Total operations count |
|
||||
|
||||
**Labels by type:**
|
||||
|
||||
| `type` | Additional Labels |
|
||||
|--------|-------------------|
|
||||
| `workflow` | `tenant_id`, `app_id`, `status`, `invoke_from` |
|
||||
| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` |
|
||||
| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` |
|
||||
| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name`, `status`, `invoke_from` |
|
||||
| `tool` | `tenant_id`, `app_id`, `tool_name` |
|
||||
| `moderation` | `tenant_id`, `app_id` |
|
||||
| `suggested_question` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
|
||||
| `dataset_retrieval` | `tenant_id`, `app_id` |
|
||||
| `generate_name` | `tenant_id`, `app_id` |
|
||||
| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `status` |
|
||||
|
||||
### Error Counters
|
||||
|
||||
| Metric | Unit | Description |
|
||||
|--------|------|-------------|
|
||||
| `dify.errors.total` | `{error}` | Total failed operations |
|
||||
|
||||
**Labels by type:**
|
||||
|
||||
| `type` | Additional Labels |
|
||||
|--------|-------------------|
|
||||
| `workflow` | `tenant_id`, `app_id` |
|
||||
| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` |
|
||||
| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` |
|
||||
| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
|
||||
| `tool` | `tenant_id`, `app_id`, `tool_name` |
|
||||
| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` |
|
||||
|
||||
### Other Counters
|
||||
|
||||
| Metric | Unit | Labels |
|
||||
|--------|------|--------|
|
||||
| `dify.feedback.total` | `{feedback}` | `tenant_id`, `app_id`, `rating` |
|
||||
| `dify.dataset.retrievals.total` | `{retrieval}` | `tenant_id`, `app_id`, `dataset_id`, `embedding_model_provider`, `embedding_model`, `rerank_model_provider`, `rerank_model` |
|
||||
| `dify.app.created.total` | `{app}` | `tenant_id`, `app_id`, `mode` |
|
||||
| `dify.app.updated.total` | `{app}` | `tenant_id`, `app_id` |
|
||||
| `dify.app.deleted.total` | `{app}` | `tenant_id`, `app_id` |
|
||||
|
||||
## Histograms
|
||||
|
||||
| Metric | Unit | Labels |
|
||||
|--------|------|--------|
|
||||
| `dify.workflow.duration` | `s` | `tenant_id`, `app_id`, `status` |
|
||||
| `dify.node.duration` | `s` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `plugin_name` |
|
||||
| `dify.message.duration` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
|
||||
| `dify.message.time_to_first_token` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
|
||||
| `dify.tool.duration` | `s` | `tenant_id`, `app_id`, `tool_name` |
|
||||
| `dify.prompt_generation.duration` | `s` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` |
|
||||
|
||||
## Structured Logs
|
||||
|
||||
### Span Companion Logs
|
||||
|
||||
Logs that accompany spans. Signal type: `span_detail`
|
||||
|
||||
#### `dify.workflow.run` Companion Log
|
||||
|
||||
**Common attributes:** All span attributes (see Traces section) plus:
|
||||
|
||||
| Additional Attribute | Type | Always Present | Description |
|
||||
|---------------------|------|----------------|-------------|
|
||||
| `dify.app.name` | string | No | Application display name |
|
||||
| `dify.workspace.name` | string | No | Workspace display name |
|
||||
| `dify.workflow.version` | string | Yes | Workflow definition version |
|
||||
| `dify.workflow.inputs` | string/JSON | Yes | Input parameters (content-gated) |
|
||||
| `dify.workflow.outputs` | string/JSON | Yes | Output results (content-gated) |
|
||||
| `dify.workflow.query` | string | No | User query text (content-gated) |
|
||||
|
||||
**Event attributes:**
|
||||
- `dify.event.name`: `"dify.workflow.run"`
|
||||
- `dify.event.signal`: `"span_detail"`
|
||||
- `trace_id`, `span_id`, `tenant_id`, `user_id`
|
||||
|
||||
#### `dify.node.execution` and `dify.node.execution.draft` Companion Logs
|
||||
|
||||
**Common attributes:** All span attributes (see Traces section) plus:
|
||||
|
||||
| Additional Attribute | Type | Always Present | Description |
|
||||
|---------------------|------|----------------|-------------|
|
||||
| `dify.app.name` | string | No | Application display name |
|
||||
| `dify.workspace.name` | string | No | Workspace display name |
|
||||
| `dify.invoke_from` | string | No | Invocation source |
|
||||
| `gen_ai.tool.name` | string | No | Tool name (tool nodes only) |
|
||||
| `dify.node.total_price` | float | No | Cost (LLM nodes only) |
|
||||
| `dify.node.currency` | string | No | Currency code (LLM nodes only) |
|
||||
| `dify.node.iteration_index` | int | No | Iteration index (iteration nodes) |
|
||||
| `dify.node.loop_index` | int | No | Loop index (loop nodes) |
|
||||
| `dify.plugin.name` | string | No | Plugin name (tool/knowledge nodes) |
|
||||
| `dify.credential.name` | string | No | Credential name (plugin nodes) |
|
||||
| `dify.credential.id` | string | No | Credential ID (plugin nodes) |
|
||||
| `dify.dataset.ids` | JSON array | No | Dataset IDs (knowledge nodes) |
|
||||
| `dify.dataset.names` | JSON array | No | Dataset names (knowledge nodes) |
|
||||
| `dify.node.inputs` | string/JSON | Yes | Node inputs (content-gated) |
|
||||
| `dify.node.outputs` | string/JSON | Yes | Node outputs (content-gated) |
|
||||
| `dify.node.process_data` | string/JSON | No | Processing data (content-gated) |
|
||||
|
||||
**Event attributes:**
|
||||
- `dify.event.name`: `"dify.node.execution"` or `"dify.node.execution.draft"`
|
||||
- `dify.event.signal`: `"span_detail"`
|
||||
- `trace_id`, `span_id`, `tenant_id`, `user_id`
|
||||
|
||||
### Standalone Logs
|
||||
|
||||
Logs without structural spans. Signal type: `metric_only`
|
||||
|
||||
#### `dify.message.run`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.message.run"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `trace_id` | string | OTEL trace ID (32-char hex) |
|
||||
| `span_id` | string | OTEL span ID (16-char hex) |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `user_id` | string | User identifier (optional) |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.message.id` | string | Message identifier |
|
||||
| `dify.conversation.id` | string | Conversation ID (optional) |
|
||||
| `dify.workflow.run_id` | string | Workflow run ID (optional) |
|
||||
| `dify.invoke_from` | string | `service-api`, `web-app`, `debugger`, `explore` |
|
||||
| `gen_ai.provider.name` | string | LLM provider |
|
||||
| `gen_ai.request.model` | string | LLM model |
|
||||
| `gen_ai.usage.input_tokens` | int | Input tokens |
|
||||
| `gen_ai.usage.output_tokens` | int | Output tokens |
|
||||
| `gen_ai.usage.total_tokens` | int | Total tokens |
|
||||
| `dify.message.status` | string | `succeeded`, `failed` |
|
||||
| `dify.message.error` | string | Error message (if failed) |
|
||||
| `dify.message.duration` | float | Duration (seconds) |
|
||||
| `dify.message.time_to_first_token` | float | TTFT (seconds) |
|
||||
| `dify.message.inputs` | string/JSON | Inputs (content-gated) |
|
||||
| `dify.message.outputs` | string/JSON | Outputs (content-gated) |
|
||||
|
||||
#### `dify.tool.execution`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.tool.execution"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `trace_id` | string | OTEL trace ID |
|
||||
| `span_id` | string | OTEL span ID |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.message.id` | string | Message identifier |
|
||||
| `dify.tool.name` | string | Tool name |
|
||||
| `dify.tool.duration` | float | Duration (seconds) |
|
||||
| `dify.tool.status` | string | `succeeded`, `failed` |
|
||||
| `dify.tool.error` | string | Error message (if failed) |
|
||||
| `dify.tool.inputs` | string/JSON | Inputs (content-gated) |
|
||||
| `dify.tool.outputs` | string/JSON | Outputs (content-gated) |
|
||||
| `dify.tool.parameters` | string/JSON | Parameters (content-gated) |
|
||||
| `dify.tool.config` | string/JSON | Configuration (content-gated) |
|
||||
|
||||
#### `dify.moderation.check`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.moderation.check"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `trace_id` | string | OTEL trace ID |
|
||||
| `span_id` | string | OTEL span ID |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.message.id` | string | Message identifier |
|
||||
| `dify.moderation.type` | string | `input`, `output` |
|
||||
| `dify.moderation.action` | string | `pass`, `block`, `flag` |
|
||||
| `dify.moderation.flagged` | boolean | Whether flagged |
|
||||
| `dify.moderation.categories` | JSON array | Flagged categories |
|
||||
| `dify.moderation.query` | string | Content (content-gated) |
|
||||
|
||||
#### `dify.suggested_question.generation`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.suggested_question.generation"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `trace_id` | string | OTEL trace ID |
|
||||
| `span_id` | string | OTEL span ID |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.message.id` | string | Message identifier |
|
||||
| `dify.suggested_question.count` | int | Number of questions |
|
||||
| `dify.suggested_question.duration` | float | Duration (seconds) |
|
||||
| `dify.suggested_question.status` | string | `succeeded`, `failed` |
|
||||
| `dify.suggested_question.error` | string | Error message (if failed) |
|
||||
| `dify.suggested_question.questions` | JSON array | Questions (content-gated) |
|
||||
|
||||
#### `dify.dataset.retrieval`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.dataset.retrieval"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `trace_id` | string | OTEL trace ID |
|
||||
| `span_id` | string | OTEL span ID |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.message.id` | string | Message identifier |
|
||||
| `dify.dataset.id` | string | Dataset identifier |
|
||||
| `dify.dataset.name` | string | Dataset name |
|
||||
| `dify.dataset.embedding_providers` | JSON array | Embedding model providers (one per dataset) |
|
||||
| `dify.dataset.embedding_models` | JSON array | Embedding models (one per dataset) |
|
||||
| `dify.retrieval.rerank_provider` | string | Rerank model provider |
|
||||
| `dify.retrieval.rerank_model` | string | Rerank model name |
|
||||
| `dify.retrieval.query` | string | Search query (content-gated) |
|
||||
| `dify.retrieval.document_count` | int | Documents retrieved |
|
||||
| `dify.retrieval.duration` | float | Duration (seconds) |
|
||||
| `dify.retrieval.status` | string | `succeeded`, `failed` |
|
||||
| `dify.retrieval.error` | string | Error message (if failed) |
|
||||
| `dify.dataset.documents` | JSON array | Documents (content-gated) |
|
||||
|
||||
#### `dify.generate_name.execution`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.generate_name.execution"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `trace_id` | string | OTEL trace ID |
|
||||
| `span_id` | string | OTEL span ID |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.conversation.id` | string | Conversation identifier |
|
||||
| `dify.generate_name.duration` | float | Duration (seconds) |
|
||||
| `dify.generate_name.status` | string | `succeeded`, `failed` |
|
||||
| `dify.generate_name.error` | string | Error message (if failed) |
|
||||
| `dify.generate_name.inputs` | string/JSON | Inputs (content-gated) |
|
||||
| `dify.generate_name.outputs` | string | Generated name (content-gated) |
|
||||
|
||||
#### `dify.prompt_generation.execution`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.prompt_generation.execution"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `trace_id` | string | OTEL trace ID |
|
||||
| `span_id` | string | OTEL span ID |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.prompt_generation.operation_type` | string | Operation type (see appendix) |
|
||||
| `gen_ai.provider.name` | string | LLM provider |
|
||||
| `gen_ai.request.model` | string | LLM model |
|
||||
| `gen_ai.usage.input_tokens` | int | Input tokens |
|
||||
| `gen_ai.usage.output_tokens` | int | Output tokens |
|
||||
| `gen_ai.usage.total_tokens` | int | Total tokens |
|
||||
| `dify.prompt_generation.duration` | float | Duration (seconds) |
|
||||
| `dify.prompt_generation.status` | string | `succeeded`, `failed` |
|
||||
| `dify.prompt_generation.error` | string | Error message (if failed) |
|
||||
| `dify.prompt_generation.instruction` | string | Instruction (content-gated) |
|
||||
| `dify.prompt_generation.output` | string/JSON | Output (content-gated) |
|
||||
|
||||
#### `dify.app.created`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.app.created"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.app.mode` | string | `chat`, `completion`, `agent-chat`, `workflow` |
|
||||
| `dify.app.created_at` | string | Timestamp (ISO 8601) |
|
||||
|
||||
#### `dify.app.updated`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.app.updated"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.app.updated_at` | string | Timestamp (ISO 8601) |
|
||||
|
||||
#### `dify.app.deleted`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.app.deleted"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.app.deleted_at` | string | Timestamp (ISO 8601) |
|
||||
|
||||
#### `dify.feedback.created`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.feedback.created"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `trace_id` | string | OTEL trace ID |
|
||||
| `span_id` | string | OTEL span ID |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.message.id` | string | Message identifier |
|
||||
| `dify.feedback.rating` | string | `like`, `dislike`, `null` |
|
||||
| `dify.feedback.content` | string | Feedback text (content-gated) |
|
||||
| `dify.feedback.created_at` | string | Timestamp (ISO 8601) |
|
||||
|
||||
#### `dify.telemetry.rehydration_failed`
|
||||
|
||||
Diagnostic event for telemetry system health monitoring.
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.telemetry.rehydration_failed"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.telemetry.error` | string | Error message |
|
||||
| `dify.telemetry.payload_type` | string | Payload type (see appendix) |
|
||||
| `dify.telemetry.correlation_id` | string | Correlation ID |
|
||||
|
||||
## Content-Gated Attributes
|
||||
|
||||
When `ENTERPRISE_INCLUDE_CONTENT=false`, these attributes are replaced with reference strings (`ref:{id_type}={uuid}`).
|
||||
|
||||
| Attribute | Signal |
|
||||
|-----------|--------|
|
||||
| `dify.workflow.inputs` | `dify.workflow.run` |
|
||||
| `dify.workflow.outputs` | `dify.workflow.run` |
|
||||
| `dify.workflow.query` | `dify.workflow.run` |
|
||||
| `dify.node.inputs` | `dify.node.execution` |
|
||||
| `dify.node.outputs` | `dify.node.execution` |
|
||||
| `dify.node.process_data` | `dify.node.execution` |
|
||||
| `dify.message.inputs` | `dify.message.run` |
|
||||
| `dify.message.outputs` | `dify.message.run` |
|
||||
| `dify.tool.inputs` | `dify.tool.execution` |
|
||||
| `dify.tool.outputs` | `dify.tool.execution` |
|
||||
| `dify.tool.parameters` | `dify.tool.execution` |
|
||||
| `dify.tool.config` | `dify.tool.execution` |
|
||||
| `dify.moderation.query` | `dify.moderation.check` |
|
||||
| `dify.suggested_question.questions` | `dify.suggested_question.generation` |
|
||||
| `dify.retrieval.query` | `dify.dataset.retrieval` |
|
||||
| `dify.dataset.documents` | `dify.dataset.retrieval` |
|
||||
| `dify.generate_name.inputs` | `dify.generate_name.execution` |
|
||||
| `dify.generate_name.outputs` | `dify.generate_name.execution` |
|
||||
| `dify.prompt_generation.instruction` | `dify.prompt_generation.execution` |
|
||||
| `dify.prompt_generation.output` | `dify.prompt_generation.execution` |
|
||||
| `dify.feedback.content` | `dify.feedback.created` |
|
||||
|
||||
## Appendix
|
||||
|
||||
### Operation Types
|
||||
|
||||
- `workflow`, `node_execution`, `message`, `rule_generate`, `code_generate`, `structured_output`, `instruction_modify`
|
||||
|
||||
### Node Types
|
||||
|
||||
- `start`, `end`, `answer`, `llm`, `knowledge-retrieval`, `knowledge-index`, `if-else`, `code`, `template-transform`, `question-classifier`, `http-request`, `tool`, `datasource`, `variable-aggregator`, `loop`, `iteration`, `parameter-extractor`, `assigner`, `document-extractor`, `list-operator`, `agent`, `trigger-webhook`, `trigger-schedule`, `trigger-plugin`, `human-input`
|
||||
|
||||
### Workflow Statuses
|
||||
|
||||
- `running`, `succeeded`, `failed`, `stopped`, `partial-succeeded`, `paused`
|
||||
|
||||
### Payload Types
|
||||
|
||||
- `workflow`, `node`, `message`, `tool`, `moderation`, `suggested_question`, `dataset_retrieval`, `generate_name`, `prompt_generation`, `app`, `feedback`
|
||||
|
||||
### Null Value Behavior
|
||||
|
||||
**Spans:** Attributes with `null` values are omitted.
|
||||
|
||||
**Logs:** Attributes with `null` values appear as `null` in JSON.
|
||||
|
||||
**Content-Gated:** Replaced with reference strings, not set to `null`.
|
||||
116
api/enterprise/telemetry/README.md
Normal file
116
api/enterprise/telemetry/README.md
Normal file
@@ -0,0 +1,116 @@
|
||||
# Dify Enterprise Telemetry
|
||||
|
||||
This document provides an overview of the Dify Enterprise OpenTelemetry (OTEL) exporter and how to configure it for integration with observability stacks like Prometheus, Grafana, Jaeger, or Honeycomb.
|
||||
|
||||
## Overview
|
||||
|
||||
Dify Enterprise uses a "slim span + rich companion log" architecture to provide high-fidelity observability without overwhelming trace storage.
|
||||
|
||||
- **Traces (Spans)**: Capture the structure, identity, and timing of high-level operations (Workflows and Nodes).
|
||||
- **Structured Logs**: Provide deep context (inputs, outputs, metadata) for every event, correlated to spans via `trace_id` and `span_id`.
|
||||
- **Metrics**: Provide 100% accurate counters and histograms for usage, performance, and error tracking.
|
||||
|
||||
### Signal Architecture
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[Workflow Run] -->|Span| B(dify.workflow.run)
|
||||
A -->|Log| C(dify.workflow.run detail)
|
||||
B ---|trace_id| C
|
||||
|
||||
D[Node Execution] -->|Span| E(dify.node.execution)
|
||||
D -->|Log| F(dify.node.execution detail)
|
||||
E ---|span_id| F
|
||||
|
||||
G[Message/Tool/etc] -->|Log| H(dify.* event)
|
||||
G -->|Metric| I(dify.* counter/histogram)
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The Enterprise OTEL exporter is configured via environment variables.
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `ENTERPRISE_ENABLED` | Master switch for all enterprise features. | `false` |
|
||||
| `ENTERPRISE_TELEMETRY_ENABLED` | Master switch for enterprise telemetry. | `false` |
|
||||
| `ENTERPRISE_OTLP_ENDPOINT` | OTLP collector endpoint (e.g., `http://otel-collector:4318`). | - |
|
||||
| `ENTERPRISE_OTLP_HEADERS` | Custom headers for OTLP requests (e.g., `x-scope-orgid=tenant1`). | - |
|
||||
| `ENTERPRISE_OTLP_PROTOCOL` | OTLP transport protocol (`http` or `grpc`). | `http` |
|
||||
| `ENTERPRISE_OTLP_API_KEY` | Bearer token for authentication. | - |
|
||||
| `ENTERPRISE_INCLUDE_CONTENT` | Whether to include sensitive content (inputs/outputs) in logs. | `true` |
|
||||
| `ENTERPRISE_SERVICE_NAME` | Service name reported to OTEL. | `dify` |
|
||||
| `ENTERPRISE_OTEL_SAMPLING_RATE` | Sampling rate for traces (0.0 to 1.0). Metrics are always 100%. | `1.0` |
|
||||
|
||||
## Correlation Model
|
||||
|
||||
Dify uses deterministic ID generation to ensure signals are correlated across different services and asynchronous tasks.
|
||||
|
||||
### ID Generation Rules
|
||||
- `trace_id`: Derived from the correlation ID (workflow_run_id or node_execution_id for drafts) using `int(UUID(correlation_id))`
|
||||
- `span_id`: Derived from the source ID using `SHA256(source_id)[:8]`
|
||||
|
||||
### Scenario A: Simple Workflow
|
||||
A single workflow run with multiple nodes. All spans and logs share the same `trace_id` (derived from `workflow_run_id`).
|
||||
|
||||
```
|
||||
trace_id = UUID(workflow_run_id)
|
||||
├── [root span] dify.workflow.run (span_id = hash(workflow_run_id))
|
||||
│ ├── [child] dify.node.execution - "Start" (span_id = hash(node_exec_id_1))
|
||||
│ ├── [child] dify.node.execution - "LLM" (span_id = hash(node_exec_id_2))
|
||||
│ └── [child] dify.node.execution - "End" (span_id = hash(node_exec_id_3))
|
||||
```
|
||||
|
||||
### Scenario B: Nested Sub-Workflow
|
||||
A workflow calling another workflow via a Tool or Sub-workflow node. The child workflow's spans are linked to the parent via `parent_span_id`. Both workflows share the same trace_id.
|
||||
|
||||
```
|
||||
trace_id = UUID(outer_workflow_run_id) ← shared across both workflows
|
||||
├── [root] dify.workflow.run (outer) (span_id = hash(outer_workflow_run_id))
|
||||
│ ├── dify.node.execution - "Start Node"
|
||||
│ ├── dify.node.execution - "Tool Node" (triggers sub-workflow)
|
||||
│ │ └── [child] dify.workflow.run (inner) (span_id = hash(inner_workflow_run_id))
|
||||
│ │ ├── dify.node.execution - "Inner Start"
|
||||
│ │ └── dify.node.execution - "Inner End"
|
||||
│ └── dify.node.execution - "End Node"
|
||||
```
|
||||
|
||||
**Key attributes for nested workflows:**
|
||||
- Inner workflow's `dify.parent.trace_id` = outer `workflow_run_id`
|
||||
- Inner workflow's `dify.parent.node.execution_id` = tool node's `execution_id`
|
||||
- Inner workflow's `dify.parent.workflow.run_id` = outer `workflow_run_id`
|
||||
- Inner workflow's `dify.parent.app.id` = outer `app_id`
|
||||
|
||||
### Scenario C: Draft Node Execution
|
||||
A single node run in isolation (debugger/preview mode). It creates its own trace where the node span is the root.
|
||||
|
||||
```
|
||||
trace_id = UUID(node_execution_id) ← own trace, NOT part of any workflow
|
||||
└── dify.node.execution.draft (span_id = hash(node_execution_id))
|
||||
```
|
||||
|
||||
**Key difference:** Draft executions use `node_execution_id` as the correlation_id, so they are NOT children of any workflow trace.
|
||||
|
||||
## Content Gating
|
||||
|
||||
When `ENTERPRISE_INCLUDE_CONTENT` is set to `false`, sensitive content attributes (inputs, outputs, queries) are replaced with reference strings (e.g., `ref:workflow_run_id=...`) to prevent data leakage to the OTEL collector.
|
||||
|
||||
**Reference String Format:**
|
||||
|
||||
```
|
||||
ref:{id_type}={uuid}
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
|
||||
```
|
||||
ref:workflow_run_id=550e8400-e29b-41d4-a716-446655440000
|
||||
ref:node_execution_id=660e8400-e29b-41d4-a716-446655440001
|
||||
ref:message_id=770e8400-e29b-41d4-a716-446655440002
|
||||
```
|
||||
|
||||
To retrieve actual content when gating is enabled, query the Dify database using the provided UUID.
|
||||
|
||||
## Reference
|
||||
|
||||
For a complete list of telemetry signals, attributes, and data structures, see [DATA_DICTIONARY.md](./DATA_DICTIONARY.md).
|
||||
0
api/enterprise/telemetry/__init__.py
Normal file
0
api/enterprise/telemetry/__init__.py
Normal file
73
api/enterprise/telemetry/contracts.py
Normal file
73
api/enterprise/telemetry/contracts.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Telemetry gateway contracts and data structures.
|
||||
|
||||
This module defines the envelope format for telemetry events and the routing
|
||||
configuration that determines how each event type is processed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class TelemetryCase(StrEnum):
|
||||
"""Enumeration of all known telemetry event cases."""
|
||||
|
||||
WORKFLOW_RUN = "workflow_run"
|
||||
NODE_EXECUTION = "node_execution"
|
||||
DRAFT_NODE_EXECUTION = "draft_node_execution"
|
||||
MESSAGE_RUN = "message_run"
|
||||
TOOL_EXECUTION = "tool_execution"
|
||||
MODERATION_CHECK = "moderation_check"
|
||||
SUGGESTED_QUESTION = "suggested_question"
|
||||
DATASET_RETRIEVAL = "dataset_retrieval"
|
||||
GENERATE_NAME = "generate_name"
|
||||
PROMPT_GENERATION = "prompt_generation"
|
||||
APP_CREATED = "app_created"
|
||||
APP_UPDATED = "app_updated"
|
||||
APP_DELETED = "app_deleted"
|
||||
FEEDBACK_CREATED = "feedback_created"
|
||||
|
||||
|
||||
class SignalType(StrEnum):
|
||||
"""Signal routing type for telemetry cases."""
|
||||
|
||||
TRACE = "trace"
|
||||
METRIC_LOG = "metric_log"
|
||||
|
||||
|
||||
class CaseRoute(BaseModel):
|
||||
"""Routing configuration for a telemetry case.
|
||||
|
||||
Attributes:
|
||||
signal_type: The type of signal (trace or metric_log).
|
||||
ce_eligible: Whether this case is eligible for community edition tracing.
|
||||
"""
|
||||
|
||||
signal_type: SignalType
|
||||
ce_eligible: bool
|
||||
|
||||
|
||||
class TelemetryEnvelope(BaseModel):
|
||||
"""Envelope for telemetry events.
|
||||
|
||||
Attributes:
|
||||
case: The telemetry case type.
|
||||
tenant_id: The tenant identifier.
|
||||
event_id: Unique event identifier for deduplication.
|
||||
payload: The main event payload (inline for small payloads,
|
||||
empty when offloaded to storage via ``payload_ref``).
|
||||
metadata: Optional metadata dictionary. When the gateway
|
||||
offloads a large payload to object storage, this contains
|
||||
``{"payload_ref": "<storage_key>"}``.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", use_enum_values=False)
|
||||
|
||||
case: TelemetryCase
|
||||
tenant_id: str
|
||||
event_id: str
|
||||
payload: dict[str, Any]
|
||||
metadata: dict[str, Any] | None = None
|
||||
77
api/enterprise/telemetry/draft_trace.py
Normal file
77
api/enterprise/telemetry/draft_trace.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
|
||||
def enqueue_draft_node_execution_trace(
|
||||
*,
|
||||
execution: WorkflowNodeExecutionModel,
|
||||
outputs: Mapping[str, Any] | None,
|
||||
workflow_execution_id: str | None,
|
||||
user_id: str,
|
||||
) -> None:
|
||||
node_data = _build_node_execution_data(
|
||||
execution=execution,
|
||||
outputs=outputs,
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
)
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=execution.tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=execution.app_id,
|
||||
),
|
||||
payload={"node_execution_data": node_data},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _build_node_execution_data(
|
||||
*,
|
||||
execution: WorkflowNodeExecutionModel,
|
||||
outputs: Mapping[str, Any] | None,
|
||||
workflow_execution_id: str | None,
|
||||
) -> dict[str, Any]:
|
||||
metadata = execution.execution_metadata_dict
|
||||
node_outputs = outputs if outputs is not None else execution.outputs_dict
|
||||
execution_id = workflow_execution_id or execution.workflow_run_id or execution.id
|
||||
|
||||
return {
|
||||
"workflow_id": execution.workflow_id,
|
||||
"workflow_execution_id": execution_id,
|
||||
"tenant_id": execution.tenant_id,
|
||||
"app_id": execution.app_id,
|
||||
"node_execution_id": execution.id,
|
||||
"node_id": execution.node_id,
|
||||
"node_type": execution.node_type,
|
||||
"title": execution.title,
|
||||
"status": execution.status,
|
||||
"error": execution.error,
|
||||
"elapsed_time": execution.elapsed_time,
|
||||
"index": execution.index,
|
||||
"predecessor_node_id": execution.predecessor_node_id,
|
||||
"created_at": execution.created_at,
|
||||
"finished_at": execution.finished_at,
|
||||
"total_tokens": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0),
|
||||
"total_price": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0),
|
||||
"currency": metadata.get(WorkflowNodeExecutionMetadataKey.CURRENCY),
|
||||
"tool_name": (metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name")
|
||||
if isinstance(metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict)
|
||||
else None,
|
||||
"iteration_id": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID),
|
||||
"iteration_index": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX),
|
||||
"loop_id": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID),
|
||||
"loop_index": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX),
|
||||
"parallel_id": metadata.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID),
|
||||
"node_inputs": execution.inputs_dict,
|
||||
"node_outputs": node_outputs,
|
||||
"process_data": execution.process_data_dict,
|
||||
}
|
||||
938
api/enterprise/telemetry/enterprise_trace.py
Normal file
938
api/enterprise/telemetry/enterprise_trace.py
Normal file
@@ -0,0 +1,938 @@
|
||||
"""Enterprise trace handler — duck-typed, NOT a BaseTraceInstance subclass.
|
||||
|
||||
Invoked directly in the Celery task, not through OpsTraceManager dispatch.
|
||||
Only requires a matching ``trace(trace_info)`` method signature.
|
||||
|
||||
Signal strategy:
|
||||
- **Traces (spans)**: workflow run, node execution, draft node execution only.
|
||||
- **Metrics + structured logs**: all other event types.
|
||||
|
||||
Token metric labels (unified structure):
|
||||
All token metrics (dify.tokens.input, dify.tokens.output, dify.tokens.total) use the
|
||||
same label set for consistent filtering and aggregation:
|
||||
- tenant_id: Tenant identifier
|
||||
- app_id: Application identifier
|
||||
- operation_type: Source of token usage (workflow | node_execution | message | rule_generate | etc.)
|
||||
- model_provider: LLM provider name (empty string if not applicable)
|
||||
- model_name: LLM model name (empty string if not applicable)
|
||||
- node_type: Workflow node type (empty string if not node_execution)
|
||||
|
||||
This unified structure allows filtering by operation_type to separate:
|
||||
- Workflow-level aggregates (operation_type=workflow)
|
||||
- Individual node executions (operation_type=node_execution)
|
||||
- Direct message calls (operation_type=message)
|
||||
- Prompt generation operations (operation_type=rule_generate, code_generate, etc.)
|
||||
|
||||
Without this, tokens are double-counted when querying totals (workflow totals include
|
||||
node totals, since workflow.total_tokens is the sum of all node tokens).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
DraftNodeExecutionTrace,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
OperationType,
|
||||
PromptGenerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowNodeTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from enterprise.telemetry.entities import (
|
||||
EnterpriseTelemetryCounter,
|
||||
EnterpriseTelemetryEvent,
|
||||
EnterpriseTelemetryHistogram,
|
||||
EnterpriseTelemetrySpan,
|
||||
TokenMetricLabels,
|
||||
)
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event, emit_telemetry_log
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnterpriseOtelTrace:
|
||||
"""Duck-typed enterprise trace handler.
|
||||
|
||||
``*_trace`` methods emit spans (workflow/node only) or structured logs
|
||||
(all other events), plus metrics at 100 % accuracy.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if exporter is None:
|
||||
raise RuntimeError("EnterpriseOtelTrace instantiated but exporter is not initialized")
|
||||
self._exporter = exporter
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo) -> None:
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self._workflow_trace(trace_info)
|
||||
elif isinstance(trace_info, MessageTraceInfo):
|
||||
self._message_trace(trace_info)
|
||||
elif isinstance(trace_info, ToolTraceInfo):
|
||||
self._tool_trace(trace_info)
|
||||
elif isinstance(trace_info, DraftNodeExecutionTrace):
|
||||
self._draft_node_execution_trace(trace_info)
|
||||
elif isinstance(trace_info, WorkflowNodeTraceInfo):
|
||||
self._node_execution_trace(trace_info)
|
||||
elif isinstance(trace_info, ModerationTraceInfo):
|
||||
self._moderation_trace(trace_info)
|
||||
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self._suggested_question_trace(trace_info)
|
||||
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self._dataset_retrieval_trace(trace_info)
|
||||
elif isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self._generate_name_trace(trace_info)
|
||||
elif isinstance(trace_info, PromptGenerationTraceInfo):
|
||||
self._prompt_generation_trace(trace_info)
|
||||
|
||||
def _common_attrs(self, trace_info: BaseTraceInfo) -> dict[str, Any]:
|
||||
metadata = self._metadata(trace_info)
|
||||
tenant_id, app_id, user_id = self._context_ids(trace_info, metadata)
|
||||
return {
|
||||
"dify.trace_id": trace_info.resolved_trace_id,
|
||||
"dify.tenant_id": tenant_id,
|
||||
"dify.app_id": app_id,
|
||||
"dify.app.name": metadata.get("app_name"),
|
||||
"dify.workspace.name": metadata.get("workspace_name"),
|
||||
"gen_ai.user.id": user_id,
|
||||
"dify.message.id": trace_info.message_id,
|
||||
}
|
||||
|
||||
def _metadata(self, trace_info: BaseTraceInfo) -> dict[str, Any]:
|
||||
return trace_info.metadata
|
||||
|
||||
def _context_ids(
|
||||
self,
|
||||
trace_info: BaseTraceInfo,
|
||||
metadata: dict[str, Any],
|
||||
) -> tuple[str | None, str | None, str | None]:
|
||||
tenant_id = getattr(trace_info, "tenant_id", None) or metadata.get("tenant_id")
|
||||
app_id = getattr(trace_info, "app_id", None) or metadata.get("app_id")
|
||||
user_id = getattr(trace_info, "user_id", None) or metadata.get("user_id")
|
||||
return tenant_id, app_id, user_id
|
||||
|
||||
def _labels(self, **values: AttributeValue) -> dict[str, AttributeValue]:
|
||||
return dict(values)
|
||||
|
||||
def _safe_payload_value(self, value: Any) -> str | dict[str, Any] | list[object] | None:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
return cast(dict[str, Any], value)
|
||||
if isinstance(value, list):
|
||||
items: list[object] = []
|
||||
for item in cast(list[object], value):
|
||||
items.append(item)
|
||||
return items
|
||||
return None
|
||||
|
||||
def _content_or_ref(self, value: Any, ref: str) -> Any:
|
||||
if self._exporter.include_content:
|
||||
return self._maybe_json(value)
|
||||
return ref
|
||||
|
||||
def _maybe_json(self, value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return json.dumps(value, default=str)
|
||||
except (TypeError, ValueError):
|
||||
return str(value)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SPAN-emitting handlers (workflow, node execution, draft node)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _workflow_trace(self, info: WorkflowTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
# -- Span attrs: identity + structure + status + timing + gen_ai scalars --
|
||||
span_attrs: dict[str, Any] = {
|
||||
"dify.trace_id": info.resolved_trace_id,
|
||||
"dify.tenant_id": tenant_id,
|
||||
"dify.app_id": app_id,
|
||||
"dify.workflow.id": info.workflow_id,
|
||||
"dify.workflow.run_id": info.workflow_run_id,
|
||||
"dify.workflow.status": info.workflow_run_status,
|
||||
"dify.workflow.error": info.error,
|
||||
"dify.workflow.elapsed_time": info.workflow_run_elapsed_time,
|
||||
"dify.invoke_from": metadata.get("triggered_from"),
|
||||
"dify.conversation.id": info.conversation_id,
|
||||
"dify.message.id": info.message_id,
|
||||
"dify.invoked_by": info.invoked_by,
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"gen_ai.user.id": user_id,
|
||||
}
|
||||
|
||||
trace_correlation_override, parent_span_id_source = info.resolved_parent_context
|
||||
|
||||
parent_ctx = metadata.get("parent_trace_context")
|
||||
if isinstance(parent_ctx, dict):
|
||||
parent_ctx_dict = cast(dict[str, Any], parent_ctx)
|
||||
span_attrs["dify.parent.trace_id"] = parent_ctx_dict.get("trace_id")
|
||||
span_attrs["dify.parent.node.execution_id"] = parent_ctx_dict.get("parent_node_execution_id")
|
||||
span_attrs["dify.parent.workflow.run_id"] = parent_ctx_dict.get("parent_workflow_run_id")
|
||||
span_attrs["dify.parent.app.id"] = parent_ctx_dict.get("parent_app_id")
|
||||
|
||||
self._exporter.export_span(
|
||||
EnterpriseTelemetrySpan.WORKFLOW_RUN,
|
||||
span_attrs,
|
||||
correlation_id=info.workflow_run_id,
|
||||
span_id_source=info.workflow_run_id,
|
||||
start_time=info.start_time,
|
||||
end_time=info.end_time,
|
||||
trace_correlation_override=trace_correlation_override,
|
||||
parent_span_id_source=parent_span_id_source,
|
||||
)
|
||||
|
||||
# -- Companion log: ALL attrs (span + detail) for full picture --
|
||||
log_attrs: dict[str, Any] = {**span_attrs}
|
||||
log_attrs.update(
|
||||
{
|
||||
"dify.app.name": metadata.get("app_name"),
|
||||
"dify.workspace.name": metadata.get("workspace_name"),
|
||||
"gen_ai.user.id": user_id,
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"dify.workflow.version": info.workflow_run_version,
|
||||
}
|
||||
)
|
||||
|
||||
ref = f"ref:workflow_run_id={info.workflow_run_id}"
|
||||
log_attrs["dify.workflow.inputs"] = self._content_or_ref(info.workflow_run_inputs, ref)
|
||||
log_attrs["dify.workflow.outputs"] = self._content_or_ref(info.workflow_run_outputs, ref)
|
||||
log_attrs["dify.workflow.query"] = self._content_or_ref(info.query, ref)
|
||||
|
||||
emit_telemetry_log(
|
||||
event_name=EnterpriseTelemetryEvent.WORKFLOW_RUN,
|
||||
attributes=log_attrs,
|
||||
signal="span_detail",
|
||||
trace_id_source=info.workflow_run_id,
|
||||
span_id_source=info.workflow_run_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# -- Metrics --
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
)
|
||||
token_labels = TokenMetricLabels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
operation_type=OperationType.WORKFLOW,
|
||||
model_provider="",
|
||||
model_name="",
|
||||
node_type="",
|
||||
).to_dict()
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
|
||||
if info.prompt_tokens is not None and info.prompt_tokens > 0:
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels)
|
||||
if info.completion_tokens is not None and info.completion_tokens > 0:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
|
||||
)
|
||||
invoke_from = metadata.get("triggered_from", "")
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="workflow",
|
||||
status=info.workflow_run_status,
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
)
|
||||
# Prefer wall-clock timestamps over the elapsed_time field: elapsed_time defaults
|
||||
# to 0 in the DB and can be stale if the Celery write races with the trace task.
|
||||
# start_time = workflow_run.created_at, end_time = workflow_run.finished_at.
|
||||
if info.start_time and info.end_time:
|
||||
workflow_duration = (info.end_time - info.start_time).total_seconds()
|
||||
elif info.workflow_run_elapsed_time:
|
||||
workflow_duration = float(info.workflow_run_elapsed_time)
|
||||
else:
|
||||
workflow_duration = 0.0
|
||||
self._exporter.record_histogram(
|
||||
EnterpriseTelemetryHistogram.WORKFLOW_DURATION,
|
||||
workflow_duration,
|
||||
self._labels(
|
||||
**labels,
|
||||
status=info.workflow_run_status,
|
||||
),
|
||||
)
|
||||
|
||||
if info.error:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.ERRORS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="workflow",
|
||||
),
|
||||
)
|
||||
|
||||
def _node_execution_trace(self, info: WorkflowNodeTraceInfo) -> None:
|
||||
self._emit_node_execution_trace(info, EnterpriseTelemetrySpan.NODE_EXECUTION, "node")
|
||||
|
||||
def _draft_node_execution_trace(self, info: DraftNodeExecutionTrace) -> None:
|
||||
self._emit_node_execution_trace(
|
||||
info,
|
||||
EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION,
|
||||
"draft_node",
|
||||
correlation_id_override=info.node_execution_id,
|
||||
trace_correlation_override_param=info.workflow_run_id,
|
||||
)
|
||||
|
||||
def _emit_node_execution_trace(
|
||||
self,
|
||||
info: WorkflowNodeTraceInfo,
|
||||
span_name: EnterpriseTelemetrySpan,
|
||||
request_type: str,
|
||||
correlation_id_override: str | None = None,
|
||||
trace_correlation_override_param: str | None = None,
|
||||
) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
# -- Span attrs: identity + structure + status + timing + gen_ai scalars --
|
||||
span_attrs: dict[str, Any] = {
|
||||
"dify.trace_id": info.resolved_trace_id,
|
||||
"dify.tenant_id": tenant_id,
|
||||
"dify.app_id": app_id,
|
||||
"dify.workflow.id": info.workflow_id,
|
||||
"dify.workflow.run_id": info.workflow_run_id,
|
||||
"dify.message.id": info.message_id,
|
||||
"dify.conversation.id": metadata.get("conversation_id"),
|
||||
"dify.node.execution_id": info.node_execution_id,
|
||||
"dify.node.id": info.node_id,
|
||||
"dify.node.type": info.node_type,
|
||||
"dify.node.title": info.title,
|
||||
"dify.node.status": info.status,
|
||||
"dify.node.error": info.error,
|
||||
"dify.node.elapsed_time": info.elapsed_time,
|
||||
"dify.node.index": info.index,
|
||||
"dify.node.predecessor_node_id": info.predecessor_node_id,
|
||||
"dify.node.iteration_id": info.iteration_id,
|
||||
"dify.node.loop_id": info.loop_id,
|
||||
"dify.node.parallel_id": info.parallel_id,
|
||||
"dify.node.invoked_by": info.invoked_by,
|
||||
"gen_ai.usage.input_tokens": info.prompt_tokens,
|
||||
"gen_ai.usage.output_tokens": info.completion_tokens,
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"gen_ai.request.model": info.model_name,
|
||||
"gen_ai.provider.name": info.model_provider,
|
||||
"gen_ai.user.id": user_id,
|
||||
}
|
||||
|
||||
resolved_override, _ = info.resolved_parent_context
|
||||
trace_correlation_override = trace_correlation_override_param or resolved_override
|
||||
|
||||
effective_correlation_id = correlation_id_override or info.workflow_run_id
|
||||
self._exporter.export_span(
|
||||
span_name,
|
||||
span_attrs,
|
||||
correlation_id=effective_correlation_id,
|
||||
span_id_source=info.node_execution_id,
|
||||
start_time=info.start_time,
|
||||
end_time=info.end_time,
|
||||
trace_correlation_override=trace_correlation_override,
|
||||
)
|
||||
|
||||
# -- Companion log: ALL attrs (span + detail) --
|
||||
log_attrs: dict[str, Any] = {**span_attrs}
|
||||
log_attrs.update(
|
||||
{
|
||||
"dify.app.name": metadata.get("app_name"),
|
||||
"dify.workspace.name": metadata.get("workspace_name"),
|
||||
"dify.invoke_from": metadata.get("invoke_from"),
|
||||
"gen_ai.user.id": user_id,
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"dify.node.total_price": info.total_price,
|
||||
"dify.node.currency": info.currency,
|
||||
"gen_ai.provider.name": info.model_provider,
|
||||
"gen_ai.request.model": info.model_name,
|
||||
"gen_ai.tool.name": info.tool_name,
|
||||
"dify.node.iteration_index": info.iteration_index,
|
||||
"dify.node.loop_index": info.loop_index,
|
||||
"dify.plugin.name": metadata.get("plugin_name"),
|
||||
"dify.credential.name": metadata.get("credential_name"),
|
||||
"dify.credential.id": metadata.get("credential_id"),
|
||||
"dify.dataset.ids": self._maybe_json(metadata.get("dataset_ids")),
|
||||
"dify.dataset.names": self._maybe_json(metadata.get("dataset_names")),
|
||||
}
|
||||
)
|
||||
|
||||
ref = f"ref:node_execution_id={info.node_execution_id}"
|
||||
log_attrs["dify.node.inputs"] = self._content_or_ref(info.node_inputs, ref)
|
||||
log_attrs["dify.node.outputs"] = self._content_or_ref(info.node_outputs, ref)
|
||||
log_attrs["dify.node.process_data"] = self._content_or_ref(info.process_data, ref)
|
||||
|
||||
emit_telemetry_log(
|
||||
event_name=span_name.value,
|
||||
attributes=log_attrs,
|
||||
signal="span_detail",
|
||||
trace_id_source=info.workflow_run_id,
|
||||
span_id_source=info.node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# -- Metrics --
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
node_type=info.node_type,
|
||||
model_provider=info.model_provider or "",
|
||||
)
|
||||
if info.total_tokens:
|
||||
token_labels = TokenMetricLabels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
operation_type=OperationType.NODE_EXECUTION,
|
||||
model_provider=info.model_provider or "",
|
||||
model_name=info.model_name or "",
|
||||
node_type=info.node_type,
|
||||
).to_dict()
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
|
||||
if info.prompt_tokens is not None and info.prompt_tokens > 0:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels
|
||||
)
|
||||
if info.completion_tokens is not None and info.completion_tokens > 0:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type=request_type,
|
||||
status=info.status,
|
||||
model_name=info.model_name or "",
|
||||
),
|
||||
)
|
||||
duration_labels = dict(labels)
|
||||
duration_labels["model_name"] = info.model_name or ""
|
||||
plugin_name = metadata.get("plugin_name")
|
||||
if plugin_name and info.node_type in {"tool", "knowledge-retrieval"}:
|
||||
duration_labels["plugin_name"] = plugin_name
|
||||
self._exporter.record_histogram(EnterpriseTelemetryHistogram.NODE_DURATION, info.elapsed_time, duration_labels)
|
||||
|
||||
if info.error:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.ERRORS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type=request_type,
|
||||
model_name=info.model_name or "",
|
||||
),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# METRIC-ONLY handlers (structured log + counters/histograms)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _message_trace(self, info: MessageTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs.update(
|
||||
{
|
||||
"dify.invoke_from": metadata.get("from_source"),
|
||||
"dify.conversation.id": metadata.get("conversation_id"),
|
||||
"dify.conversation.mode": info.conversation_mode,
|
||||
"gen_ai.provider.name": metadata.get("ls_provider"),
|
||||
"gen_ai.request.model": metadata.get("ls_model_name"),
|
||||
"gen_ai.usage.input_tokens": info.message_tokens,
|
||||
"gen_ai.usage.output_tokens": info.answer_tokens,
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"dify.message.status": metadata.get("status"),
|
||||
"dify.message.error": info.error,
|
||||
"dify.message.from_source": metadata.get("from_source"),
|
||||
"dify.message.from_end_user_id": metadata.get("from_end_user_id"),
|
||||
"dify.message.from_account_id": metadata.get("from_account_id"),
|
||||
"dify.streaming": info.is_streaming_request,
|
||||
"dify.message.time_to_first_token": info.gen_ai_server_time_to_first_token,
|
||||
"dify.message.streaming_duration": info.llm_streaming_time_to_generate,
|
||||
"dify.workflow.run_id": metadata.get("workflow_run_id"),
|
||||
}
|
||||
)
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
ref = f"ref:message_id={info.message_id}"
|
||||
inputs = self._safe_payload_value(info.inputs)
|
||||
outputs = self._safe_payload_value(info.outputs)
|
||||
attrs["dify.message.inputs"] = self._content_or_ref(inputs, ref)
|
||||
attrs["dify.message.outputs"] = self._content_or_ref(outputs, ref)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.MESSAGE_RUN,
|
||||
attributes=attrs,
|
||||
trace_id_source=metadata.get("workflow_run_id") or str(info.message_id) if info.message_id else None,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
model_provider=metadata.get("ls_provider") or "",
|
||||
model_name=metadata.get("ls_model_name") or "",
|
||||
)
|
||||
token_labels = TokenMetricLabels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
operation_type=OperationType.MESSAGE,
|
||||
model_provider=metadata.get("ls_provider") or "",
|
||||
model_name=metadata.get("ls_model_name") or "",
|
||||
node_type="",
|
||||
).to_dict()
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
|
||||
if info.message_tokens > 0:
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.message_tokens, token_labels)
|
||||
if info.answer_tokens > 0:
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.answer_tokens, token_labels)
|
||||
invoke_from = metadata.get("from_source", "")
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="message",
|
||||
status=metadata.get("status", ""),
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
)
|
||||
|
||||
if info.start_time and info.end_time:
|
||||
duration = (info.end_time - info.start_time).total_seconds()
|
||||
self._exporter.record_histogram(EnterpriseTelemetryHistogram.MESSAGE_DURATION, duration, labels)
|
||||
|
||||
if info.gen_ai_server_time_to_first_token is not None:
|
||||
self._exporter.record_histogram(
|
||||
EnterpriseTelemetryHistogram.MESSAGE_TTFT, info.gen_ai_server_time_to_first_token, labels
|
||||
)
|
||||
|
||||
if info.error:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.ERRORS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="message",
|
||||
),
|
||||
)
|
||||
|
||||
def _tool_trace(self, info: ToolTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs.update(
|
||||
{
|
||||
"gen_ai.tool.name": info.tool_name,
|
||||
"dify.tool.time_cost": info.time_cost,
|
||||
"dify.tool.error": info.error,
|
||||
"dify.workflow.run_id": metadata.get("workflow_run_id"),
|
||||
}
|
||||
)
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
ref = f"ref:message_id={info.message_id}"
|
||||
attrs["dify.tool.inputs"] = self._content_or_ref(info.tool_inputs, ref)
|
||||
attrs["dify.tool.outputs"] = self._content_or_ref(info.tool_outputs, ref)
|
||||
attrs["dify.tool.parameters"] = self._content_or_ref(info.tool_parameters, ref)
|
||||
attrs["dify.tool.config"] = self._content_or_ref(info.tool_config, ref)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.TOOL_EXECUTION,
|
||||
attributes=attrs,
|
||||
trace_id_source=info.resolved_trace_id,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
tool_name=info.tool_name,
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="tool",
|
||||
),
|
||||
)
|
||||
self._exporter.record_histogram(EnterpriseTelemetryHistogram.TOOL_DURATION, float(info.time_cost), labels)
|
||||
|
||||
if info.error:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.ERRORS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="tool",
|
||||
),
|
||||
)
|
||||
|
||||
def _moderation_trace(self, info: ModerationTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs.update(
|
||||
{
|
||||
"dify.moderation.flagged": info.flagged,
|
||||
"dify.moderation.action": info.action,
|
||||
"dify.moderation.preset_response": info.preset_response,
|
||||
"dify.workflow.run_id": metadata.get("workflow_run_id"),
|
||||
}
|
||||
)
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
attrs["dify.moderation.query"] = self._content_or_ref(
|
||||
info.query,
|
||||
f"ref:message_id={info.message_id}",
|
||||
)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.MODERATION_CHECK,
|
||||
attributes=attrs,
|
||||
trace_id_source=info.resolved_trace_id,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="moderation",
|
||||
),
|
||||
)
|
||||
|
||||
def _suggested_question_trace(self, info: SuggestedQuestionTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs.update(
|
||||
{
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"dify.suggested_question.status": info.status,
|
||||
"dify.suggested_question.error": info.error,
|
||||
"gen_ai.provider.name": info.model_provider,
|
||||
"gen_ai.request.model": info.model_id,
|
||||
"dify.suggested_question.count": len(info.suggested_question),
|
||||
"dify.workflow.run_id": metadata.get("workflow_run_id"),
|
||||
}
|
||||
)
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
attrs["dify.suggested_question.questions"] = self._content_or_ref(
|
||||
info.suggested_question,
|
||||
f"ref:message_id={info.message_id}",
|
||||
)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION,
|
||||
attributes=attrs,
|
||||
trace_id_source=info.resolved_trace_id,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="suggested_question",
|
||||
model_provider=info.model_provider or "",
|
||||
model_name=info.model_id or "",
|
||||
),
|
||||
)
|
||||
|
||||
def _dataset_retrieval_trace(self, info: DatasetRetrievalTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs["dify.dataset.error"] = info.error
|
||||
attrs["dify.workflow.run_id"] = metadata.get("workflow_run_id")
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
docs: list[dict[str, Any]] = []
|
||||
documents_any: Any = info.documents
|
||||
documents_list: list[Any] = cast(list[Any], documents_any) if isinstance(documents_any, list) else []
|
||||
for entry in documents_list:
|
||||
if isinstance(entry, dict):
|
||||
entry_dict: dict[str, Any] = cast(dict[str, Any], entry)
|
||||
docs.append(entry_dict)
|
||||
dataset_ids: list[str] = []
|
||||
dataset_names: list[str] = []
|
||||
structured_docs: list[dict[str, Any]] = []
|
||||
for doc in docs:
|
||||
meta_raw = doc.get("metadata")
|
||||
meta: dict[str, Any] = cast(dict[str, Any], meta_raw) if isinstance(meta_raw, dict) else {}
|
||||
did = meta.get("dataset_id")
|
||||
dname = meta.get("dataset_name")
|
||||
if did and did not in dataset_ids:
|
||||
dataset_ids.append(did)
|
||||
if dname and dname not in dataset_names:
|
||||
dataset_names.append(dname)
|
||||
structured_docs.append(
|
||||
{
|
||||
"dataset_id": did,
|
||||
"document_id": meta.get("document_id"),
|
||||
"segment_id": meta.get("segment_id"),
|
||||
"score": meta.get("score"),
|
||||
}
|
||||
)
|
||||
|
||||
attrs["dify.dataset.ids"] = self._maybe_json(dataset_ids)
|
||||
attrs["dify.dataset.names"] = self._maybe_json(dataset_names)
|
||||
attrs["dify.retrieval.document_count"] = len(docs)
|
||||
|
||||
embedding_models_raw: Any = metadata.get("embedding_models")
|
||||
embedding_models: dict[str, Any] = (
|
||||
cast(dict[str, Any], embedding_models_raw) if isinstance(embedding_models_raw, dict) else {}
|
||||
)
|
||||
if embedding_models:
|
||||
providers: list[str] = []
|
||||
models: list[str] = []
|
||||
for ds_info in embedding_models.values():
|
||||
if isinstance(ds_info, dict):
|
||||
ds_info_dict: dict[str, Any] = cast(dict[str, Any], ds_info)
|
||||
p = ds_info_dict.get("embedding_model_provider", "")
|
||||
m = ds_info_dict.get("embedding_model", "")
|
||||
if p and p not in providers:
|
||||
providers.append(p)
|
||||
if m and m not in models:
|
||||
models.append(m)
|
||||
attrs["dify.dataset.embedding_providers"] = self._maybe_json(providers)
|
||||
attrs["dify.dataset.embedding_models"] = self._maybe_json(models)
|
||||
|
||||
# Add rerank model to logs
|
||||
rerank_provider = metadata.get("rerank_model_provider", "")
|
||||
rerank_model = metadata.get("rerank_model_name", "")
|
||||
if rerank_provider or rerank_model:
|
||||
attrs["dify.retrieval.rerank_provider"] = rerank_provider
|
||||
attrs["dify.retrieval.rerank_model"] = rerank_model
|
||||
|
||||
ref = f"ref:message_id={info.message_id}"
|
||||
retrieval_inputs = self._safe_payload_value(info.inputs)
|
||||
attrs["dify.retrieval.query"] = self._content_or_ref(retrieval_inputs, ref)
|
||||
attrs["dify.dataset.documents"] = self._content_or_ref(structured_docs, ref)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.DATASET_RETRIEVAL,
|
||||
attributes=attrs,
|
||||
trace_id_source=metadata.get("workflow_run_id") or str(info.message_id) if info.message_id else None,
|
||||
span_id_source=node_execution_id or (str(info.message_id) if info.message_id else None),
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="dataset_retrieval",
|
||||
),
|
||||
)
|
||||
|
||||
for did in dataset_ids:
|
||||
# Get embedding model for this specific dataset
|
||||
ds_embedding_info = embedding_models.get(did, {})
|
||||
embedding_provider = ds_embedding_info.get("embedding_model_provider", "")
|
||||
embedding_model = ds_embedding_info.get("embedding_model", "")
|
||||
|
||||
# Get rerank model (same for all datasets in this retrieval)
|
||||
rerank_provider = metadata.get("rerank_model_provider", "")
|
||||
rerank_model = metadata.get("rerank_model_name", "")
|
||||
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.DATASET_RETRIEVALS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
dataset_id=did,
|
||||
embedding_model_provider=embedding_provider,
|
||||
embedding_model=embedding_model,
|
||||
rerank_model_provider=rerank_provider,
|
||||
rerank_model=rerank_model,
|
||||
),
|
||||
)
|
||||
|
||||
def _generate_name_trace(self, info: GenerateNameTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs["dify.conversation.id"] = info.conversation_id
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
ref = f"ref:conversation_id={info.conversation_id}"
|
||||
inputs = self._safe_payload_value(info.inputs)
|
||||
outputs = self._safe_payload_value(info.outputs)
|
||||
attrs["dify.generate_name.inputs"] = self._content_or_ref(inputs, ref)
|
||||
attrs["dify.generate_name.outputs"] = self._content_or_ref(outputs, ref)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION,
|
||||
attributes=attrs,
|
||||
trace_id_source=info.resolved_trace_id,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="generate_name",
|
||||
),
|
||||
)
|
||||
|
||||
def _prompt_generation_trace(self, info: PromptGenerationTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = {
|
||||
"dify.trace_id": info.resolved_trace_id,
|
||||
"dify.tenant_id": tenant_id,
|
||||
"dify.user.id": user_id,
|
||||
"dify.app.id": app_id or "",
|
||||
"dify.app.name": metadata.get("app_name"),
|
||||
"dify.workspace.name": metadata.get("workspace_name"),
|
||||
"dify.operation.type": info.operation_type,
|
||||
"gen_ai.provider.name": info.model_provider,
|
||||
"gen_ai.request.model": info.model_name,
|
||||
"gen_ai.usage.input_tokens": info.prompt_tokens,
|
||||
"gen_ai.usage.output_tokens": info.completion_tokens,
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"dify.prompt_generation.latency": info.latency,
|
||||
"dify.prompt_generation.error": info.error,
|
||||
}
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
if info.total_price is not None:
|
||||
attrs["dify.prompt_generation.total_price"] = info.total_price
|
||||
attrs["dify.prompt_generation.currency"] = info.currency
|
||||
|
||||
ref = f"ref:trace_id={info.trace_id}"
|
||||
outputs = self._safe_payload_value(info.outputs)
|
||||
attrs["dify.prompt_generation.instruction"] = self._content_or_ref(info.instruction, ref)
|
||||
attrs["dify.prompt_generation.output"] = self._content_or_ref(outputs, ref)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION,
|
||||
attributes=attrs,
|
||||
trace_id_source=info.resolved_trace_id,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
token_labels = TokenMetricLabels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
operation_type=info.operation_type,
|
||||
model_provider=info.model_provider,
|
||||
model_name=info.model_name,
|
||||
node_type="",
|
||||
).to_dict()
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
operation_type=info.operation_type,
|
||||
model_provider=info.model_provider,
|
||||
model_name=info.model_name,
|
||||
)
|
||||
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
|
||||
if info.prompt_tokens > 0:
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels)
|
||||
if info.completion_tokens > 0:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
|
||||
)
|
||||
|
||||
status = "failed" if info.error else "success"
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="prompt_generation",
|
||||
status=status,
|
||||
),
|
||||
)
|
||||
|
||||
self._exporter.record_histogram(
|
||||
EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION,
|
||||
info.latency,
|
||||
labels,
|
||||
)
|
||||
|
||||
if info.error:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.ERRORS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="prompt_generation",
|
||||
),
|
||||
)
|
||||
121
api/enterprise/telemetry/entities/__init__.py
Normal file
121
api/enterprise/telemetry/entities/__init__.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from enum import StrEnum
|
||||
from typing import cast
|
||||
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class EnterpriseTelemetrySpan(StrEnum):
|
||||
WORKFLOW_RUN = "dify.workflow.run"
|
||||
NODE_EXECUTION = "dify.node.execution"
|
||||
DRAFT_NODE_EXECUTION = "dify.node.execution.draft"
|
||||
|
||||
|
||||
class EnterpriseTelemetryEvent(StrEnum):
|
||||
"""Event names for enterprise telemetry logs."""
|
||||
|
||||
APP_CREATED = "dify.app.created"
|
||||
APP_UPDATED = "dify.app.updated"
|
||||
APP_DELETED = "dify.app.deleted"
|
||||
FEEDBACK_CREATED = "dify.feedback.created"
|
||||
WORKFLOW_RUN = "dify.workflow.run"
|
||||
MESSAGE_RUN = "dify.message.run"
|
||||
TOOL_EXECUTION = "dify.tool.execution"
|
||||
MODERATION_CHECK = "dify.moderation.check"
|
||||
SUGGESTED_QUESTION_GENERATION = "dify.suggested_question.generation"
|
||||
DATASET_RETRIEVAL = "dify.dataset.retrieval"
|
||||
GENERATE_NAME_EXECUTION = "dify.generate_name.execution"
|
||||
PROMPT_GENERATION_EXECUTION = "dify.prompt_generation.execution"
|
||||
REHYDRATION_FAILED = "dify.telemetry.rehydration_failed"
|
||||
|
||||
|
||||
class EnterpriseTelemetryCounter(StrEnum):
|
||||
TOKENS = "tokens"
|
||||
INPUT_TOKENS = "input_tokens"
|
||||
OUTPUT_TOKENS = "output_tokens"
|
||||
REQUESTS = "requests"
|
||||
ERRORS = "errors"
|
||||
FEEDBACK = "feedback"
|
||||
DATASET_RETRIEVALS = "dataset_retrievals"
|
||||
APP_CREATED = "app_created"
|
||||
APP_UPDATED = "app_updated"
|
||||
APP_DELETED = "app_deleted"
|
||||
|
||||
|
||||
class EnterpriseTelemetryHistogram(StrEnum):
|
||||
WORKFLOW_DURATION = "workflow_duration"
|
||||
NODE_DURATION = "node_duration"
|
||||
MESSAGE_DURATION = "message_duration"
|
||||
MESSAGE_TTFT = "message_ttft"
|
||||
TOOL_DURATION = "tool_duration"
|
||||
PROMPT_GENERATION_DURATION = "prompt_generation_duration"
|
||||
|
||||
|
||||
class TokenMetricLabels(BaseModel):
|
||||
"""Unified label structure for all dify.token.* metrics.
|
||||
|
||||
All token counters (dify.tokens.input, dify.tokens.output, dify.tokens.total) MUST
|
||||
use this exact label set to ensure consistent filtering and aggregation across
|
||||
different operation types.
|
||||
|
||||
Attributes:
|
||||
tenant_id: Tenant identifier.
|
||||
app_id: Application identifier.
|
||||
operation_type: Source of token usage (workflow | node_execution | message |
|
||||
rule_generate | code_generate | structured_output | instruction_modify).
|
||||
model_provider: LLM provider name. Empty string if not applicable (e.g., workflow-level).
|
||||
model_name: LLM model name. Empty string if not applicable (e.g., workflow-level).
|
||||
node_type: Workflow node type. Empty string unless operation_type=node_execution.
|
||||
|
||||
Usage:
|
||||
labels = TokenMetricLabels(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
operation_type=OperationType.WORKFLOW,
|
||||
model_provider="",
|
||||
model_name="",
|
||||
node_type="",
|
||||
)
|
||||
exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.INPUT_TOKENS,
|
||||
100,
|
||||
labels.to_dict()
|
||||
)
|
||||
|
||||
Design rationale:
|
||||
Without this unified structure, tokens get double-counted when querying totals
|
||||
because workflow.total_tokens is already the sum of all node tokens. The
|
||||
operation_type label allows filtering to separate workflow-level aggregates from
|
||||
node-level detail, while keeping the same label cardinality for consistent queries.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
operation_type: str
|
||||
model_provider: str
|
||||
model_name: str
|
||||
node_type: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
def to_dict(self) -> dict[str, AttributeValue]:
|
||||
return cast(
|
||||
dict[str, AttributeValue],
|
||||
{
|
||||
"tenant_id": self.tenant_id,
|
||||
"app_id": self.app_id,
|
||||
"operation_type": self.operation_type,
|
||||
"model_provider": self.model_provider,
|
||||
"model_name": self.model_name,
|
||||
"node_type": self.node_type,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EnterpriseTelemetryCounter",
|
||||
"EnterpriseTelemetryEvent",
|
||||
"EnterpriseTelemetryHistogram",
|
||||
"EnterpriseTelemetrySpan",
|
||||
"TokenMetricLabels",
|
||||
]
|
||||
99
api/enterprise/telemetry/event_handlers.py
Normal file
99
api/enterprise/telemetry/event_handlers.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Blinker signal handlers for enterprise telemetry.
|
||||
|
||||
Registered at import time via ``@signal.connect`` decorators.
|
||||
Import must happen during ``ext_enterprise_telemetry.init_app()`` to
|
||||
ensure handlers fire. Each handler delegates to ``core.telemetry.gateway``
|
||||
which handles routing, EE-gating, and dispatch.
|
||||
|
||||
All handlers are best-effort: exceptions are caught and logged so that
|
||||
telemetry failures never break user-facing operations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from events.app_event import app_was_created, app_was_deleted, app_was_updated
|
||||
from events.feedback_event import feedback_was_created
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"_handle_app_created",
|
||||
"_handle_app_deleted",
|
||||
"_handle_app_updated",
|
||||
"_handle_feedback_created",
|
||||
]
|
||||
|
||||
|
||||
@app_was_created.connect
|
||||
def _handle_app_created(sender: object, **kwargs: object) -> None:
|
||||
try:
|
||||
from core.telemetry.gateway import emit as gateway_emit
|
||||
from enterprise.telemetry.contracts import TelemetryCase
|
||||
|
||||
gateway_emit(
|
||||
case=TelemetryCase.APP_CREATED,
|
||||
context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")},
|
||||
payload={
|
||||
"app_id": getattr(sender, "id", None),
|
||||
"mode": getattr(sender, "mode", None),
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to emit app_created telemetry", exc_info=True)
|
||||
|
||||
|
||||
@app_was_deleted.connect
|
||||
def _handle_app_deleted(sender: object, **kwargs: object) -> None:
|
||||
try:
|
||||
from core.telemetry.gateway import emit as gateway_emit
|
||||
from enterprise.telemetry.contracts import TelemetryCase
|
||||
|
||||
gateway_emit(
|
||||
case=TelemetryCase.APP_DELETED,
|
||||
context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")},
|
||||
payload={"app_id": getattr(sender, "id", None)},
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to emit app_deleted telemetry", exc_info=True)
|
||||
|
||||
|
||||
@app_was_updated.connect
|
||||
def _handle_app_updated(sender: object, **kwargs: object) -> None:
|
||||
try:
|
||||
from core.telemetry.gateway import emit as gateway_emit
|
||||
from enterprise.telemetry.contracts import TelemetryCase
|
||||
|
||||
gateway_emit(
|
||||
case=TelemetryCase.APP_UPDATED,
|
||||
context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")},
|
||||
payload={"app_id": getattr(sender, "id", None)},
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to emit app_updated telemetry", exc_info=True)
|
||||
|
||||
|
||||
@feedback_was_created.connect
|
||||
def _handle_feedback_created(sender: object, **kwargs: object) -> None:
|
||||
try:
|
||||
from core.telemetry.gateway import emit as gateway_emit
|
||||
from enterprise.telemetry.contracts import TelemetryCase
|
||||
|
||||
tenant_id = str(kwargs.get("tenant_id", "") or "")
|
||||
gateway_emit(
|
||||
case=TelemetryCase.FEEDBACK_CREATED,
|
||||
context={"tenant_id": tenant_id},
|
||||
payload={
|
||||
"message_id": getattr(sender, "message_id", None),
|
||||
"app_id": getattr(sender, "app_id", None),
|
||||
"conversation_id": getattr(sender, "conversation_id", None),
|
||||
"from_end_user_id": getattr(sender, "from_end_user_id", None),
|
||||
"from_account_id": getattr(sender, "from_account_id", None),
|
||||
"rating": getattr(sender, "rating", None),
|
||||
"from_source": getattr(sender, "from_source", None),
|
||||
"content": getattr(sender, "content", None),
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to emit feedback_created telemetry", exc_info=True)
|
||||
284
api/enterprise/telemetry/exporter.py
Normal file
284
api/enterprise/telemetry/exporter.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""Enterprise OTEL exporter — shared by EnterpriseOtelTrace, event handlers, and direct instrumentation.
|
||||
|
||||
Uses dedicated TracerProvider and MeterProvider instances (configurable sampling,
|
||||
independent from ext_otel.py infrastructure).
|
||||
|
||||
Initialized once during Flask extension init (single-threaded via ext_enterprise_telemetry.py).
|
||||
Accessed via ``ext_enterprise_telemetry.get_enterprise_exporter()`` from any thread/process.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import socket
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, cast
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.context import Context
|
||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.trace import SpanContext, TraceFlags
|
||||
from opentelemetry.util.types import Attributes, AttributeValue
|
||||
|
||||
from configs import dify_config
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram
|
||||
from enterprise.telemetry.id_generator import (
|
||||
CorrelationIdGenerator,
|
||||
compute_deterministic_span_id,
|
||||
set_correlation_id,
|
||||
set_span_id_source,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_enterprise_telemetry_enabled() -> bool:
|
||||
return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED)
|
||||
|
||||
|
||||
def _parse_otlp_headers(raw: str) -> dict[str, str]:
|
||||
"""Parse ``key=value,key2=value2`` into a dict."""
|
||||
if not raw:
|
||||
return {}
|
||||
headers: dict[str, str] = {}
|
||||
for pair in raw.split(","):
|
||||
if "=" not in pair:
|
||||
continue
|
||||
k, v = pair.split("=", 1)
|
||||
headers[k.strip()] = v.strip()
|
||||
return headers
|
||||
|
||||
|
||||
def _datetime_to_ns(dt: datetime) -> int:
|
||||
"""Convert a datetime to nanoseconds since epoch (OTEL convention)."""
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
class _ExporterFactory:
|
||||
def __init__(self, protocol: str, endpoint: str, headers: dict[str, str], insecure: bool):
|
||||
self._protocol = protocol
|
||||
self._endpoint = endpoint
|
||||
self._headers = headers
|
||||
self._grpc_headers = tuple(headers.items()) if headers else None
|
||||
self._http_headers = headers or None
|
||||
self._insecure = insecure
|
||||
|
||||
def create_trace_exporter(self) -> HTTPSpanExporter | GRPCSpanExporter:
|
||||
if self._protocol == "grpc":
|
||||
return GRPCSpanExporter(
|
||||
endpoint=self._endpoint or None,
|
||||
headers=self._grpc_headers,
|
||||
insecure=self._insecure,
|
||||
)
|
||||
trace_endpoint = f"{self._endpoint}/v1/traces" if self._endpoint else ""
|
||||
return HTTPSpanExporter(endpoint=trace_endpoint or None, headers=self._http_headers)
|
||||
|
||||
def create_metric_exporter(self) -> HTTPMetricExporter | GRPCMetricExporter:
|
||||
if self._protocol == "grpc":
|
||||
return GRPCMetricExporter(
|
||||
endpoint=self._endpoint or None,
|
||||
headers=self._grpc_headers,
|
||||
insecure=self._insecure,
|
||||
)
|
||||
metric_endpoint = f"{self._endpoint}/v1/metrics" if self._endpoint else ""
|
||||
return HTTPMetricExporter(endpoint=metric_endpoint or None, headers=self._http_headers)
|
||||
|
||||
|
||||
class EnterpriseExporter:
|
||||
"""Shared OTEL exporter for all enterprise telemetry.
|
||||
|
||||
``export_span`` creates spans with optional real timestamps, deterministic
|
||||
span/trace IDs, and cross-workflow parent linking.
|
||||
``increment_counter`` / ``record_histogram`` emit OTEL metrics at 100% accuracy.
|
||||
"""
|
||||
|
||||
def __init__(self, config: object) -> None:
|
||||
endpoint: str = getattr(config, "ENTERPRISE_OTLP_ENDPOINT", "")
|
||||
headers_raw: str = getattr(config, "ENTERPRISE_OTLP_HEADERS", "")
|
||||
protocol: str = (getattr(config, "ENTERPRISE_OTLP_PROTOCOL", "http") or "http").lower()
|
||||
service_name: str = getattr(config, "ENTERPRISE_SERVICE_NAME", "dify")
|
||||
sampling_rate: float = getattr(config, "ENTERPRISE_OTEL_SAMPLING_RATE", 1.0)
|
||||
self.include_content: bool = getattr(config, "ENTERPRISE_INCLUDE_CONTENT", True)
|
||||
api_key: str = getattr(config, "ENTERPRISE_OTLP_API_KEY", "")
|
||||
|
||||
# Auto-detect TLS: https:// uses secure, everything else is insecure
|
||||
insecure = not endpoint.startswith("https://")
|
||||
|
||||
resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: service_name,
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
}
|
||||
)
|
||||
sampler = ParentBasedTraceIdRatio(sampling_rate)
|
||||
id_generator = CorrelationIdGenerator()
|
||||
self._tracer_provider = TracerProvider(resource=resource, sampler=sampler, id_generator=id_generator)
|
||||
|
||||
headers = _parse_otlp_headers(headers_raw)
|
||||
if api_key:
|
||||
if "authorization" in headers:
|
||||
logger.warning(
|
||||
"ENTERPRISE_OTLP_API_KEY is set but ENTERPRISE_OTLP_HEADERS also contains "
|
||||
"'authorization'; the API key will take precedence."
|
||||
)
|
||||
headers["authorization"] = f"Bearer {api_key}"
|
||||
factory = _ExporterFactory(protocol, endpoint, headers, insecure=insecure)
|
||||
|
||||
trace_exporter = factory.create_trace_exporter()
|
||||
self._tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
|
||||
self._tracer = self._tracer_provider.get_tracer("dify.enterprise")
|
||||
|
||||
metric_exporter = factory.create_metric_exporter()
|
||||
self._meter_provider = MeterProvider(
|
||||
resource=resource,
|
||||
metric_readers=[PeriodicExportingMetricReader(metric_exporter)],
|
||||
)
|
||||
meter = self._meter_provider.get_meter("dify.enterprise")
|
||||
self._counters = {
|
||||
EnterpriseTelemetryCounter.TOKENS: meter.create_counter("dify.tokens.total", unit="{token}"),
|
||||
EnterpriseTelemetryCounter.INPUT_TOKENS: meter.create_counter("dify.tokens.input", unit="{token}"),
|
||||
EnterpriseTelemetryCounter.OUTPUT_TOKENS: meter.create_counter("dify.tokens.output", unit="{token}"),
|
||||
EnterpriseTelemetryCounter.REQUESTS: meter.create_counter("dify.requests.total", unit="{request}"),
|
||||
EnterpriseTelemetryCounter.ERRORS: meter.create_counter("dify.errors.total", unit="{error}"),
|
||||
EnterpriseTelemetryCounter.FEEDBACK: meter.create_counter("dify.feedback.total", unit="{feedback}"),
|
||||
EnterpriseTelemetryCounter.DATASET_RETRIEVALS: meter.create_counter(
|
||||
"dify.dataset.retrievals.total", unit="{retrieval}"
|
||||
),
|
||||
EnterpriseTelemetryCounter.APP_CREATED: meter.create_counter("dify.app.created.total", unit="{app}"),
|
||||
EnterpriseTelemetryCounter.APP_UPDATED: meter.create_counter("dify.app.updated.total", unit="{app}"),
|
||||
EnterpriseTelemetryCounter.APP_DELETED: meter.create_counter("dify.app.deleted.total", unit="{app}"),
|
||||
}
|
||||
self._histograms = {
|
||||
EnterpriseTelemetryHistogram.WORKFLOW_DURATION: meter.create_histogram("dify.workflow.duration", unit="s"),
|
||||
EnterpriseTelemetryHistogram.NODE_DURATION: meter.create_histogram("dify.node.duration", unit="s"),
|
||||
EnterpriseTelemetryHistogram.MESSAGE_DURATION: meter.create_histogram("dify.message.duration", unit="s"),
|
||||
EnterpriseTelemetryHistogram.MESSAGE_TTFT: meter.create_histogram(
|
||||
"dify.message.time_to_first_token", unit="s"
|
||||
),
|
||||
EnterpriseTelemetryHistogram.TOOL_DURATION: meter.create_histogram("dify.tool.duration", unit="s"),
|
||||
EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION: meter.create_histogram(
|
||||
"dify.prompt_generation.duration", unit="s"
|
||||
),
|
||||
}
|
||||
|
||||
def export_span(
|
||||
self,
|
||||
name: str,
|
||||
attributes: dict[str, Any],
|
||||
correlation_id: str | None = None,
|
||||
span_id_source: str | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
trace_correlation_override: str | None = None,
|
||||
parent_span_id_source: str | None = None,
|
||||
) -> None:
|
||||
"""Export an OTEL span with optional deterministic IDs and real timestamps.
|
||||
|
||||
Args:
|
||||
name: Span operation name.
|
||||
attributes: Span attributes dict.
|
||||
correlation_id: Source for trace_id derivation (groups spans in one trace).
|
||||
span_id_source: Source for deterministic span_id (e.g. workflow_run_id or node_execution_id).
|
||||
start_time: Real span start time. When None, uses current time.
|
||||
end_time: Real span end time. When None, span ends immediately.
|
||||
trace_correlation_override: Override trace_id source (for cross-workflow linking).
|
||||
When set, trace_id is derived from this instead of ``correlation_id``.
|
||||
parent_span_id_source: Override parent span_id source (for cross-workflow linking).
|
||||
When set, parent span_id is derived from this value. When None and
|
||||
``correlation_id`` is set, parent is the workflow root span.
|
||||
"""
|
||||
effective_trace_correlation = trace_correlation_override or correlation_id
|
||||
set_correlation_id(effective_trace_correlation)
|
||||
set_span_id_source(span_id_source)
|
||||
|
||||
try:
|
||||
parent_context: Context | None = None
|
||||
# A span is the "root" of its correlation group when span_id_source == correlation_id
|
||||
# (i.e. a workflow root span). All other spans are children.
|
||||
if parent_span_id_source:
|
||||
# Cross-workflow linking: parent is an explicit span (e.g. tool node in outer workflow)
|
||||
parent_span_id = compute_deterministic_span_id(parent_span_id_source)
|
||||
try:
|
||||
parent_trace_id = int(uuid.UUID(effective_trace_correlation)) if effective_trace_correlation else 0
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(
|
||||
"Invalid trace correlation UUID for cross-workflow link: %s, span=%s",
|
||||
effective_trace_correlation,
|
||||
name,
|
||||
)
|
||||
parent_trace_id = 0
|
||||
if parent_trace_id:
|
||||
parent_span_context = SpanContext(
|
||||
trace_id=parent_trace_id,
|
||||
span_id=parent_span_id,
|
||||
is_remote=True,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
)
|
||||
parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context))
|
||||
elif correlation_id and correlation_id != span_id_source:
|
||||
# Child span: parent is the correlation-group root (workflow root span)
|
||||
parent_span_id = compute_deterministic_span_id(correlation_id)
|
||||
try:
|
||||
parent_trace_id = int(uuid.UUID(effective_trace_correlation or correlation_id))
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(
|
||||
"Invalid trace correlation UUID for child span link: %s, span=%s",
|
||||
effective_trace_correlation or correlation_id,
|
||||
name,
|
||||
)
|
||||
parent_trace_id = 0
|
||||
if parent_trace_id:
|
||||
parent_span_context = SpanContext(
|
||||
trace_id=parent_trace_id,
|
||||
span_id=parent_span_id,
|
||||
is_remote=True,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
)
|
||||
parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context))
|
||||
|
||||
span_start_time = _datetime_to_ns(start_time) if start_time is not None else None
|
||||
span_end_on_exit = end_time is None
|
||||
|
||||
with self._tracer.start_as_current_span(
|
||||
name,
|
||||
context=parent_context,
|
||||
start_time=span_start_time,
|
||||
end_on_exit=span_end_on_exit,
|
||||
) as span:
|
||||
for key, value in attributes.items():
|
||||
if value is not None:
|
||||
span.set_attribute(key, value)
|
||||
if end_time is not None:
|
||||
span.end(end_time=_datetime_to_ns(end_time))
|
||||
except Exception:
|
||||
logger.exception("Failed to export span %s", name)
|
||||
finally:
|
||||
set_correlation_id(None)
|
||||
set_span_id_source(None)
|
||||
|
||||
def increment_counter(
|
||||
self, name: EnterpriseTelemetryCounter, value: int, labels: dict[str, AttributeValue]
|
||||
) -> None:
|
||||
counter = self._counters.get(name)
|
||||
if counter:
|
||||
counter.add(value, cast(Attributes, labels))
|
||||
|
||||
def record_histogram(
|
||||
self, name: EnterpriseTelemetryHistogram, value: float, labels: dict[str, AttributeValue]
|
||||
) -> None:
|
||||
histogram = self._histograms.get(name)
|
||||
if histogram:
|
||||
histogram.record(value, cast(Attributes, labels))
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._tracer_provider.shutdown()
|
||||
self._meter_provider.shutdown()
|
||||
76
api/enterprise/telemetry/id_generator.py
Normal file
76
api/enterprise/telemetry/id_generator.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Custom OTEL ID Generator for correlation-based trace/span ID derivation.
|
||||
|
||||
Uses contextvars for thread-safe correlation_id -> trace_id mapping.
|
||||
When a span_id_source is set, the span_id is derived deterministically
|
||||
from that value, enabling any span to reference another as parent
|
||||
without depending on span creation order.
|
||||
"""
|
||||
|
||||
import random
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from typing import cast
|
||||
|
||||
from opentelemetry.sdk.trace.id_generator import IdGenerator
|
||||
|
||||
_correlation_id_context: ContextVar[str | None] = ContextVar("correlation_id", default=None)
|
||||
_span_id_source_context: ContextVar[str | None] = ContextVar("span_id_source", default=None)
|
||||
|
||||
|
||||
def set_correlation_id(correlation_id: str | None) -> None:
|
||||
_correlation_id_context.set(correlation_id)
|
||||
|
||||
|
||||
def get_correlation_id() -> str | None:
|
||||
return _correlation_id_context.get()
|
||||
|
||||
|
||||
def set_span_id_source(source_id: str | None) -> None:
|
||||
"""Set the source for deterministic span_id generation.
|
||||
|
||||
When set, ``generate_span_id()`` derives the span_id from this value
|
||||
(lower 64 bits of the UUID). Pass the ``workflow_run_id`` for workflow
|
||||
root spans or ``node_execution_id`` for node spans.
|
||||
"""
|
||||
_span_id_source_context.set(source_id)
|
||||
|
||||
|
||||
def compute_deterministic_span_id(source_id: str) -> int:
|
||||
"""Derive a deterministic span_id from any UUID string.
|
||||
|
||||
Uses the lower 64 bits of the UUID, guaranteeing non-zero output
|
||||
(OTEL requires span_id != 0).
|
||||
"""
|
||||
span_id = cast(int, uuid.UUID(source_id).int) & ((1 << 64) - 1)
|
||||
return span_id if span_id != 0 else 1
|
||||
|
||||
|
||||
class CorrelationIdGenerator(IdGenerator):
|
||||
"""ID generator that derives trace_id and optionally span_id from context.
|
||||
|
||||
- trace_id: always derived from correlation_id (groups all spans in one trace)
|
||||
- span_id: derived from span_id_source when set (enables deterministic
|
||||
parent-child linking), otherwise random
|
||||
"""
|
||||
|
||||
def generate_trace_id(self) -> int:
|
||||
correlation_id = _correlation_id_context.get()
|
||||
if correlation_id:
|
||||
try:
|
||||
return cast(int, uuid.UUID(correlation_id).int)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
return random.getrandbits(128)
|
||||
|
||||
def generate_span_id(self) -> int:
|
||||
source = _span_id_source_context.get()
|
||||
if source:
|
||||
try:
|
||||
return compute_deterministic_span_id(source)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
span_id = random.getrandbits(64)
|
||||
while span_id == 0:
|
||||
span_id = random.getrandbits(64)
|
||||
return span_id
|
||||
381
api/enterprise/telemetry/metric_handler.py
Normal file
381
api/enterprise/telemetry/metric_handler.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""Enterprise metric/log event handler.
|
||||
|
||||
This module processes metric and log telemetry events after they've been
|
||||
dequeued from the enterprise_telemetry Celery queue. It handles case routing,
|
||||
idempotency checking, and payload rehydration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnterpriseMetricHandler:
|
||||
"""Handler for enterprise metric and log telemetry events.
|
||||
|
||||
Processes envelopes from the enterprise_telemetry queue, routing each
|
||||
case to the appropriate handler method. Implements idempotency checking
|
||||
and payload rehydration with fallback.
|
||||
"""
|
||||
|
||||
def _increment_diagnostic_counter(self, counter_name: str, labels: dict[str, str] | None = None) -> None:
|
||||
"""Increment a diagnostic counter for operational monitoring.
|
||||
|
||||
Args:
|
||||
counter_name: Name of the counter (e.g., 'processed_total', 'deduped_total').
|
||||
labels: Optional labels for the counter.
|
||||
"""
|
||||
try:
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
return
|
||||
|
||||
full_counter_name = f"enterprise_telemetry.handler.{counter_name}"
|
||||
logger.debug(
|
||||
"Diagnostic counter: %s, labels=%s",
|
||||
full_counter_name,
|
||||
labels or {},
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to increment diagnostic counter: %s", counter_name, exc_info=True)
|
||||
|
||||
def handle(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Main entry point for processing telemetry envelopes.
|
||||
|
||||
Args:
|
||||
envelope: The telemetry envelope to process.
|
||||
"""
|
||||
# Check for duplicate events
|
||||
if self._is_duplicate(envelope):
|
||||
logger.debug(
|
||||
"Skipping duplicate event: tenant_id=%s, event_id=%s",
|
||||
envelope.tenant_id,
|
||||
envelope.event_id,
|
||||
)
|
||||
self._increment_diagnostic_counter("deduped_total")
|
||||
return
|
||||
|
||||
# Route to appropriate handler based on case
|
||||
case = envelope.case
|
||||
if case == TelemetryCase.APP_CREATED:
|
||||
self._on_app_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
|
||||
elif case == TelemetryCase.APP_UPDATED:
|
||||
self._on_app_updated(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
|
||||
elif case == TelemetryCase.APP_DELETED:
|
||||
self._on_app_deleted(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
|
||||
elif case == TelemetryCase.FEEDBACK_CREATED:
|
||||
self._on_feedback_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
|
||||
elif case == TelemetryCase.MESSAGE_RUN:
|
||||
self._on_message_run(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
|
||||
elif case == TelemetryCase.TOOL_EXECUTION:
|
||||
self._on_tool_execution(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
|
||||
elif case == TelemetryCase.MODERATION_CHECK:
|
||||
self._on_moderation_check(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
|
||||
elif case == TelemetryCase.SUGGESTED_QUESTION:
|
||||
self._on_suggested_question(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
|
||||
elif case == TelemetryCase.DATASET_RETRIEVAL:
|
||||
self._on_dataset_retrieval(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
|
||||
elif case == TelemetryCase.GENERATE_NAME:
|
||||
self._on_generate_name(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
|
||||
elif case == TelemetryCase.PROMPT_GENERATION:
|
||||
self._on_prompt_generation(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
|
||||
case,
|
||||
envelope.tenant_id,
|
||||
envelope.event_id,
|
||||
)
|
||||
|
||||
def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool:
|
||||
"""Check if this event has already been processed.
|
||||
|
||||
Uses Redis with TTL for deduplication. Returns True if duplicate,
|
||||
False if first time seeing this event.
|
||||
|
||||
Args:
|
||||
envelope: The telemetry envelope to check.
|
||||
|
||||
Returns:
|
||||
True if this event_id has been seen before, False otherwise.
|
||||
"""
|
||||
dedup_key = f"telemetry:dedup:{envelope.tenant_id}:{envelope.event_id}"
|
||||
|
||||
try:
|
||||
# Atomic set-if-not-exists with 1h TTL
|
||||
# Returns True if key was set (first time), None if already exists (duplicate)
|
||||
was_set = redis_client.set(dedup_key, b"1", nx=True, ex=3600)
|
||||
return was_set is None
|
||||
except Exception:
|
||||
# Fail open: if Redis is unavailable, process the event
|
||||
# (prefer occasional duplicate over lost data)
|
||||
logger.warning(
|
||||
"Redis unavailable for deduplication check, processing event anyway: %s",
|
||||
envelope.event_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
def _rehydrate(self, envelope: TelemetryEnvelope) -> dict[str, Any]:
|
||||
"""Rehydrate payload from storage reference or inline data.
|
||||
|
||||
If the envelope payload is empty and metadata contains a
|
||||
``payload_ref``, the full payload is loaded from object storage
|
||||
(where the gateway wrote it as JSON). When both the inline
|
||||
payload and storage resolution fail, a degraded-event marker
|
||||
is emitted so the gap is observable.
|
||||
|
||||
Args:
|
||||
envelope: The telemetry envelope containing payload data.
|
||||
|
||||
Returns:
|
||||
The rehydrated payload dictionary, or ``{}`` on total failure.
|
||||
"""
|
||||
payload = envelope.payload
|
||||
|
||||
# Resolve from object storage when the gateway offloaded a large payload.
|
||||
if not payload and envelope.metadata:
|
||||
payload_ref = envelope.metadata.get("payload_ref")
|
||||
if payload_ref:
|
||||
try:
|
||||
payload_bytes = storage.load(payload_ref)
|
||||
payload = json.loads(payload_bytes.decode("utf-8"))
|
||||
logger.debug("Loaded payload from storage: key=%s", payload_ref)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to load payload from storage: key=%s, event_id=%s",
|
||||
payload_ref,
|
||||
envelope.event_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if not payload:
|
||||
# Storage resolution failed or no data available — emit degraded event.
|
||||
logger.error(
|
||||
"Payload rehydration failed for event_id=%s, tenant_id=%s, case=%s",
|
||||
envelope.event_id,
|
||||
envelope.tenant_id,
|
||||
envelope.case,
|
||||
)
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.REHYDRATION_FAILED,
|
||||
attributes={
|
||||
"dify.tenant_id": envelope.tenant_id,
|
||||
"dify.event_id": envelope.event_id,
|
||||
"dify.case": envelope.case,
|
||||
"rehydration_failed": True,
|
||||
},
|
||||
tenant_id=envelope.tenant_id,
|
||||
)
|
||||
self._increment_diagnostic_counter("rehydration_failed_total")
|
||||
return {}
|
||||
|
||||
return payload
|
||||
|
||||
# Stub methods for each metric/log case
|
||||
# These will be implemented in later tasks with actual emission logic
|
||||
|
||||
def _on_app_created(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle app created event."""
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
logger.debug("No exporter available for APP_CREATED: event_id=%s", envelope.event_id)
|
||||
return
|
||||
|
||||
payload = self._rehydrate(envelope)
|
||||
if not payload:
|
||||
return
|
||||
|
||||
attrs = {
|
||||
"dify.app.id": payload.get("app_id"),
|
||||
"dify.tenant_id": envelope.tenant_id,
|
||||
"dify.event.id": envelope.event_id,
|
||||
"dify.app.mode": payload.get("mode"),
|
||||
}
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.APP_CREATED,
|
||||
attributes=attrs,
|
||||
tenant_id=envelope.tenant_id,
|
||||
)
|
||||
exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.APP_CREATED,
|
||||
1,
|
||||
{
|
||||
"tenant_id": envelope.tenant_id,
|
||||
"app_id": str(payload.get("app_id", "")),
|
||||
"mode": str(payload.get("mode", "")),
|
||||
},
|
||||
)
|
||||
|
||||
def _on_app_updated(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle app updated event."""
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
logger.debug("No exporter available for APP_UPDATED: event_id=%s", envelope.event_id)
|
||||
return
|
||||
|
||||
payload = self._rehydrate(envelope)
|
||||
if not payload:
|
||||
return
|
||||
|
||||
attrs = {
|
||||
"dify.app.id": payload.get("app_id"),
|
||||
"dify.tenant_id": envelope.tenant_id,
|
||||
"dify.event.id": envelope.event_id,
|
||||
}
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.APP_UPDATED,
|
||||
attributes=attrs,
|
||||
tenant_id=envelope.tenant_id,
|
||||
)
|
||||
exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.APP_UPDATED,
|
||||
1,
|
||||
{
|
||||
"tenant_id": envelope.tenant_id,
|
||||
"app_id": str(payload.get("app_id", "")),
|
||||
},
|
||||
)
|
||||
|
||||
def _on_app_deleted(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle app deleted event."""
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
logger.debug("No exporter available for APP_DELETED: event_id=%s", envelope.event_id)
|
||||
return
|
||||
|
||||
payload = self._rehydrate(envelope)
|
||||
if not payload:
|
||||
return
|
||||
|
||||
attrs = {
|
||||
"dify.app.id": payload.get("app_id"),
|
||||
"dify.tenant_id": envelope.tenant_id,
|
||||
"dify.event.id": envelope.event_id,
|
||||
}
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.APP_DELETED,
|
||||
attributes=attrs,
|
||||
tenant_id=envelope.tenant_id,
|
||||
)
|
||||
exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.APP_DELETED,
|
||||
1,
|
||||
{
|
||||
"tenant_id": envelope.tenant_id,
|
||||
"app_id": str(payload.get("app_id", "")),
|
||||
},
|
||||
)
|
||||
|
||||
def _on_feedback_created(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle feedback created event."""
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
logger.debug("No exporter available for FEEDBACK_CREATED: event_id=%s", envelope.event_id)
|
||||
return
|
||||
|
||||
payload = self._rehydrate(envelope)
|
||||
if not payload:
|
||||
return
|
||||
|
||||
include_content = exporter.include_content
|
||||
attrs: dict = {
|
||||
"dify.message.id": payload.get("message_id"),
|
||||
"dify.tenant_id": envelope.tenant_id,
|
||||
"dify.event.id": envelope.event_id,
|
||||
"dify.app_id": payload.get("app_id"),
|
||||
"dify.conversation.id": payload.get("conversation_id"),
|
||||
"gen_ai.user.id": payload.get("from_end_user_id") or payload.get("from_account_id"),
|
||||
"dify.feedback.rating": payload.get("rating"),
|
||||
"dify.feedback.from_source": payload.get("from_source"),
|
||||
}
|
||||
if include_content:
|
||||
attrs["dify.feedback.content"] = payload.get("content")
|
||||
|
||||
user_id = payload.get("from_end_user_id") or payload.get("from_account_id")
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.FEEDBACK_CREATED,
|
||||
attributes=attrs,
|
||||
tenant_id=envelope.tenant_id,
|
||||
user_id=str(user_id or ""),
|
||||
)
|
||||
exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.FEEDBACK,
|
||||
1,
|
||||
{
|
||||
"tenant_id": envelope.tenant_id,
|
||||
"app_id": str(payload.get("app_id", "")),
|
||||
"rating": str(payload.get("rating", "")),
|
||||
},
|
||||
)
|
||||
|
||||
def _on_message_run(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle message run event (stub)."""
|
||||
logger.debug("Processing MESSAGE_RUN: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_tool_execution(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle tool execution event (stub)."""
|
||||
logger.debug("Processing TOOL_EXECUTION: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_moderation_check(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle moderation check event (stub)."""
|
||||
logger.debug("Processing MODERATION_CHECK: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_suggested_question(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle suggested question event (stub)."""
|
||||
logger.debug("Processing SUGGESTED_QUESTION: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_dataset_retrieval(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle dataset retrieval event (stub)."""
|
||||
logger.debug("Processing DATASET_RETRIEVAL: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_generate_name(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle generate name event (stub)."""
|
||||
logger.debug("Processing GENERATE_NAME: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_prompt_generation(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle prompt generation event (stub)."""
|
||||
logger.debug("Processing PROMPT_GENERATION: event_id=%s", envelope.event_id)
|
||||
122
api/enterprise/telemetry/telemetry_log.py
Normal file
122
api/enterprise/telemetry/telemetry_log.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Structured-log emitter for enterprise telemetry events.
|
||||
|
||||
Emits structured JSON log lines correlated with OTEL traces via trace_id.
|
||||
Picked up by ``StructuredJSONFormatter`` → stdout/Loki/Elastic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
|
||||
|
||||
logger = logging.getLogger("dify.telemetry")
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def compute_trace_id_hex(uuid_str: str | None) -> str:
|
||||
"""Convert a business UUID string to a 32-hex OTEL-compatible trace_id.
|
||||
|
||||
Returns empty string when *uuid_str* is ``None`` or invalid.
|
||||
"""
|
||||
if not uuid_str:
|
||||
return ""
|
||||
normalized = uuid_str.strip().lower()
|
||||
if len(normalized) == 32 and all(ch in "0123456789abcdef" for ch in normalized):
|
||||
return normalized
|
||||
try:
|
||||
return f"{uuid.UUID(normalized).int:032x}"
|
||||
except (ValueError, AttributeError):
|
||||
return ""
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def compute_span_id_hex(uuid_str: str | None) -> str:
|
||||
if not uuid_str:
|
||||
return ""
|
||||
normalized = uuid_str.strip().lower()
|
||||
if len(normalized) == 16 and all(ch in "0123456789abcdef" for ch in normalized):
|
||||
return normalized
|
||||
try:
|
||||
from enterprise.telemetry.id_generator import compute_deterministic_span_id
|
||||
|
||||
return f"{compute_deterministic_span_id(normalized):016x}"
|
||||
except (ValueError, AttributeError):
|
||||
return ""
|
||||
|
||||
|
||||
def emit_telemetry_log(
|
||||
*,
|
||||
event_name: str | EnterpriseTelemetryEvent,
|
||||
attributes: dict[str, Any],
|
||||
signal: str = "metric_only",
|
||||
trace_id_source: str | None = None,
|
||||
span_id_source: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit a structured log line for a telemetry event.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
event_name:
|
||||
Canonical event name, e.g. ``"dify.workflow.run"``.
|
||||
attributes:
|
||||
All event-specific attributes (already built by the caller).
|
||||
signal:
|
||||
``"metric_only"`` for events with no span, ``"span_detail"``
|
||||
for detail logs accompanying a slim span.
|
||||
trace_id_source:
|
||||
A UUID string (e.g. ``workflow_run_id``) used to derive a 32-hex
|
||||
trace_id for cross-signal correlation.
|
||||
tenant_id:
|
||||
Tenant identifier (for the ``IdentityContextFilter``).
|
||||
user_id:
|
||||
User identifier (for the ``IdentityContextFilter``).
|
||||
"""
|
||||
if not logger.isEnabledFor(logging.INFO):
|
||||
return
|
||||
attrs = {
|
||||
"dify.event.name": event_name,
|
||||
"dify.event.signal": signal,
|
||||
**attributes,
|
||||
}
|
||||
|
||||
extra: dict[str, Any] = {"attributes": attrs}
|
||||
|
||||
trace_id_hex = compute_trace_id_hex(trace_id_source)
|
||||
if trace_id_hex:
|
||||
extra["trace_id"] = trace_id_hex
|
||||
span_id_hex = compute_span_id_hex(span_id_source)
|
||||
if span_id_hex:
|
||||
extra["span_id"] = span_id_hex
|
||||
if tenant_id:
|
||||
extra["tenant_id"] = tenant_id
|
||||
if user_id:
|
||||
extra["user_id"] = user_id
|
||||
|
||||
logger.info("telemetry.%s", signal, extra=extra)
|
||||
|
||||
|
||||
def emit_metric_only_event(
|
||||
*,
|
||||
event_name: str | EnterpriseTelemetryEvent,
|
||||
attributes: dict[str, Any],
|
||||
trace_id_source: str | None = None,
|
||||
span_id_source: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
emit_telemetry_log(
|
||||
event_name=event_name,
|
||||
attributes=attributes,
|
||||
signal="metric_only",
|
||||
trace_id_source=trace_id_source,
|
||||
span_id_source=span_id_source,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -3,6 +3,12 @@ from blinker import signal
|
||||
# sender: app
|
||||
app_was_created = signal("app-was-created")
|
||||
|
||||
# sender: app
|
||||
app_was_deleted = signal("app-was-deleted")
|
||||
|
||||
# sender: app
|
||||
app_was_updated = signal("app-was-updated")
|
||||
|
||||
# sender: app, kwargs: app_model_config
|
||||
app_model_config_was_updated = signal("app-model-config-was-updated")
|
||||
|
||||
|
||||
4
api/events/feedback_event.py
Normal file
4
api/events/feedback_event.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from blinker import signal
|
||||
|
||||
# sender: MessageFeedback, kwargs: tenant_id
|
||||
feedback_was_created = signal("feedback-was-created")
|
||||
@@ -184,14 +184,8 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"task": "schedule.trigger_provider_refresh_task.trigger_provider_refresh",
|
||||
"schedule": timedelta(minutes=dify_config.TRIGGER_PROVIDER_REFRESH_INTERVAL),
|
||||
}
|
||||
|
||||
if dify_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK:
|
||||
imports.append("schedule.update_api_token_last_used_task")
|
||||
beat_schedule["batch_update_api_token_last_used"] = {
|
||||
"task": "schedule.update_api_token_last_used_task.batch_update_api_token_last_used",
|
||||
"schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL),
|
||||
}
|
||||
|
||||
if dify_config.ENTERPRISE_TELEMETRY_ENABLED:
|
||||
imports.append("tasks.enterprise_telemetry_task")
|
||||
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
|
||||
|
||||
return celery_app
|
||||
|
||||
50
api/extensions/ext_enterprise_telemetry.py
Normal file
50
api/extensions/ext_enterprise_telemetry.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Flask extension for enterprise telemetry lifecycle management.
|
||||
|
||||
Initializes the EnterpriseExporter singleton during ``create_app()``
|
||||
(single-threaded), registers blinker event handlers, and hooks atexit
|
||||
for graceful shutdown.
|
||||
|
||||
Skipped entirely when ``ENTERPRISE_ENABLED`` and ``ENTERPRISE_TELEMETRY_ENABLED``
|
||||
are false (``is_enabled()`` gate).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_app import DifyApp
|
||||
from enterprise.telemetry.exporter import EnterpriseExporter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_exporter: EnterpriseExporter | None = None
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED)
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
global _exporter
|
||||
|
||||
if not is_enabled():
|
||||
return
|
||||
|
||||
from enterprise.telemetry.exporter import EnterpriseExporter
|
||||
|
||||
_exporter = EnterpriseExporter(dify_config)
|
||||
atexit.register(_exporter.shutdown)
|
||||
|
||||
# Import to trigger @signal.connect decorator registration
|
||||
import enterprise.telemetry.event_handlers # noqa: F401 # type: ignore[reportUnusedImport]
|
||||
|
||||
logger.info("Enterprise telemetry initialized")
|
||||
|
||||
|
||||
def get_enterprise_exporter() -> EnterpriseExporter | None:
|
||||
return _exporter
|
||||
@@ -59,16 +59,24 @@ def init_app(app: DifyApp):
|
||||
protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower()
|
||||
if dify_config.OTEL_EXPORTER_TYPE == "otlp":
|
||||
if protocol == "grpc":
|
||||
# Auto-detect TLS: https:// uses secure, everything else is insecure
|
||||
endpoint = dify_config.OTLP_BASE_ENDPOINT
|
||||
insecure = not endpoint.startswith("https://")
|
||||
|
||||
exporter = GRPCSpanExporter(
|
||||
endpoint=dify_config.OTLP_BASE_ENDPOINT,
|
||||
endpoint=endpoint,
|
||||
# Header field names must consist of lowercase letters, check RFC7540
|
||||
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
|
||||
insecure=True,
|
||||
headers=(
|
||||
(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else None
|
||||
),
|
||||
insecure=insecure,
|
||||
)
|
||||
metric_exporter = GRPCMetricExporter(
|
||||
endpoint=dify_config.OTLP_BASE_ENDPOINT,
|
||||
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
|
||||
insecure=True,
|
||||
endpoint=endpoint,
|
||||
headers=(
|
||||
(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else None
|
||||
),
|
||||
insecure=insecure,
|
||||
)
|
||||
else:
|
||||
headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None
|
||||
|
||||
@@ -5,7 +5,7 @@ This module provides parsers that extract node-specific metadata and set
|
||||
OpenTelemetry span attributes according to semantic conventions.
|
||||
"""
|
||||
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps, should_include_content
|
||||
from extensions.otel.parser.llm import LLMNodeOTelParser
|
||||
from extensions.otel.parser.retrieval import RetrievalNodeOTelParser
|
||||
from extensions.otel.parser.tool import ToolNodeOTelParser
|
||||
@@ -17,4 +17,5 @@ __all__ = [
|
||||
"RetrievalNodeOTelParser",
|
||||
"ToolNodeOTelParser",
|
||||
"safe_json_dumps",
|
||||
"should_include_content",
|
||||
]
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
"""
|
||||
Base parser interface and utilities for OpenTelemetry node parsers.
|
||||
|
||||
Content gating: ``should_include_content()`` controls whether content-bearing
|
||||
span attributes (inputs, outputs, prompts, completions, documents) are written.
|
||||
Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when
|
||||
``ENTERPRISE_INCLUDE_CONTENT=False``; CE behaviour is unchanged.
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -9,6 +14,7 @@ from opentelemetry.trace import Span
|
||||
from opentelemetry.trace.status import Status, StatusCode
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.models import File
|
||||
from core.variables import Segment
|
||||
from core.workflow.enums import NodeType
|
||||
@@ -17,6 +23,17 @@ from core.workflow.nodes.base.node import Node
|
||||
from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes
|
||||
|
||||
|
||||
def should_include_content() -> bool:
|
||||
"""Return True if content should be written to spans.
|
||||
|
||||
CE (ENTERPRISE_ENABLED=False): always True — no behaviour change.
|
||||
EE: follows ENTERPRISE_INCLUDE_CONTENT (default True).
|
||||
"""
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return True
|
||||
return dify_config.ENTERPRISE_INCLUDE_CONTENT
|
||||
|
||||
|
||||
def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str:
|
||||
"""
|
||||
Safely serialize objects to JSON, handling non-serializable types.
|
||||
@@ -105,10 +122,11 @@ class DefaultNodeOTelParser:
|
||||
# Extract inputs and outputs from result_event
|
||||
if result_event and result_event.node_run_result:
|
||||
node_run_result = result_event.node_run_result
|
||||
if node_run_result.inputs:
|
||||
span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
|
||||
if node_run_result.outputs:
|
||||
span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
|
||||
if should_include_content():
|
||||
if node_run_result.inputs:
|
||||
span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
|
||||
if node_run_result.outputs:
|
||||
span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
|
||||
|
||||
if error:
|
||||
span.record_exception(error)
|
||||
|
||||
@@ -10,7 +10,7 @@ from opentelemetry.trace import Span
|
||||
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps, should_include_content
|
||||
from extensions.otel.semconv.gen_ai import LLMAttributes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -132,24 +132,19 @@ class LLMNodeOTelParser:
|
||||
span.set_attribute(LLMAttributes.USAGE_OUTPUT_TOKENS, completion_tokens)
|
||||
span.set_attribute(LLMAttributes.USAGE_TOTAL_TOKENS, total_tokens)
|
||||
|
||||
# Prompts and completion
|
||||
prompts = process_data.get("prompts", [])
|
||||
if prompts:
|
||||
prompts_json = safe_json_dumps(prompts)
|
||||
span.set_attribute(LLMAttributes.PROMPT, prompts_json)
|
||||
# Prompts and completion — gated by content policy
|
||||
if should_include_content():
|
||||
prompts = process_data.get("prompts", [])
|
||||
if prompts:
|
||||
prompts_json = safe_json_dumps(prompts)
|
||||
span.set_attribute(LLMAttributes.PROMPT, prompts_json)
|
||||
|
||||
text_output = str(outputs.get("text", ""))
|
||||
if text_output:
|
||||
span.set_attribute(LLMAttributes.COMPLETION, text_output)
|
||||
text_output = str(outputs.get("text", ""))
|
||||
if text_output:
|
||||
span.set_attribute(LLMAttributes.COMPLETION, text_output)
|
||||
|
||||
# Finish reason
|
||||
finish_reason = outputs.get("finish_reason") or ""
|
||||
if finish_reason:
|
||||
span.set_attribute(LLMAttributes.RESPONSE_FINISH_REASON, finish_reason)
|
||||
|
||||
# Structured input/output messages
|
||||
gen_ai_input_message = _format_input_messages(process_data)
|
||||
gen_ai_output_message = _format_output_messages(outputs)
|
||||
|
||||
span.set_attribute(LLMAttributes.INPUT_MESSAGE, gen_ai_input_message)
|
||||
span.set_attribute(LLMAttributes.OUTPUT_MESSAGE, gen_ai_output_message)
|
||||
# Structured input/output messages
|
||||
gen_ai_input_message = _format_input_messages(process_data)
|
||||
gen_ai_output_message = _format_output_messages(outputs)
|
||||
span.set_attribute(LLMAttributes.INPUT_MESSAGE, gen_ai_input_message)
|
||||
span.set_attribute(LLMAttributes.OUTPUT_MESSAGE, gen_ai_output_message)
|
||||
|
||||
@@ -11,7 +11,7 @@ from opentelemetry.trace import Span
|
||||
from core.variables import Segment
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps, should_include_content
|
||||
from extensions.otel.semconv.gen_ai import RetrieverAttributes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -83,23 +83,21 @@ class RetrievalNodeOTelParser:
|
||||
inputs = node_run_result.inputs or {}
|
||||
outputs = node_run_result.outputs or {}
|
||||
|
||||
# Extract query from inputs
|
||||
query = str(inputs.get("query", "")) if inputs else ""
|
||||
if query:
|
||||
span.set_attribute(RetrieverAttributes.QUERY, query)
|
||||
# Query and documents — gated by content policy
|
||||
if should_include_content():
|
||||
query = str(inputs.get("query", "")) if inputs else ""
|
||||
if query:
|
||||
span.set_attribute(RetrieverAttributes.QUERY, query)
|
||||
|
||||
# Extract and format retrieval documents from outputs
|
||||
result_value = outputs.get("result") if outputs else None
|
||||
retrieval_documents: list[Any] = []
|
||||
if result_value:
|
||||
value_to_check = result_value
|
||||
if isinstance(result_value, Segment):
|
||||
value_to_check = result_value.value
|
||||
|
||||
if isinstance(value_to_check, (list, Sequence)):
|
||||
retrieval_documents = list(value_to_check)
|
||||
|
||||
if retrieval_documents:
|
||||
semantic_retrieval_documents = _format_retrieval_documents(retrieval_documents)
|
||||
semantic_retrieval_documents_json = safe_json_dumps(semantic_retrieval_documents)
|
||||
span.set_attribute(RetrieverAttributes.DOCUMENT, semantic_retrieval_documents_json)
|
||||
result_value = outputs.get("result") if outputs else None
|
||||
retrieval_documents: list[Any] = []
|
||||
if result_value:
|
||||
value_to_check = result_value
|
||||
if isinstance(result_value, Segment):
|
||||
value_to_check = result_value.value
|
||||
if isinstance(value_to_check, (list, Sequence)):
|
||||
retrieval_documents = list(value_to_check)
|
||||
if retrieval_documents:
|
||||
semantic_retrieval_documents = _format_retrieval_documents(retrieval_documents)
|
||||
semantic_retrieval_documents_json = safe_json_dumps(semantic_retrieval_documents)
|
||||
span.set_attribute(RetrieverAttributes.DOCUMENT, semantic_retrieval_documents_json)
|
||||
|
||||
@@ -8,7 +8,7 @@ from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps, should_include_content
|
||||
from extensions.otel.semconv.gen_ai import ToolAttributes
|
||||
|
||||
|
||||
@@ -40,8 +40,14 @@ class ToolNodeOTelParser:
|
||||
if tool_info:
|
||||
span.set_attribute(ToolAttributes.TOOL_DESCRIPTION, safe_json_dumps(tool_info))
|
||||
|
||||
if result_event and result_event.node_run_result and result_event.node_run_result.inputs:
|
||||
span.set_attribute(ToolAttributes.TOOL_CALL_ARGUMENTS, safe_json_dumps(result_event.node_run_result.inputs))
|
||||
# Tool call arguments and result — gated by content policy
|
||||
if should_include_content():
|
||||
if result_event and result_event.node_run_result and result_event.node_run_result.inputs:
|
||||
span.set_attribute(
|
||||
ToolAttributes.TOOL_CALL_ARGUMENTS, safe_json_dumps(result_event.node_run_result.inputs)
|
||||
)
|
||||
|
||||
if result_event and result_event.node_run_result and result_event.node_run_result.outputs:
|
||||
span.set_attribute(ToolAttributes.TOOL_CALL_RESULT, safe_json_dumps(result_event.node_run_result.outputs))
|
||||
if result_event and result_event.node_run_result and result_event.node_run_result.outputs:
|
||||
span.set_attribute(
|
||||
ToolAttributes.TOOL_CALL_RESULT, safe_json_dumps(result_event.node_run_result.outputs)
|
||||
)
|
||||
|
||||
@@ -21,3 +21,15 @@ class DifySpanAttributes:
|
||||
|
||||
INVOKE_FROM = "dify.invoke_from"
|
||||
"""Invocation source, e.g. SERVICE_API, WEB_APP, DEBUGGER."""
|
||||
|
||||
INVOKED_BY = "dify.invoked_by"
|
||||
"""Invoked by, e.g. end_user, account, user."""
|
||||
|
||||
USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
|
||||
"""Number of input tokens (prompt tokens) used."""
|
||||
|
||||
USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
|
||||
"""Number of output tokens (completion tokens) generated."""
|
||||
|
||||
USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
|
||||
"""Total number of tokens used."""
|
||||
|
||||
213
api/libs/db_migration_lock.py
Normal file
213
api/libs/db_migration_lock.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
DB migration Redis lock with heartbeat renewal.
|
||||
|
||||
This is intentionally migration-specific. Background renewal is a trade-off that makes sense
|
||||
for unbounded, blocking operations like DB migrations (DDL/DML) where the main thread cannot
|
||||
periodically refresh the lock TTL.
|
||||
|
||||
Do NOT use this as a general-purpose lock primitive for normal application code. Prefer explicit
|
||||
lock lifecycle management (e.g. redis-py Lock context manager + `extend()` / `reacquire()` from
|
||||
the same thread) when execution flow is under control.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from redis.exceptions import LockNotOwnedError, RedisError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MIN_RENEW_INTERVAL_SECONDS = 0.1
|
||||
DEFAULT_RENEW_INTERVAL_DIVISOR = 3
|
||||
MIN_JOIN_TIMEOUT_SECONDS = 0.5
|
||||
MAX_JOIN_TIMEOUT_SECONDS = 5.0
|
||||
JOIN_TIMEOUT_MULTIPLIER = 2.0
|
||||
|
||||
|
||||
class DbMigrationAutoRenewLock:
|
||||
"""
|
||||
Redis lock wrapper that automatically renews TTL while held (migration-only).
|
||||
|
||||
Notes:
|
||||
- We force `thread_local=False` when creating the underlying redis-py lock, because the
|
||||
lock token must be accessible from the heartbeat thread for `reacquire()` to work.
|
||||
- `release_safely()` is best-effort: it never raises, so it won't mask the caller's
|
||||
primary error/exit code.
|
||||
"""
|
||||
|
||||
_redis_client: Any
|
||||
_name: str
|
||||
_ttl_seconds: float
|
||||
_renew_interval_seconds: float
|
||||
_log_context: str | None
|
||||
_logger: logging.Logger
|
||||
|
||||
_lock: Any
|
||||
_stop_event: threading.Event | None
|
||||
_thread: threading.Thread | None
|
||||
_acquired: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Any,
|
||||
name: str,
|
||||
ttl_seconds: float = 60,
|
||||
renew_interval_seconds: float | None = None,
|
||||
*,
|
||||
logger: logging.Logger | None = None,
|
||||
log_context: str | None = None,
|
||||
) -> None:
|
||||
self._redis_client = redis_client
|
||||
self._name = name
|
||||
self._ttl_seconds = float(ttl_seconds)
|
||||
self._renew_interval_seconds = (
|
||||
float(renew_interval_seconds)
|
||||
if renew_interval_seconds is not None
|
||||
else max(MIN_RENEW_INTERVAL_SECONDS, self._ttl_seconds / DEFAULT_RENEW_INTERVAL_DIVISOR)
|
||||
)
|
||||
self._logger = logger or logging.getLogger(__name__)
|
||||
self._log_context = log_context
|
||||
|
||||
self._lock = None
|
||||
self._stop_event = None
|
||||
self._thread = None
|
||||
self._acquired = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def acquire(self, *args: Any, **kwargs: Any) -> bool:
|
||||
"""
|
||||
Acquire the lock and start heartbeat renewal on success.
|
||||
|
||||
Accepts the same args/kwargs as redis-py `Lock.acquire()`.
|
||||
"""
|
||||
# Prevent accidental double-acquire which could leave the previous heartbeat thread running.
|
||||
if self._acquired:
|
||||
raise RuntimeError("DB migration lock is already acquired; call release_safely() before acquiring again.")
|
||||
|
||||
# Reuse the lock object if we already created one.
|
||||
if self._lock is None:
|
||||
self._lock = self._redis_client.lock(
|
||||
name=self._name,
|
||||
timeout=self._ttl_seconds,
|
||||
thread_local=False,
|
||||
)
|
||||
acquired = bool(self._lock.acquire(*args, **kwargs))
|
||||
self._acquired = acquired
|
||||
if acquired:
|
||||
self._start_heartbeat()
|
||||
return acquired
|
||||
|
||||
def owned(self) -> bool:
|
||||
if self._lock is None:
|
||||
return False
|
||||
try:
|
||||
return bool(self._lock.owned())
|
||||
except Exception:
|
||||
# Ownership checks are best-effort and must not break callers.
|
||||
return False
|
||||
|
||||
def _start_heartbeat(self) -> None:
|
||||
if self._lock is None:
|
||||
return
|
||||
if self._stop_event is not None:
|
||||
return
|
||||
|
||||
self._stop_event = threading.Event()
|
||||
self._thread = threading.Thread(
|
||||
target=self._heartbeat_loop,
|
||||
args=(self._lock, self._stop_event),
|
||||
daemon=True,
|
||||
name=f"DbMigrationAutoRenewLock({self._name})",
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None:
|
||||
while not stop_event.wait(self._renew_interval_seconds):
|
||||
try:
|
||||
lock.reacquire()
|
||||
except LockNotOwnedError:
|
||||
self._logger.warning(
|
||||
"DB migration lock is no longer owned during heartbeat; stop renewing. log_context=%s",
|
||||
self._log_context,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
except RedisError:
|
||||
self._logger.warning(
|
||||
"Failed to renew DB migration lock due to Redis error; will retry. log_context=%s",
|
||||
self._log_context,
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception:
|
||||
self._logger.warning(
|
||||
"Unexpected error while renewing DB migration lock; will retry. log_context=%s",
|
||||
self._log_context,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def release_safely(self, *, status: str | None = None) -> None:
|
||||
"""
|
||||
Stop heartbeat and release lock. Never raises.
|
||||
|
||||
Args:
|
||||
status: Optional caller-provided status (e.g. 'successful'/'failed') to add context to logs.
|
||||
"""
|
||||
lock = self._lock
|
||||
if lock is None:
|
||||
return
|
||||
|
||||
self._stop_heartbeat()
|
||||
|
||||
# Lock release errors should never mask the real error/exit code.
|
||||
try:
|
||||
lock.release()
|
||||
except LockNotOwnedError:
|
||||
self._logger.warning(
|
||||
"DB migration lock not owned on release; ignoring. status=%s log_context=%s",
|
||||
status,
|
||||
self._log_context,
|
||||
exc_info=True,
|
||||
)
|
||||
except RedisError:
|
||||
self._logger.warning(
|
||||
"Failed to release DB migration lock due to Redis error; ignoring. status=%s log_context=%s",
|
||||
status,
|
||||
self._log_context,
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception:
|
||||
self._logger.warning(
|
||||
"Unexpected error while releasing DB migration lock; ignoring. status=%s log_context=%s",
|
||||
status,
|
||||
self._log_context,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
self._acquired = False
|
||||
self._lock = None
|
||||
|
||||
def _stop_heartbeat(self) -> None:
|
||||
if self._stop_event is None:
|
||||
return
|
||||
self._stop_event.set()
|
||||
if self._thread is not None:
|
||||
# Best-effort join: if Redis calls are blocked, the daemon thread may remain alive.
|
||||
join_timeout_seconds = max(
|
||||
MIN_JOIN_TIMEOUT_SECONDS,
|
||||
min(MAX_JOIN_TIMEOUT_SECONDS, self._renew_interval_seconds * JOIN_TIMEOUT_MULTIPLIER),
|
||||
)
|
||||
self._thread.join(timeout=join_timeout_seconds)
|
||||
if self._thread.is_alive():
|
||||
self._logger.warning(
|
||||
"DB migration lock heartbeat thread did not stop within %.2fs; ignoring. log_context=%s",
|
||||
join_timeout_seconds,
|
||||
self._log_context,
|
||||
)
|
||||
self._stop_event = None
|
||||
self._thread = None
|
||||
@@ -8,11 +8,6 @@ Create Date: 2025-12-25 10:39:15.139304
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '7df29de0f6be'
|
||||
@@ -23,31 +18,16 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('tenant_credit_pools',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
|
||||
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
|
||||
sa.Column('quota_used', sa.BigInteger(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
|
||||
)
|
||||
else:
|
||||
# For MySQL and other databases, UUID should be generated at application level
|
||||
op.create_table('tenant_credit_pools',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
|
||||
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
|
||||
sa.Column('quota_used', sa.BigInteger(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
|
||||
)
|
||||
op.create_table('tenant_credit_pools',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
|
||||
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
|
||||
sa.Column('quota_used', sa.BigInteger(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
|
||||
)
|
||||
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
|
||||
batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False)
|
||||
batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
|
||||
@@ -8,7 +8,6 @@ Create Date: 2026-01-017 11:10:18.079355
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'f9f6d18a37f9'
|
||||
@@ -20,7 +19,7 @@ depends_on = None
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('account_trial_app_records',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('account_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('count', sa.Integer(), nullable=False),
|
||||
@@ -33,17 +32,17 @@ def upgrade():
|
||||
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
|
||||
|
||||
op.create_table('exporle_banners',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('content', sa.JSON(), nullable=False),
|
||||
sa.Column('link', sa.String(length=255), nullable=False),
|
||||
sa.Column('sort', sa.Integer(), nullable=False),
|
||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
|
||||
sa.Column('status', sa.String(length=255), server_default='enabled', nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False),
|
||||
sa.Column('language', sa.String(length=255), server_default='en-US', nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
|
||||
)
|
||||
op.create_table('trial_apps',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
|
||||
@@ -620,7 +620,7 @@ class TrialApp(Base):
|
||||
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
@@ -640,7 +640,7 @@ class AccountTrialAppRecord(Base):
|
||||
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
|
||||
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
|
||||
)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
count = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
@@ -660,19 +660,15 @@ class AccountTrialAppRecord(Base):
|
||||
class ExporleBanner(TypeBase):
|
||||
__tablename__ = "exporle_banners"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
||||
link: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
status: Mapped[str] = mapped_column(
|
||||
sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled"
|
||||
)
|
||||
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default="enabled", default="enabled")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
language: Mapped[str] = mapped_column(
|
||||
String(255), nullable=False, server_default=sa.text("'en-US'::character varying"), default="en-US"
|
||||
)
|
||||
language: Mapped[str] = mapped_column(String(255), nullable=False, server_default="en-US", default="en-US")
|
||||
|
||||
|
||||
class OAuthProviderApp(TypeBase):
|
||||
@@ -2166,9 +2162,7 @@ class TenantCreditPool(TypeBase):
|
||||
sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
|
||||
quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
|
||||
@@ -22,14 +22,14 @@ dependencies = [
|
||||
"flask-sqlalchemy~=3.1.1",
|
||||
"gevent~=25.9.1",
|
||||
"gmpy2~=2.2.1",
|
||||
"google-api-core==2.18.0",
|
||||
"google-api-core>=2.19.1",
|
||||
"google-api-python-client==2.90.0",
|
||||
"google-auth==2.29.0",
|
||||
"google-auth>=2.47.0",
|
||||
"google-auth-httplib2==0.2.0",
|
||||
"google-cloud-aiplatform==1.49.0",
|
||||
"googleapis-common-protos==1.63.0",
|
||||
"google-cloud-aiplatform>=1.123.0",
|
||||
"googleapis-common-protos>=1.65.0",
|
||||
"gunicorn~=23.0.0",
|
||||
"httpx[socks]~=0.27.0",
|
||||
"httpx[socks]~=0.28.0",
|
||||
"jieba==0.42.1",
|
||||
"json-repair>=0.55.1",
|
||||
"jsonschema>=4.25.1",
|
||||
@@ -41,26 +41,23 @@ dependencies = [
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.8.72",
|
||||
"litellm==1.77.1", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.27.0",
|
||||
"opentelemetry-distro==0.48b0",
|
||||
"opentelemetry-exporter-otlp==1.27.0",
|
||||
"opentelemetry-exporter-otlp-proto-common==1.27.0",
|
||||
"opentelemetry-exporter-otlp-proto-grpc==1.27.0",
|
||||
"opentelemetry-exporter-otlp-proto-http==1.27.0",
|
||||
"opentelemetry-instrumentation==0.48b0",
|
||||
"opentelemetry-instrumentation-celery==0.48b0",
|
||||
"opentelemetry-instrumentation-flask==0.48b0",
|
||||
"opentelemetry-instrumentation-httpx==0.48b0",
|
||||
"opentelemetry-instrumentation-redis==0.48b0",
|
||||
"opentelemetry-instrumentation-httpx==0.48b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.48b0",
|
||||
"opentelemetry-propagator-b3==1.27.0",
|
||||
# opentelemetry-proto1.28.0 depends on protobuf (>=5.0,<6.0),
|
||||
# which is conflict with googleapis-common-protos (1.63.0)
|
||||
"opentelemetry-proto==1.27.0",
|
||||
"opentelemetry-sdk==1.27.0",
|
||||
"opentelemetry-semantic-conventions==0.48b0",
|
||||
"opentelemetry-util-http==0.48b0",
|
||||
"opentelemetry-api==1.28.0",
|
||||
"opentelemetry-distro==0.49b0",
|
||||
"opentelemetry-exporter-otlp==1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-common==1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-grpc==1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-http==1.28.0",
|
||||
"opentelemetry-instrumentation==0.49b0",
|
||||
"opentelemetry-instrumentation-celery==0.49b0",
|
||||
"opentelemetry-instrumentation-flask==0.49b0",
|
||||
"opentelemetry-instrumentation-httpx==0.49b0",
|
||||
"opentelemetry-instrumentation-redis==0.49b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.49b0",
|
||||
"opentelemetry-propagator-b3==1.28.0",
|
||||
"opentelemetry-proto==1.28.0",
|
||||
"opentelemetry-sdk==1.28.0",
|
||||
"opentelemetry-semantic-conventions==0.49b0",
|
||||
"opentelemetry-util-http==0.49b0",
|
||||
"pandas[excel,output-formatting,performance]~=2.2.2",
|
||||
"psycogreen~=1.0.2",
|
||||
"psycopg2-binary~=2.9.6",
|
||||
@@ -81,7 +78,7 @@ dependencies = [
|
||||
"starlette==0.49.1",
|
||||
"tiktoken~=0.9.0",
|
||||
"transformers~=4.56.1",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
|
||||
"yarl~=1.18.3",
|
||||
"webvtt-py~=0.5.1",
|
||||
"sseclient-py~=1.8.0",
|
||||
|
||||
@@ -1,24 +1,16 @@
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
|
||||
import click
|
||||
|
||||
import app
|
||||
from core.helper.marketplace import fetch_global_plugin_manifest
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
from tasks import process_tenant_plugin_autoupgrade_check_task as check_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL = 15 * 60 # 15 minutes
|
||||
MAX_CONCURRENT_CHECK_TASKS = 20
|
||||
|
||||
# Import cache constants from the task module
|
||||
CACHE_REDIS_KEY_PREFIX = check_task.CACHE_REDIS_KEY_PREFIX
|
||||
CACHE_REDIS_TTL = check_task.CACHE_REDIS_TTL
|
||||
|
||||
|
||||
@app.celery.task(queue="plugin")
|
||||
def check_upgradable_plugin_task():
|
||||
@@ -48,22 +40,6 @@ def check_upgradable_plugin_task():
|
||||
) # make sure all strategies are checked in this interval
|
||||
batch_interval_time = (AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL / batch_chunk_count) if batch_chunk_count > 0 else 0
|
||||
|
||||
if total_strategies == 0:
|
||||
click.echo(click.style("no strategies to process, skipping plugin manifest fetch.", fg="green"))
|
||||
return
|
||||
|
||||
# Fetch and cache all plugin manifests before processing tenants
|
||||
# This reduces load on marketplace from 300k requests to 1 request per check cycle
|
||||
logger.info("fetching global plugin manifest from marketplace")
|
||||
try:
|
||||
fetch_global_plugin_manifest(CACHE_REDIS_KEY_PREFIX, CACHE_REDIS_TTL)
|
||||
logger.info("successfully fetched and cached global plugin manifest")
|
||||
except Exception as e:
|
||||
logger.exception("failed to fetch global plugin manifest")
|
||||
click.echo(click.style(f"failed to fetch global plugin manifest: {e}", fg="red"))
|
||||
click.echo(click.style("skipping plugin upgrade check for this cycle", fg="yellow"))
|
||||
return
|
||||
|
||||
for i in range(0, total_strategies, MAX_CONCURRENT_CHECK_TASKS):
|
||||
batch_strategies = strategies[i : i + MAX_CONCURRENT_CHECK_TASKS]
|
||||
for strategy in batch_strategies:
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
"""
|
||||
Scheduled task to batch-update API token last_used_at timestamps.
|
||||
|
||||
Instead of updating the database on every request, token usage is recorded
|
||||
in Redis as lightweight SET keys (api_token_active:{scope}:{token}).
|
||||
This task runs periodically (default every 30 minutes) to flush those
|
||||
records into the database in a single batch operation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import click
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import ApiToken
|
||||
from services.api_token_service import ACTIVE_TOKEN_KEY_PREFIX
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.celery.task(queue="api_token")
|
||||
def batch_update_api_token_last_used():
|
||||
"""
|
||||
Batch update last_used_at for all recently active API tokens.
|
||||
|
||||
Scans Redis for api_token_active:* keys, parses the token and scope
|
||||
from each key, and performs a batch database update.
|
||||
"""
|
||||
click.echo(click.style("batch_update_api_token_last_used: start.", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
updated_count = 0
|
||||
scanned_count = 0
|
||||
|
||||
try:
|
||||
# Collect all active token keys and their values (the actual usage timestamps)
|
||||
token_entries: list[tuple[str, str | None, datetime]] = [] # (token, scope, usage_time)
|
||||
keys_to_delete: list[str | bytes] = []
|
||||
|
||||
for key in redis_client.scan_iter(match=f"{ACTIVE_TOKEN_KEY_PREFIX}*", count=200):
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode("utf-8")
|
||||
scanned_count += 1
|
||||
|
||||
# Read the value (ISO timestamp recorded at actual request time)
|
||||
value = redis_client.get(key)
|
||||
if not value:
|
||||
keys_to_delete.append(key)
|
||||
continue
|
||||
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("utf-8")
|
||||
|
||||
try:
|
||||
usage_time = datetime.fromisoformat(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning("Invalid timestamp in key %s: %s", key, value)
|
||||
keys_to_delete.append(key)
|
||||
continue
|
||||
|
||||
# Parse token info from key: api_token_active:{scope}:{token}
|
||||
suffix = key[len(ACTIVE_TOKEN_KEY_PREFIX) :]
|
||||
parts = suffix.split(":", 1)
|
||||
if len(parts) == 2:
|
||||
scope_str, token = parts
|
||||
scope = None if scope_str == "None" else scope_str
|
||||
token_entries.append((token, scope, usage_time))
|
||||
keys_to_delete.append(key)
|
||||
|
||||
if not token_entries:
|
||||
click.echo(click.style("batch_update_api_token_last_used: no active tokens found.", fg="yellow"))
|
||||
# Still clean up any invalid keys
|
||||
if keys_to_delete:
|
||||
redis_client.delete(*keys_to_delete)
|
||||
return
|
||||
|
||||
# Update each token in its own short transaction to avoid long transactions
|
||||
for token, scope, usage_time in token_entries:
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
stmt = (
|
||||
update(ApiToken)
|
||||
.where(
|
||||
ApiToken.token == token,
|
||||
ApiToken.type == scope,
|
||||
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < usage_time)),
|
||||
)
|
||||
.values(last_used_at=usage_time)
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
rowcount = getattr(result, "rowcount", 0)
|
||||
if rowcount > 0:
|
||||
updated_count += 1
|
||||
|
||||
# Delete processed keys from Redis
|
||||
if keys_to_delete:
|
||||
redis_client.delete(*keys_to_delete)
|
||||
|
||||
except Exception:
|
||||
logger.exception("batch_update_api_token_last_used failed")
|
||||
|
||||
elapsed = time.perf_counter() - start_at
|
||||
click.echo(
|
||||
click.style(
|
||||
f"batch_update_api_token_last_used: done. "
|
||||
f"scanned={scanned_count}, updated={updated_count}, elapsed={elapsed:.2f}s",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
@@ -74,6 +74,16 @@ from tasks.mail_reset_password_task import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _try_join_enterprise_default_workspace(account_id: str) -> None:
|
||||
"""Best-effort join to enterprise default workspace."""
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
|
||||
from services.enterprise.enterprise_service import try_join_default_workspace
|
||||
|
||||
try_join_default_workspace(account_id)
|
||||
|
||||
|
||||
class TokenPair(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
@@ -287,7 +297,14 @@ class AccountService:
|
||||
email=email, name=name, interface_language=interface_language, password=password
|
||||
)
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account)
|
||||
try:
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account)
|
||||
except Exception:
|
||||
# Enterprise-only side-effect should run independently from personal workspace creation.
|
||||
_try_join_enterprise_default_workspace(str(account.id))
|
||||
raise
|
||||
|
||||
_try_join_enterprise_default_workspace(str(account.id))
|
||||
|
||||
return account
|
||||
|
||||
@@ -330,12 +347,7 @@ class AccountService:
|
||||
# Queue account deletion sync tasks for all workspaces BEFORE account deletion (enterprise only)
|
||||
from services.enterprise.account_deletion_sync import sync_account_deletion
|
||||
|
||||
sync_success = sync_account_deletion(account_id=account.id, source="account_deleted")
|
||||
if not sync_success:
|
||||
logger.warning(
|
||||
"Enterprise account deletion sync failed for account %s; proceeding with local deletion.",
|
||||
account.id,
|
||||
)
|
||||
sync_account_deletion(account_id=account.id, source="account_deleted")
|
||||
|
||||
# Now proceed with async account deletion
|
||||
delete_account_task.delay(account.id)
|
||||
@@ -1244,15 +1256,7 @@ class TenantService:
|
||||
# Queue account deletion sync task for enterprise backend to reassign resources (enterprise only)
|
||||
from services.enterprise.account_deletion_sync import sync_workspace_member_removal
|
||||
|
||||
sync_success = sync_workspace_member_removal(
|
||||
workspace_id=tenant.id, member_id=account.id, source="workspace_member_removed"
|
||||
)
|
||||
if not sync_success:
|
||||
logger.warning(
|
||||
"Enterprise workspace member removal sync failed: workspace_id=%s, member_id=%s",
|
||||
tenant.id,
|
||||
account.id,
|
||||
)
|
||||
sync_workspace_member_removal(workspace_id=tenant.id, member_id=account.id, source="workspace_member_removed")
|
||||
|
||||
@staticmethod
|
||||
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
|
||||
@@ -1374,12 +1378,18 @@ class RegisterService:
|
||||
and create_workspace_required
|
||||
and FeatureService.get_system_features().license.workspaces.is_available()
|
||||
):
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
try:
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
except Exception:
|
||||
_try_join_enterprise_default_workspace(str(account.id))
|
||||
raise
|
||||
|
||||
db.session.commit()
|
||||
|
||||
_try_join_enterprise_default_workspace(str(account.id))
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
db.session.rollback()
|
||||
logger.exception("Register failed")
|
||||
|
||||
@@ -1,330 +0,0 @@
|
||||
"""
|
||||
API Token Service
|
||||
|
||||
Handles all API token caching, validation, and usage recording.
|
||||
Includes Redis cache operations, database queries, and single-flight concurrency control.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import ApiToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Pydantic DTO
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
class CachedApiToken(BaseModel):
|
||||
"""
|
||||
Pydantic model for cached API token data.
|
||||
|
||||
This is NOT a SQLAlchemy model instance, but a plain Pydantic model
|
||||
that mimics the ApiToken model interface for read-only access.
|
||||
"""
|
||||
|
||||
id: str
|
||||
app_id: str | None
|
||||
tenant_id: str | None
|
||||
type: str
|
||||
token: str
|
||||
last_used_at: datetime | None
|
||||
created_at: datetime | None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<CachedApiToken id={self.id} type={self.type}>"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Cache configuration
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
CACHE_KEY_PREFIX = "api_token"
|
||||
CACHE_TTL_SECONDS = 600 # 10 minutes
|
||||
CACHE_NULL_TTL_SECONDS = 60 # 1 minute for non-existent tokens
|
||||
ACTIVE_TOKEN_KEY_PREFIX = "api_token_active:"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Cache class
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
class ApiTokenCache:
|
||||
"""
|
||||
Redis cache wrapper for API tokens.
|
||||
Handles serialization, deserialization, and cache invalidation.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def make_active_key(token: str, scope: str | None = None) -> str:
|
||||
"""Generate Redis key for recording token usage."""
|
||||
return f"{ACTIVE_TOKEN_KEY_PREFIX}{scope}:{token}"
|
||||
|
||||
@staticmethod
|
||||
def _make_tenant_index_key(tenant_id: str) -> str:
|
||||
"""Generate Redis key for tenant token index."""
|
||||
return f"tenant_tokens:{tenant_id}"
|
||||
|
||||
@staticmethod
|
||||
def _make_cache_key(token: str, scope: str | None = None) -> str:
|
||||
"""Generate cache key for the given token and scope."""
|
||||
scope_str = scope or "any"
|
||||
return f"{CACHE_KEY_PREFIX}:{scope_str}:{token}"
|
||||
|
||||
@staticmethod
|
||||
def _serialize_token(api_token: Any) -> bytes:
|
||||
"""Serialize ApiToken object to JSON bytes."""
|
||||
if isinstance(api_token, CachedApiToken):
|
||||
return api_token.model_dump_json().encode("utf-8")
|
||||
|
||||
cached = CachedApiToken(
|
||||
id=str(api_token.id),
|
||||
app_id=str(api_token.app_id) if api_token.app_id else None,
|
||||
tenant_id=str(api_token.tenant_id) if api_token.tenant_id else None,
|
||||
type=api_token.type,
|
||||
token=api_token.token,
|
||||
last_used_at=api_token.last_used_at,
|
||||
created_at=api_token.created_at,
|
||||
)
|
||||
return cached.model_dump_json().encode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_token(cached_data: bytes | str) -> Any:
|
||||
"""Deserialize JSON bytes/string back to a CachedApiToken Pydantic model."""
|
||||
if cached_data in {b"null", "null"}:
|
||||
return None
|
||||
|
||||
try:
|
||||
if isinstance(cached_data, bytes):
|
||||
cached_data = cached_data.decode("utf-8")
|
||||
return CachedApiToken.model_validate_json(cached_data)
|
||||
except (ValueError, Exception) as e:
|
||||
logger.warning("Failed to deserialize token from cache: %s", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def get(token: str, scope: str | None) -> Any | None:
|
||||
"""Get API token from cache."""
|
||||
cache_key = ApiTokenCache._make_cache_key(token, scope)
|
||||
cached_data = redis_client.get(cache_key)
|
||||
|
||||
if cached_data is None:
|
||||
logger.debug("Cache miss for token key: %s", cache_key)
|
||||
return None
|
||||
|
||||
logger.debug("Cache hit for token key: %s", cache_key)
|
||||
return ApiTokenCache._deserialize_token(cached_data)
|
||||
|
||||
@staticmethod
|
||||
def _add_to_tenant_index(tenant_id: str | None, cache_key: str) -> None:
|
||||
"""Add cache key to tenant index for efficient invalidation."""
|
||||
if not tenant_id:
|
||||
return
|
||||
|
||||
try:
|
||||
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
|
||||
redis_client.sadd(index_key, cache_key)
|
||||
redis_client.expire(index_key, CACHE_TTL_SECONDS + 60)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to update tenant index: %s", e)
|
||||
|
||||
@staticmethod
|
||||
def _remove_from_tenant_index(tenant_id: str | None, cache_key: str) -> None:
|
||||
"""Remove cache key from tenant index."""
|
||||
if not tenant_id:
|
||||
return
|
||||
|
||||
try:
|
||||
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
|
||||
redis_client.srem(index_key, cache_key)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to remove from tenant index: %s", e)
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def set(token: str, scope: str | None, api_token: Any | None, ttl: int = CACHE_TTL_SECONDS) -> bool:
|
||||
"""Set API token in cache."""
|
||||
cache_key = ApiTokenCache._make_cache_key(token, scope)
|
||||
|
||||
if api_token is None:
|
||||
cached_value = b"null"
|
||||
ttl = CACHE_NULL_TTL_SECONDS
|
||||
else:
|
||||
cached_value = ApiTokenCache._serialize_token(api_token)
|
||||
|
||||
try:
|
||||
redis_client.setex(cache_key, ttl, cached_value)
|
||||
|
||||
if api_token is not None and hasattr(api_token, "tenant_id"):
|
||||
ApiTokenCache._add_to_tenant_index(api_token.tenant_id, cache_key)
|
||||
|
||||
logger.debug("Cached token with key: %s, ttl: %ss", cache_key, ttl)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cache token: %s", e)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def delete(token: str, scope: str | None = None) -> bool:
|
||||
"""Delete API token from cache."""
|
||||
if scope is None:
|
||||
pattern = f"{CACHE_KEY_PREFIX}:*:{token}"
|
||||
try:
|
||||
keys_to_delete = list(redis_client.scan_iter(match=pattern))
|
||||
if keys_to_delete:
|
||||
redis_client.delete(*keys_to_delete)
|
||||
logger.info("Deleted %d cache entries for token", len(keys_to_delete))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete token cache with pattern: %s", e)
|
||||
return False
|
||||
else:
|
||||
cache_key = ApiTokenCache._make_cache_key(token, scope)
|
||||
try:
|
||||
tenant_id = None
|
||||
try:
|
||||
cached_data = redis_client.get(cache_key)
|
||||
if cached_data and cached_data != b"null":
|
||||
cached_token = ApiTokenCache._deserialize_token(cached_data)
|
||||
if cached_token:
|
||||
tenant_id = cached_token.tenant_id
|
||||
except Exception as e:
|
||||
logger.debug("Failed to get tenant_id for cache cleanup: %s", e)
|
||||
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
if tenant_id:
|
||||
ApiTokenCache._remove_from_tenant_index(tenant_id, cache_key)
|
||||
|
||||
logger.info("Deleted cache for key: %s", cache_key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete token cache: %s", e)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def invalidate_by_tenant(tenant_id: str) -> bool:
|
||||
"""Invalidate all API token caches for a specific tenant via tenant index."""
|
||||
try:
|
||||
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
|
||||
cache_keys = redis_client.smembers(index_key)
|
||||
|
||||
if cache_keys:
|
||||
deleted_count = 0
|
||||
for cache_key in cache_keys:
|
||||
if isinstance(cache_key, bytes):
|
||||
cache_key = cache_key.decode("utf-8")
|
||||
redis_client.delete(cache_key)
|
||||
deleted_count += 1
|
||||
|
||||
redis_client.delete(index_key)
|
||||
|
||||
logger.info(
|
||||
"Invalidated %d token cache entries for tenant: %s",
|
||||
deleted_count,
|
||||
tenant_id,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"No tenant index found for %s, relying on TTL expiration",
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to invalidate tenant token cache: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Token usage recording (for batch update)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def record_token_usage(auth_token: str, scope: str | None) -> None:
|
||||
"""
|
||||
Record token usage in Redis for later batch update by a scheduled job.
|
||||
|
||||
Instead of dispatching a Celery task per request, we simply SET a key in Redis.
|
||||
A Celery Beat scheduled task will periodically scan these keys and batch-update
|
||||
last_used_at in the database.
|
||||
"""
|
||||
try:
|
||||
key = ApiTokenCache.make_active_key(auth_token, scope)
|
||||
redis_client.set(key, naive_utc_now().isoformat(), ex=3600)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to record token usage: %s", e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Database query + single-flight
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def query_token_from_db(auth_token: str, scope: str | None) -> ApiToken:
|
||||
"""
|
||||
Query API token from database and cache the result.
|
||||
|
||||
Raises Unauthorized if token is invalid.
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
api_token = session.scalar(stmt)
|
||||
|
||||
if not api_token:
|
||||
ApiTokenCache.set(auth_token, scope, None)
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
ApiTokenCache.set(auth_token, scope, api_token)
|
||||
record_token_usage(auth_token, scope)
|
||||
return api_token
|
||||
|
||||
|
||||
def fetch_token_with_single_flight(auth_token: str, scope: str | None) -> ApiToken | Any:
|
||||
"""
|
||||
Fetch token from DB with single-flight pattern using Redis lock.
|
||||
|
||||
Ensures only one concurrent request queries the database for the same token.
|
||||
Falls back to direct query if lock acquisition fails.
|
||||
"""
|
||||
logger.debug("Token cache miss, attempting to acquire query lock for scope: %s", scope)
|
||||
|
||||
lock_key = f"api_token_query_lock:{scope}:{auth_token}"
|
||||
lock = redis_client.lock(lock_key, timeout=10, blocking_timeout=5)
|
||||
|
||||
try:
|
||||
if lock.acquire(blocking=True):
|
||||
try:
|
||||
cached_token = ApiTokenCache.get(auth_token, scope)
|
||||
if cached_token is not None:
|
||||
logger.debug("Token cached by concurrent request, using cached version")
|
||||
return cached_token
|
||||
|
||||
return query_token_from_db(auth_token, scope)
|
||||
finally:
|
||||
lock.release()
|
||||
else:
|
||||
logger.warning("Lock timeout for token: %s, proceeding with direct query", auth_token[:10])
|
||||
return query_token_from_db(auth_token, scope)
|
||||
except Unauthorized:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning("Redis lock failed for token query: %s, proceeding anyway", e)
|
||||
return query_token_from_db(auth_token, scope)
|
||||
@@ -14,7 +14,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelTy
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_was_created
|
||||
from events.app_event import app_was_created, app_was_deleted
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
@@ -340,6 +340,8 @@ class AppService:
|
||||
db.session.delete(app)
|
||||
db.session.commit()
|
||||
|
||||
app_was_deleted.send(app)
|
||||
|
||||
# clean up web app settings
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
EnterpriseService.WebAppAuth.cleanup_webapp(app.id)
|
||||
|
||||
@@ -6,6 +6,13 @@ from typing import Any
|
||||
import httpx
|
||||
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
from services.errors.enterprise import (
|
||||
EnterpriseAPIBadRequestError,
|
||||
EnterpriseAPIError,
|
||||
EnterpriseAPIForbiddenError,
|
||||
EnterpriseAPINotFoundError,
|
||||
EnterpriseAPIUnauthorizedError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,6 +46,9 @@ class BaseRequest:
|
||||
endpoint: str,
|
||||
json: Any | None = None,
|
||||
params: Mapping[str, Any] | None = None,
|
||||
*,
|
||||
timeout: float | httpx.Timeout | None = None,
|
||||
raise_for_status: bool = False,
|
||||
) -> Any:
|
||||
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
@@ -53,9 +63,64 @@ class BaseRequest:
|
||||
logger.debug("Failed to generate traceparent header", exc_info=True)
|
||||
|
||||
with httpx.Client(mounts=mounts) as client:
|
||||
response = client.request(method, url, json=json, params=params, headers=headers)
|
||||
# IMPORTANT:
|
||||
# - In httpx, passing timeout=None disables timeouts (infinite) and overrides the library default.
|
||||
# - To preserve httpx's default timeout behavior for existing call sites, only pass the kwarg when set.
|
||||
request_kwargs: dict[str, Any] = {"json": json, "params": params, "headers": headers}
|
||||
if timeout is not None:
|
||||
request_kwargs["timeout"] = timeout
|
||||
|
||||
response = client.request(method, url, **request_kwargs)
|
||||
|
||||
# Always validate HTTP status and raise domain-specific errors
|
||||
if not response.is_success:
|
||||
cls._handle_error_response(response)
|
||||
|
||||
# Legacy support: still respect raise_for_status parameter
|
||||
if raise_for_status:
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
def _handle_error_response(cls, response: httpx.Response) -> None:
|
||||
"""
|
||||
Handle non-2xx HTTP responses by raising appropriate domain errors.
|
||||
|
||||
Attempts to extract error message from JSON response body,
|
||||
falls back to status text if parsing fails.
|
||||
"""
|
||||
error_message = f"Enterprise API request failed: {response.status_code} {response.reason_phrase}"
|
||||
|
||||
# Try to extract error message from JSON response
|
||||
try:
|
||||
error_data = response.json()
|
||||
if isinstance(error_data, dict):
|
||||
# Common error response formats:
|
||||
# {"error": "...", "message": "..."}
|
||||
# {"message": "..."}
|
||||
# {"detail": "..."}
|
||||
error_message = (
|
||||
error_data.get("message") or error_data.get("error") or error_data.get("detail") or error_message
|
||||
)
|
||||
except Exception:
|
||||
# If JSON parsing fails, use the default message
|
||||
logger.debug(
|
||||
"Failed to parse error response from enterprise API (status=%s)", response.status_code, exc_info=True
|
||||
)
|
||||
|
||||
# Raise specific error based on status code
|
||||
if response.status_code == 400:
|
||||
raise EnterpriseAPIBadRequestError(error_message)
|
||||
elif response.status_code == 401:
|
||||
raise EnterpriseAPIUnauthorizedError(error_message)
|
||||
elif response.status_code == 403:
|
||||
raise EnterpriseAPIForbiddenError(error_message)
|
||||
elif response.status_code == 404:
|
||||
raise EnterpriseAPINotFoundError(error_message)
|
||||
else:
|
||||
raise EnterpriseAPIError(error_message, status_code=response.status_code)
|
||||
|
||||
|
||||
class EnterpriseRequest(BaseRequest):
|
||||
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
|
||||
|
||||
@@ -1,9 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
|
||||
ALLOWED_ACCESS_MODES = ["public", "private", "private_all", "sso_verified"]
|
||||
# License status cache configuration
|
||||
LICENSE_STATUS_CACHE_KEY = "enterprise:license:status"
|
||||
LICENSE_STATUS_CACHE_TTL = 600 # 10 minutes
|
||||
|
||||
|
||||
class WebAppSettings(BaseModel):
|
||||
access_mode: str = Field(
|
||||
@@ -30,6 +48,55 @@ class WorkspacePermission(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class DefaultWorkspaceJoinResult(BaseModel):
|
||||
"""
|
||||
Result of ensuring an account is a member of the enterprise default workspace.
|
||||
|
||||
- joined=True is idempotent (already a member also returns True)
|
||||
- joined=False means enterprise default workspace is not configured or invalid/archived
|
||||
"""
|
||||
|
||||
workspace_id: str = Field(default="", alias="workspaceId")
|
||||
joined: bool
|
||||
message: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_workspace_id_when_joined(self) -> DefaultWorkspaceJoinResult:
|
||||
if self.joined and not self.workspace_id:
|
||||
raise ValueError("workspace_id must be non-empty when joined is True")
|
||||
return self
|
||||
|
||||
|
||||
def try_join_default_workspace(account_id: str) -> None:
|
||||
"""
|
||||
Enterprise-only side-effect: ensure account is a member of the default workspace.
|
||||
|
||||
This is a best-effort integration. Failures must not block user registration.
|
||||
"""
|
||||
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
|
||||
try:
|
||||
result = EnterpriseService.join_default_workspace(account_id=account_id)
|
||||
if result.joined:
|
||||
logger.info(
|
||||
"Joined enterprise default workspace for account %s (workspace_id=%s)",
|
||||
account_id,
|
||||
result.workspace_id,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Skipped joining enterprise default workspace for account %s (message=%s)",
|
||||
account_id,
|
||||
result.message,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to join enterprise default workspace for account %s", account_id, exc_info=True)
|
||||
|
||||
|
||||
class EnterpriseService:
|
||||
@classmethod
|
||||
def get_info(cls):
|
||||
@@ -39,6 +106,34 @@ class EnterpriseService:
|
||||
def get_workspace_info(cls, tenant_id: str):
|
||||
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
|
||||
|
||||
@classmethod
|
||||
def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult:
|
||||
"""
|
||||
Call enterprise inner API to add an account to the default workspace.
|
||||
|
||||
NOTE: EnterpriseRequest.base_url is expected to already include the `/inner/api` prefix,
|
||||
so the endpoint here is `/default-workspace/members`.
|
||||
"""
|
||||
|
||||
# Ensure we are sending a UUID-shaped string (enterprise side validates too).
|
||||
try:
|
||||
uuid.UUID(account_id)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"account_id must be a valid UUID: {account_id}") from e
|
||||
|
||||
data = EnterpriseRequest.send_request(
|
||||
"POST",
|
||||
"/default-workspace/members",
|
||||
json={"account_id": account_id},
|
||||
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
|
||||
raise_for_status=True,
|
||||
)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Invalid response format from enterprise default workspace API")
|
||||
if "joined" not in data or "message" not in data:
|
||||
raise ValueError("Invalid response payload from enterprise default workspace API")
|
||||
return DefaultWorkspaceJoinResult.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def get_app_sso_settings_last_update_time(cls) -> datetime:
|
||||
data = EnterpriseRequest.send_request("GET", "/sso/app/last-update-time")
|
||||
@@ -123,8 +218,8 @@ class EnterpriseService:
|
||||
def update_app_access_mode(cls, app_id: str, access_mode: str):
|
||||
if not app_id:
|
||||
raise ValueError("app_id must be provided.")
|
||||
if access_mode not in ["public", "private", "private_all"]:
|
||||
raise ValueError("access_mode must be either 'public', 'private', or 'private_all'")
|
||||
if access_mode not in ALLOWED_ACCESS_MODES:
|
||||
raise ValueError(f"access_mode must be one of: {', '.join(ALLOWED_ACCESS_MODES)}")
|
||||
|
||||
data = {"appId": app_id, "accessMode": access_mode}
|
||||
|
||||
@@ -139,3 +234,62 @@ class EnterpriseService:
|
||||
|
||||
params = {"appId": app_id}
|
||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||
|
||||
@classmethod
|
||||
def get_cached_license_status(cls) -> LicenseStatus | None:
|
||||
"""Get enterprise license status with Redis caching to reduce HTTP calls.
|
||||
|
||||
Caches valid statuses (active/expiring) for 10 minutes. Invalid statuses
|
||||
are not cached so license updates are picked up on the next request.
|
||||
|
||||
Returns:
|
||||
LicenseStatus enum value, or None if enterprise is disabled / unreachable.
|
||||
"""
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return None
|
||||
|
||||
cached = cls._read_cached_license_status()
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
return cls._fetch_and_cache_license_status()
|
||||
|
||||
@classmethod
|
||||
def _read_cached_license_status(cls) -> LicenseStatus | None:
|
||||
"""Read license status from Redis cache, returning None on miss or failure."""
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
try:
|
||||
raw = redis_client.get(LICENSE_STATUS_CACHE_KEY)
|
||||
if raw:
|
||||
value = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
||||
return LicenseStatus(value)
|
||||
except Exception:
|
||||
logger.debug("Failed to read license status from cache", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _fetch_and_cache_license_status(cls) -> LicenseStatus | None:
|
||||
"""Fetch license status from enterprise API and cache the result.
|
||||
|
||||
Only caches valid statuses (active/expiring) so license updates
|
||||
for invalid statuses are picked up on the next request.
|
||||
"""
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
try:
|
||||
info = cls.get_info()
|
||||
license_info = info.get("License")
|
||||
if not license_info:
|
||||
return None
|
||||
|
||||
status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
||||
if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING):
|
||||
try:
|
||||
redis_client.setex(LICENSE_STATUS_CACHE_KEY, LICENSE_STATUS_CACHE_TTL, status)
|
||||
except Exception:
|
||||
logger.debug("Failed to cache license status", exc_info=True)
|
||||
return status
|
||||
except Exception:
|
||||
logger.debug("Failed to fetch enterprise license status", exc_info=True)
|
||||
return None
|
||||
|
||||
@@ -28,6 +28,11 @@ class CheckCredentialPolicyComplianceRequest(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class PreUninstallPluginRequest(BaseModel):
|
||||
tenant_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
|
||||
class CredentialPolicyViolationError(BaseServiceError):
|
||||
pass
|
||||
|
||||
@@ -55,3 +60,19 @@ class PluginManagerService:
|
||||
body.dify_credential_id,
|
||||
ret.get("result", False),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def try_pre_uninstall_plugin(cls, body: PreUninstallPluginRequest):
|
||||
try:
|
||||
# the invocation must be synchronous.
|
||||
EnterprisePluginManagerRequest.send_request(
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json=body.model_dump(),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"failed to perform pre uninstall plugin hook. tenant_id: %s, plugin_unique_identifier: %s",
|
||||
body.tenant_id,
|
||||
body.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from . import (
|
||||
conversation,
|
||||
dataset,
|
||||
document,
|
||||
enterprise,
|
||||
file,
|
||||
index,
|
||||
message,
|
||||
@@ -21,6 +22,7 @@ __all__ = [
|
||||
"conversation",
|
||||
"dataset",
|
||||
"document",
|
||||
"enterprise",
|
||||
"file",
|
||||
"index",
|
||||
"message",
|
||||
|
||||
45
api/services/errors/enterprise.py
Normal file
45
api/services/errors/enterprise.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Enterprise service errors."""
|
||||
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class EnterpriseServiceError(BaseServiceError):
|
||||
"""Base exception for enterprise service errors."""
|
||||
|
||||
def __init__(self, description: str | None = None, status_code: int | None = None):
|
||||
super().__init__(description)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class EnterpriseAPIError(EnterpriseServiceError):
|
||||
"""Generic enterprise API error (non-2xx response)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EnterpriseAPINotFoundError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 404 Not Found."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=404)
|
||||
|
||||
|
||||
class EnterpriseAPIForbiddenError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 403 Forbidden."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=403)
|
||||
|
||||
|
||||
class EnterpriseAPIUnauthorizedError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 401 Unauthorized."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=401)
|
||||
|
||||
|
||||
class EnterpriseAPIBadRequestError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 400 Bad Request."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=400)
|
||||
@@ -361,11 +361,14 @@ class FeatureService:
|
||||
)
|
||||
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
|
||||
|
||||
if is_authenticated and (license_info := enterprise_info.get("License")):
|
||||
# License status and expiry are always exposed so the login page can
|
||||
# show the expiry UI after a force-logout (the user is unauthenticated
|
||||
# at that point). Workspace usage details remain auth-gated.
|
||||
if license_info := enterprise_info.get("License"):
|
||||
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
||||
features.license.expired_at = license_info.get("expiredAt", "")
|
||||
|
||||
if workspaces_info := license_info.get("workspaces"):
|
||||
if is_authenticated and (workspaces_info := license_info.get("workspaces")):
|
||||
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
|
||||
features.license.workspaces.limit = workspaces_info.get("limit", 0)
|
||||
features.license.workspaces.size = workspaces_info.get("used", 0)
|
||||
|
||||
@@ -7,9 +7,10 @@ from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from events.feedback_event import feedback_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
@@ -179,6 +180,9 @@ class MessageService:
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if feedback and rating:
|
||||
feedback_was_created.send(feedback, tenant_id=app_model.tenant_id)
|
||||
|
||||
return feedback
|
||||
|
||||
@classmethod
|
||||
@@ -294,10 +298,15 @@ class MessageService:
|
||||
questions: list[str] = list(questions_sequence)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_id=app_model.id)
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
context=TelemetryContext(tenant_id=app_model.tenant_id, app_id=app_model.id),
|
||||
payload={
|
||||
"message_id": message_id,
|
||||
"suggested_question": questions,
|
||||
"timer": timer,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
@@ -5,6 +6,8 @@ from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, TraceAppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpsService:
|
||||
@classmethod
|
||||
@@ -135,12 +138,13 @@ class OpsService:
|
||||
return trace_config_data.to_dict()
|
||||
|
||||
@classmethod
|
||||
def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict):
|
||||
def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict, account_id: str):
|
||||
"""
|
||||
Create tracing app config
|
||||
:param app_id: app id
|
||||
:param tracing_provider: tracing provider
|
||||
:param tracing_config: tracing config
|
||||
:param account_id: account id of the user creating the config
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
@@ -207,15 +211,19 @@ class OpsService:
|
||||
db.session.add(trace_config_data)
|
||||
db.session.commit()
|
||||
|
||||
# Log the creation with modifier information
|
||||
logger.info("Trace config created: app_id=%s, provider=%s, created_by=%s", app_id, tracing_provider, account_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict):
|
||||
def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict, account_id: str):
|
||||
"""
|
||||
Update tracing app config
|
||||
:param app_id: app id
|
||||
:param tracing_provider: tracing provider
|
||||
:param tracing_config: tracing config
|
||||
:param account_id: account id of the user updating the config
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
@@ -251,14 +259,18 @@ class OpsService:
|
||||
current_trace_config.tracing_config = tracing_config
|
||||
db.session.commit()
|
||||
|
||||
# Log the update with modifier information
|
||||
logger.info("Trace config updated: app_id=%s, provider=%s, updated_by=%s", app_id, tracing_provider, account_id)
|
||||
|
||||
return current_trace_config.to_dict()
|
||||
|
||||
@classmethod
|
||||
def delete_tracing_app_config(cls, app_id: str, tracing_provider: str):
|
||||
def delete_tracing_app_config(cls, app_id: str, tracing_provider: str, account_id: str):
|
||||
"""
|
||||
Delete tracing app config
|
||||
:param app_id: app id
|
||||
:param tracing_provider: tracing provider
|
||||
:param account_id: account id of the user deleting the config
|
||||
:return:
|
||||
"""
|
||||
trace_config = (
|
||||
@@ -270,6 +282,9 @@ class OpsService:
|
||||
if not trace_config:
|
||||
return None
|
||||
|
||||
# Log the deletion with modifier information
|
||||
logger.info("Trace config deleted: app_id=%s, provider=%s, deleted_by=%s", app_id, tracing_provider, account_id)
|
||||
|
||||
db.session.delete(trace_config)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user