mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 02:19:20 +08:00
refactor(api): type OpsTraceProviderConfigMap with TracingProviderCon… (#34424)
This commit is contained in:
@@ -19,6 +19,7 @@ from typing_extensions import TypedDict
|
||||
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.ops.entities.config_entity import (
|
||||
OPS_FILE_PATH,
|
||||
BaseTracingConfig,
|
||||
TracingProviderEnum,
|
||||
)
|
||||
from core.ops.entities.trace_entity import (
|
||||
@@ -195,8 +196,15 @@ def _lookup_llm_credential_info(
|
||||
return None, ""
|
||||
|
||||
|
||||
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
def __getitem__(self, provider: str) -> dict[str, Any]:
|
||||
class TracingProviderConfigEntry(TypedDict):
|
||||
config_class: type[BaseTracingConfig]
|
||||
secret_keys: list[str]
|
||||
other_keys: list[str]
|
||||
trace_instance: type[Any]
|
||||
|
||||
|
||||
class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]):
|
||||
def __getitem__(self, provider: str) -> TracingProviderConfigEntry:
|
||||
match provider:
|
||||
case TracingProviderEnum.LANGFUSE:
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
@@ -585,8 +593,8 @@ class OpsTraceManager:
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["trace_instance"],
|
||||
)
|
||||
tracing_config = config_type(**tracing_config)
|
||||
return trace_instance(tracing_config).api_check()
|
||||
config = config_type(**tracing_config)
|
||||
return trace_instance(config).api_check()
|
||||
|
||||
@staticmethod
|
||||
def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
|
||||
@@ -600,8 +608,8 @@ class OpsTraceManager:
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["trace_instance"],
|
||||
)
|
||||
tracing_config = config_type(**tracing_config)
|
||||
return trace_instance(tracing_config).get_project_key()
|
||||
config = config_type(**tracing_config)
|
||||
return trace_instance(config).get_project_key()
|
||||
|
||||
@staticmethod
|
||||
def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
|
||||
@@ -615,8 +623,8 @@ class OpsTraceManager:
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["trace_instance"],
|
||||
)
|
||||
tracing_config = config_type(**tracing_config)
|
||||
return trace_instance(tracing_config).get_project_url()
|
||||
config = config_type(**tracing_config)
|
||||
return trace_instance(config).get_project_url()
|
||||
|
||||
|
||||
class TraceTask:
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
|
||||
from core.ops.ops_trace_manager import OpsTraceManager, TracingProviderConfigEntry, provider_config_map
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, TraceAppConfig
|
||||
|
||||
@@ -150,7 +148,7 @@ class OpsService:
|
||||
except KeyError:
|
||||
return {"error": f"Invalid tracing provider: {tracing_provider}"}
|
||||
|
||||
provider_config: dict[str, Any] = provider_config_map[tracing_provider]
|
||||
provider_config: TracingProviderConfigEntry = provider_config_map[tracing_provider]
|
||||
config_class: type[BaseTracingConfig] = provider_config["config_class"]
|
||||
other_keys: list[str] = provider_config["other_keys"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user