mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 06:19:25 +08:00
refactor(api): type OpsTraceProviderConfigMap with TracingProviderCon… (#34424)
This commit is contained in:
@@ -19,6 +19,7 @@ from typing_extensions import TypedDict
|
|||||||
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
||||||
from core.ops.entities.config_entity import (
|
from core.ops.entities.config_entity import (
|
||||||
OPS_FILE_PATH,
|
OPS_FILE_PATH,
|
||||||
|
BaseTracingConfig,
|
||||||
TracingProviderEnum,
|
TracingProviderEnum,
|
||||||
)
|
)
|
||||||
from core.ops.entities.trace_entity import (
|
from core.ops.entities.trace_entity import (
|
||||||
@@ -195,8 +196,15 @@ def _lookup_llm_credential_info(
|
|||||||
return None, ""
|
return None, ""
|
||||||
|
|
||||||
|
|
||||||
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
class TracingProviderConfigEntry(TypedDict):
|
||||||
def __getitem__(self, provider: str) -> dict[str, Any]:
|
config_class: type[BaseTracingConfig]
|
||||||
|
secret_keys: list[str]
|
||||||
|
other_keys: list[str]
|
||||||
|
trace_instance: type[Any]
|
||||||
|
|
||||||
|
|
||||||
|
class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]):
|
||||||
|
def __getitem__(self, provider: str) -> TracingProviderConfigEntry:
|
||||||
match provider:
|
match provider:
|
||||||
case TracingProviderEnum.LANGFUSE:
|
case TracingProviderEnum.LANGFUSE:
|
||||||
from core.ops.entities.config_entity import LangfuseConfig
|
from core.ops.entities.config_entity import LangfuseConfig
|
||||||
@@ -585,8 +593,8 @@ class OpsTraceManager:
|
|||||||
provider_config_map[tracing_provider]["config_class"],
|
provider_config_map[tracing_provider]["config_class"],
|
||||||
provider_config_map[tracing_provider]["trace_instance"],
|
provider_config_map[tracing_provider]["trace_instance"],
|
||||||
)
|
)
|
||||||
tracing_config = config_type(**tracing_config)
|
config = config_type(**tracing_config)
|
||||||
return trace_instance(tracing_config).api_check()
|
return trace_instance(config).api_check()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
|
def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
|
||||||
@@ -600,8 +608,8 @@ class OpsTraceManager:
|
|||||||
provider_config_map[tracing_provider]["config_class"],
|
provider_config_map[tracing_provider]["config_class"],
|
||||||
provider_config_map[tracing_provider]["trace_instance"],
|
provider_config_map[tracing_provider]["trace_instance"],
|
||||||
)
|
)
|
||||||
tracing_config = config_type(**tracing_config)
|
config = config_type(**tracing_config)
|
||||||
return trace_instance(tracing_config).get_project_key()
|
return trace_instance(config).get_project_key()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
|
def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
|
||||||
@@ -615,8 +623,8 @@ class OpsTraceManager:
|
|||||||
provider_config_map[tracing_provider]["config_class"],
|
provider_config_map[tracing_provider]["config_class"],
|
||||||
provider_config_map[tracing_provider]["trace_instance"],
|
provider_config_map[tracing_provider]["trace_instance"],
|
||||||
)
|
)
|
||||||
tracing_config = config_type(**tracing_config)
|
config = config_type(**tracing_config)
|
||||||
return trace_instance(tracing_config).get_project_url()
|
return trace_instance(config).get_project_url()
|
||||||
|
|
||||||
|
|
||||||
class TraceTask:
|
class TraceTask:
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from core.ops.entities.config_entity import BaseTracingConfig
|
from core.ops.entities.config_entity import BaseTracingConfig
|
||||||
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
|
from core.ops.ops_trace_manager import OpsTraceManager, TracingProviderConfigEntry, provider_config_map
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App, TraceAppConfig
|
from models.model import App, TraceAppConfig
|
||||||
|
|
||||||
@@ -150,7 +148,7 @@ class OpsService:
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
return {"error": f"Invalid tracing provider: {tracing_provider}"}
|
return {"error": f"Invalid tracing provider: {tracing_provider}"}
|
||||||
|
|
||||||
provider_config: dict[str, Any] = provider_config_map[tracing_provider]
|
provider_config: TracingProviderConfigEntry = provider_config_map[tracing_provider]
|
||||||
config_class: type[BaseTracingConfig] = provider_config["config_class"]
|
config_class: type[BaseTracingConfig] = provider_config["config_class"]
|
||||||
other_keys: list[str] = provider_config["other_keys"]
|
other_keys: list[str] = provider_config["other_keys"]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user