refactor(api): type OpsTraceProviderConfigMap with TracingProviderCon… (#34424)

This commit is contained in:
YBoy
2026-04-02 03:47:08 +02:00
committed by GitHub
parent 725f9e3dc4
commit 2d29345f26
2 changed files with 18 additions and 12 deletions

View File

@@ -19,6 +19,7 @@ from typing_extensions import TypedDict
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import ( from core.ops.entities.config_entity import (
OPS_FILE_PATH, OPS_FILE_PATH,
BaseTracingConfig,
TracingProviderEnum, TracingProviderEnum,
) )
from core.ops.entities.trace_entity import ( from core.ops.entities.trace_entity import (
@@ -195,8 +196,15 @@ def _lookup_llm_credential_info(
return None, "" return None, ""
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): class TracingProviderConfigEntry(TypedDict):
def __getitem__(self, provider: str) -> dict[str, Any]: config_class: type[BaseTracingConfig]
secret_keys: list[str]
other_keys: list[str]
trace_instance: type[Any]
class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]):
def __getitem__(self, provider: str) -> TracingProviderConfigEntry:
match provider: match provider:
case TracingProviderEnum.LANGFUSE: case TracingProviderEnum.LANGFUSE:
from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.config_entity import LangfuseConfig
@@ -585,8 +593,8 @@ class OpsTraceManager:
provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["trace_instance"], provider_config_map[tracing_provider]["trace_instance"],
) )
tracing_config = config_type(**tracing_config) config = config_type(**tracing_config)
return trace_instance(tracing_config).api_check() return trace_instance(config).api_check()
@staticmethod @staticmethod
def get_trace_config_project_key(tracing_config: dict, tracing_provider: str): def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
@@ -600,8 +608,8 @@ class OpsTraceManager:
provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["trace_instance"], provider_config_map[tracing_provider]["trace_instance"],
) )
tracing_config = config_type(**tracing_config) config = config_type(**tracing_config)
return trace_instance(tracing_config).get_project_key() return trace_instance(config).get_project_key()
@staticmethod @staticmethod
def get_trace_config_project_url(tracing_config: dict, tracing_provider: str): def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
@@ -615,8 +623,8 @@ class OpsTraceManager:
provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["trace_instance"], provider_config_map[tracing_provider]["trace_instance"],
) )
tracing_config = config_type(**tracing_config) config = config_type(**tracing_config)
return trace_instance(tracing_config).get_project_url() return trace_instance(config).get_project_url()
class TraceTask: class TraceTask:

View File

@@ -1,9 +1,7 @@
from typing import Any
from sqlalchemy import select from sqlalchemy import select
from core.ops.entities.config_entity import BaseTracingConfig from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map from core.ops.ops_trace_manager import OpsTraceManager, TracingProviderConfigEntry, provider_config_map
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, TraceAppConfig from models.model import App, TraceAppConfig
@@ -150,7 +148,7 @@ class OpsService:
except KeyError: except KeyError:
return {"error": f"Invalid tracing provider: {tracing_provider}"} return {"error": f"Invalid tracing provider: {tracing_provider}"}
provider_config: dict[str, Any] = provider_config_map[tracing_provider] provider_config: TracingProviderConfigEntry = provider_config_map[tracing_provider]
config_class: type[BaseTracingConfig] = provider_config["config_class"] config_class: type[BaseTracingConfig] = provider_config["config_class"]
other_keys: list[str] = provider_config["other_keys"] other_keys: list[str] = provider_config["other_keys"]