mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 06:19:25 +08:00
refactor: select in core/ops trace manager and trace providers (#34197)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -27,9 +27,7 @@ DEFAULT_FRAMEWORK_NAME = "dify"
|
||||
def get_user_id_from_message_data(message_data) -> str:
|
||||
user_id = message_data.from_account_id
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id)
|
||||
if end_user_data is not None:
|
||||
user_id = end_user_data.session_id
|
||||
return user_id
|
||||
|
||||
@@ -410,9 +410,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
|
||||
# Add end user data if available
|
||||
if trace_info.message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first()
|
||||
)
|
||||
end_user_data: EndUser | None = db.session.get(EndUser, trace_info.message_data.from_end_user_id)
|
||||
if end_user_data is not None:
|
||||
metadata["end_user_id"] = end_user_data.session_id
|
||||
|
||||
|
||||
@@ -241,9 +241,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id)
|
||||
if end_user_data is not None:
|
||||
user_id = end_user_data.session_id
|
||||
metadata["user_id"] = user_id
|
||||
|
||||
@@ -259,9 +259,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
metadata["user_id"] = user_id
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id)
|
||||
if end_user_data is not None:
|
||||
end_user_id = end_user_data.session_id
|
||||
metadata["end_user_id"] = end_user_id
|
||||
|
||||
@@ -9,6 +9,7 @@ from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType
|
||||
from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey
|
||||
from mlflow.tracing.fluent import start_span_no_context, update_current_trace
|
||||
from mlflow.tracing.provider import detach_span_from_context, set_span_in_context
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig
|
||||
@@ -320,7 +321,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
|
||||
def _get_message_user_id(self, metadata: dict) -> str | None:
|
||||
if (end_user_id := metadata.get("from_end_user_id")) and (
|
||||
end_user_data := db.session.query(EndUser).where(EndUser.id == end_user_id).first()
|
||||
end_user_data := db.session.get(EndUser, end_user_id)
|
||||
):
|
||||
return end_user_data.session_id
|
||||
|
||||
@@ -447,25 +448,11 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
|
||||
def _get_workflow_nodes(self, workflow_run_id: str):
|
||||
"""Helper method to get workflow nodes"""
|
||||
workflow_nodes = (
|
||||
db.session.query(
|
||||
WorkflowNodeExecutionModel.id,
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
WorkflowNodeExecutionModel.title,
|
||||
WorkflowNodeExecutionModel.node_type,
|
||||
WorkflowNodeExecutionModel.status,
|
||||
WorkflowNodeExecutionModel.inputs,
|
||||
WorkflowNodeExecutionModel.outputs,
|
||||
WorkflowNodeExecutionModel.created_at,
|
||||
WorkflowNodeExecutionModel.elapsed_time,
|
||||
WorkflowNodeExecutionModel.process_data,
|
||||
WorkflowNodeExecutionModel.execution_metadata,
|
||||
)
|
||||
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
workflow_nodes = db.session.scalars(
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.order_by(WorkflowNodeExecutionModel.created_at)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
return workflow_nodes
|
||||
|
||||
def _get_node_span_type(self, node_type: str) -> str:
|
||||
|
||||
@@ -288,9 +288,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
metadata["file_list"] = file_list
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id)
|
||||
if end_user_data is not None:
|
||||
end_user_id = end_user_data.session_id
|
||||
metadata["end_user_id"] = end_user_id
|
||||
|
||||
@@ -420,10 +420,10 @@ class OpsTraceManager:
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
trace_config_data: TraceAppConfig | None = (
|
||||
db.session.query(TraceAppConfig)
|
||||
trace_config_data: TraceAppConfig | None = db.session.scalar(
|
||||
select(TraceAppConfig)
|
||||
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not trace_config_data:
|
||||
@@ -463,7 +463,7 @@ class OpsTraceManager:
|
||||
if isinstance(app_id, str) and app_id.startswith("tenant-"):
|
||||
return None
|
||||
|
||||
app: App | None = db.session.query(App).where(App.id == app_id).first()
|
||||
app = db.session.get(App, app_id)
|
||||
|
||||
if app is None:
|
||||
return None
|
||||
@@ -537,7 +537,7 @@ class OpsTraceManager:
|
||||
except KeyError:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
|
||||
app_config: App | None = db.session.query(App).where(App.id == app_id).first()
|
||||
app_config: App | None = db.session.get(App, app_id)
|
||||
if not app_config:
|
||||
raise ValueError("App not found")
|
||||
app_config.tracing = json.dumps(
|
||||
@@ -555,7 +555,7 @@ class OpsTraceManager:
|
||||
:param app_id: app id
|
||||
:return:
|
||||
"""
|
||||
app: App | None = db.session.query(App).where(App.id == app_id).first()
|
||||
app: App | None = db.session.get(App, app_id)
|
||||
if not app:
|
||||
raise ValueError("App not found")
|
||||
if not app.tracing:
|
||||
@@ -883,7 +883,7 @@ class TraceTask:
|
||||
inputs = message_data.message
|
||||
|
||||
# get message file data
|
||||
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
|
||||
message_file_data = db.session.scalar(select(MessageFile).where(MessageFile.message_id == message_id).limit(1))
|
||||
file_list = []
|
||||
if message_file_data and message_file_data.url is not None:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
@@ -972,8 +972,8 @@ class TraceTask:
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_id = None
|
||||
if message_data.workflow_run_id:
|
||||
workflow_app_log_data = (
|
||||
db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
|
||||
workflow_app_log_data = db.session.scalar(
|
||||
select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == message_data.workflow_run_id).limit(1)
|
||||
)
|
||||
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
||||
|
||||
@@ -1015,8 +1015,8 @@ class TraceTask:
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_id = None
|
||||
if message_data.workflow_run_id:
|
||||
workflow_app_log_data = (
|
||||
db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
|
||||
workflow_app_log_data = db.session.scalar(
|
||||
select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == message_data.workflow_run_id).limit(1)
|
||||
)
|
||||
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
||||
|
||||
@@ -1171,7 +1171,7 @@ class TraceTask:
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
file_url = ""
|
||||
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
|
||||
message_file_data = db.session.scalar(select(MessageFile).where(MessageFile.message_id == message_id).limit(1))
|
||||
if message_file_data:
|
||||
message_file_id = message_file_data.id if message_file_data else None
|
||||
type = message_file_data.type
|
||||
|
||||
@@ -245,9 +245,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
attributes["user_id"] = user_id
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id)
|
||||
if end_user_data is not None:
|
||||
end_user_id = end_user_data.session_id
|
||||
attributes["end_user_id"] = end_user_id
|
||||
|
||||
@@ -45,11 +45,8 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch):
|
||||
end_user_data = MagicMock(spec=EndUser)
|
||||
end_user_data.session_id = "session_id"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = end_user_data
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_session.get.return_value = end_user_data
|
||||
|
||||
from core.ops.aliyun_trace.utils import db
|
||||
|
||||
@@ -63,11 +60,8 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch):
|
||||
message_data.from_account_id = "account_id"
|
||||
message_data.from_end_user_id = "end_user_id"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_session.get.return_value = None
|
||||
|
||||
from core.ops.aliyun_trace.utils import db
|
||||
|
||||
|
||||
@@ -365,9 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch):
|
||||
mock_end_user = MagicMock(spec=EndUser)
|
||||
mock_end_user.session_id = "session-id-123"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_end_user
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.query", lambda model: mock_query)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user)
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_generation = MagicMock()
|
||||
|
||||
@@ -319,9 +319,7 @@ def test_message_trace(trace_instance, monkeypatch):
|
||||
# Mock EndUser lookup
|
||||
mock_end_user = MagicMock(spec=EndUser)
|
||||
mock_end_user.session_id = "session-id-123"
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_end_user
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.query", lambda model: mock_query)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user)
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
|
||||
|
||||
@@ -330,7 +330,7 @@ class TestTraceDispatcher:
|
||||
|
||||
class TestWorkflowTrace:
|
||||
def test_basic_workflow_no_nodes(self, trace_instance, mock_tracing, mock_db):
|
||||
mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = []
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
@@ -343,7 +343,7 @@ class TestWorkflowTrace:
|
||||
span.end.assert_called_once()
|
||||
|
||||
def test_workflow_filters_sys_inputs_and_adds_query(self, trace_instance, mock_tracing, mock_db):
|
||||
mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = []
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
@@ -374,7 +374,7 @@ class TestWorkflowTrace:
|
||||
),
|
||||
outputs='{"text": "hello world"}',
|
||||
)
|
||||
mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [llm_node]
|
||||
mock_db.session.scalars.return_value.all.return_value = [llm_node]
|
||||
|
||||
workflow_span = MagicMock()
|
||||
node_span = MagicMock()
|
||||
@@ -397,7 +397,7 @@ class TestWorkflowTrace:
|
||||
}
|
||||
),
|
||||
)
|
||||
mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [qc_node]
|
||||
mock_db.session.scalars.return_value.all.return_value = [qc_node]
|
||||
workflow_span = MagicMock()
|
||||
node_span = MagicMock()
|
||||
mock_tracing["start"].side_effect = [workflow_span, node_span]
|
||||
@@ -411,7 +411,7 @@ class TestWorkflowTrace:
|
||||
node_type=BuiltinNodeTypes.HTTP_REQUEST,
|
||||
process_data='{"url": "https://api.com"}',
|
||||
)
|
||||
mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [http_node]
|
||||
mock_db.session.scalars.return_value.all.return_value = [http_node]
|
||||
workflow_span = MagicMock()
|
||||
node_span = MagicMock()
|
||||
mock_tracing["start"].side_effect = [workflow_span, node_span]
|
||||
@@ -434,7 +434,7 @@ class TestWorkflowTrace:
|
||||
}
|
||||
),
|
||||
)
|
||||
mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [kr_node]
|
||||
mock_db.session.scalars.return_value.all.return_value = [kr_node]
|
||||
workflow_span = MagicMock()
|
||||
node_span = MagicMock()
|
||||
mock_tracing["start"].side_effect = [workflow_span, node_span]
|
||||
@@ -448,7 +448,7 @@ class TestWorkflowTrace:
|
||||
|
||||
def test_workflow_with_failed_node(self, trace_instance, mock_tracing, mock_db):
|
||||
failed_node = _make_node(status="failed")
|
||||
mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [failed_node]
|
||||
mock_db.session.scalars.return_value.all.return_value = [failed_node]
|
||||
workflow_span = MagicMock()
|
||||
node_span = MagicMock()
|
||||
mock_tracing["start"].side_effect = [workflow_span, node_span]
|
||||
@@ -459,7 +459,7 @@ class TestWorkflowTrace:
|
||||
node_span.add_event.assert_called_once()
|
||||
|
||||
def test_workflow_with_workflow_error(self, trace_instance, mock_tracing, mock_db):
|
||||
mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = []
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
workflow_span = MagicMock()
|
||||
mock_tracing["start"].return_value = workflow_span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
@@ -473,7 +473,7 @@ class TestWorkflowTrace:
|
||||
|
||||
def test_workflow_node_no_inputs_no_outputs(self, trace_instance, mock_tracing, mock_db):
|
||||
node = _make_node(inputs=None, outputs=None)
|
||||
mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [node]
|
||||
mock_db.session.scalars.return_value.all.return_value = [node]
|
||||
workflow_span = MagicMock()
|
||||
node_span = MagicMock()
|
||||
mock_tracing["start"].side_effect = [workflow_span, node_span]
|
||||
@@ -486,7 +486,7 @@ class TestWorkflowTrace:
|
||||
assert end_call.kwargs["outputs"] == {}
|
||||
|
||||
def test_workflow_no_user_id_no_conversation_id(self, trace_instance, mock_tracing, mock_db):
|
||||
mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = []
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
@@ -501,7 +501,7 @@ class TestWorkflowTrace:
|
||||
|
||||
def test_workflow_empty_query(self, trace_instance, mock_tracing, mock_db):
|
||||
"""When query is empty string, it's falsy so no query key added."""
|
||||
mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = []
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
@@ -680,12 +680,12 @@ class TestGetMessageUserId:
|
||||
def test_returns_end_user_session_id(self, trace_instance, mock_db):
|
||||
end_user = MagicMock()
|
||||
end_user.session_id = "session-1"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = end_user
|
||||
mock_db.session.get.return_value = end_user
|
||||
result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1"})
|
||||
assert result == "session-1"
|
||||
|
||||
def test_returns_account_id_when_no_end_user(self, trace_instance, mock_db):
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.get.return_value = None
|
||||
result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1", "from_account_id": "acc-1"})
|
||||
assert result == "acc-1"
|
||||
|
||||
@@ -834,7 +834,7 @@ class TestGenerateNameTrace:
|
||||
|
||||
class TestGetWorkflowNodes:
|
||||
def test_queries_db(self, trace_instance, mock_db):
|
||||
mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = ["n1", "n2"]
|
||||
mock_db.session.scalars.return_value.all.return_value = ["n1", "n2"]
|
||||
result = trace_instance._get_workflow_nodes("run-1")
|
||||
assert result == ["n1", "n2"]
|
||||
|
||||
|
||||
@@ -373,9 +373,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch):
|
||||
mock_end_user = MagicMock(spec=EndUser)
|
||||
mock_end_user.session_id = "session-id-123"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_end_user
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.query", lambda model: mock_query)
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.get", lambda model, pk: mock_end_user)
|
||||
|
||||
trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2"))
|
||||
trace_instance.add_span = MagicMock()
|
||||
|
||||
@@ -157,17 +157,19 @@ def make_workflow_run():
|
||||
)
|
||||
|
||||
|
||||
def configure_db_query(session, *, message_file=None, workflow_app_log=None):
|
||||
def _side_effect(model):
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value.first.return_value = None
|
||||
if message_file and model.__name__ == "MessageFile":
|
||||
query.filter_by.return_value.first.return_value = message_file
|
||||
if workflow_app_log and model.__name__ == "WorkflowAppLog":
|
||||
query.filter_by.return_value.first.return_value = workflow_app_log
|
||||
return query
|
||||
def configure_db_scalar(session, *, message_file=None, workflow_app_log=None):
|
||||
"""Configure session.scalar to return appropriate values for MessageFile/WorkflowAppLog lookups."""
|
||||
original_scalar = session.scalar
|
||||
|
||||
session.query.side_effect = _side_effect
|
||||
def _side_effect(stmt):
|
||||
stmt_str = str(stmt)
|
||||
if "message_file" in stmt_str.lower():
|
||||
return message_file
|
||||
if "workflow_app_log" in stmt_str.lower():
|
||||
return workflow_app_log
|
||||
return original_scalar(stmt)
|
||||
|
||||
session.scalar.side_effect = _side_effect
|
||||
|
||||
|
||||
class DummySessionContext:
|
||||
@@ -263,7 +265,7 @@ def workflow_repo_fixture(monkeypatch):
|
||||
def trace_task_message(monkeypatch, mock_db):
|
||||
message_data = make_message_data()
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda msg_id: message_data)
|
||||
configure_db_query(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id"))
|
||||
configure_db_scalar(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id"))
|
||||
return message_data
|
||||
|
||||
|
||||
@@ -307,56 +309,53 @@ def test_obfuscated_decrypt_token(encryption_mocks):
|
||||
|
||||
def test_get_decrypted_tracing_config_returns_config(encryption_mocks, mock_db):
|
||||
trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc", "other_value": "info"})
|
||||
mock_db.query.return_value.where.return_value.first.return_value = trace_config_data
|
||||
app = SimpleNamespace(id="app-id", tenant_id="tenant")
|
||||
mock_db.scalar.return_value = app
|
||||
mock_db.scalar.side_effect = [trace_config_data, app]
|
||||
|
||||
decrypted = OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy")
|
||||
assert decrypted["other_value"] == "info"
|
||||
|
||||
|
||||
def test_get_decrypted_tracing_config_missing_trace_config(mock_db):
|
||||
mock_db.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.scalar.return_value = None
|
||||
assert OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") is None
|
||||
|
||||
|
||||
def test_get_decrypted_tracing_config_raises_for_missing_app(mock_db):
|
||||
trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc"})
|
||||
mock_db.query.return_value.where.return_value.first.return_value = trace_config_data
|
||||
mock_db.scalar.return_value = None
|
||||
mock_db.scalar.side_effect = [trace_config_data, None]
|
||||
with pytest.raises(ValueError, match="App not found"):
|
||||
OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy")
|
||||
|
||||
|
||||
def test_get_decrypted_tracing_config_raises_for_none_config(mock_db):
|
||||
trace_config_data = SimpleNamespace(tracing_config=None)
|
||||
mock_db.query.return_value.where.return_value.first.return_value = trace_config_data
|
||||
mock_db.scalar.return_value = SimpleNamespace(tenant_id="tenant")
|
||||
mock_db.scalar.side_effect = [trace_config_data, SimpleNamespace(tenant_id="tenant")]
|
||||
with pytest.raises(ValueError, match="Tracing config cannot be None"):
|
||||
OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy")
|
||||
|
||||
|
||||
def test_get_ops_trace_instance_handles_none_app(mock_db):
|
||||
mock_db.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.get.return_value = None
|
||||
assert OpsTraceManager.get_ops_trace_instance("app-id") is None
|
||||
|
||||
|
||||
def test_get_ops_trace_instance_returns_none_when_disabled(mock_db, monkeypatch):
|
||||
app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": False}))
|
||||
mock_db.query.return_value.where.return_value.first.return_value = app
|
||||
mock_db.get.return_value = app
|
||||
assert OpsTraceManager.get_ops_trace_instance("app-id") is None
|
||||
|
||||
|
||||
def test_get_ops_trace_instance_invalid_provider(mock_db, monkeypatch):
|
||||
app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "missing"}))
|
||||
mock_db.query.return_value.where.return_value.first.return_value = app
|
||||
mock_db.get.return_value = app
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({}))
|
||||
assert OpsTraceManager.get_ops_trace_instance("app-id") is None
|
||||
|
||||
|
||||
def test_get_ops_trace_instance_success(monkeypatch, mock_db):
|
||||
app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "dummy"}))
|
||||
mock_db.query.return_value.where.return_value.first.return_value = app
|
||||
mock_db.get.return_value = app
|
||||
monkeypatch.setattr(
|
||||
"core.ops.ops_trace_manager.OpsTraceManager.get_decrypted_tracing_config",
|
||||
classmethod(lambda cls, aid, provider: {"secret_value": "decrypted", "other_value": "info"}),
|
||||
@@ -390,7 +389,7 @@ def test_get_app_config_through_message_id_app_model_config(mock_db):
|
||||
|
||||
|
||||
def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch):
|
||||
mock_db.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.get.return_value = None
|
||||
with pytest.raises(ValueError, match="Invalid tracing provider"):
|
||||
OpsTraceManager.update_app_tracing_config("app", True, "bad")
|
||||
with pytest.raises(ValueError, match="App not found"):
|
||||
@@ -399,26 +398,26 @@ def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch):
|
||||
|
||||
def test_update_app_tracing_config_success(mock_db):
|
||||
app = SimpleNamespace(id="app-id", tracing="{}")
|
||||
mock_db.query.return_value.where.return_value.first.return_value = app
|
||||
mock_db.get.return_value = app
|
||||
OpsTraceManager.update_app_tracing_config("app-id", True, "dummy")
|
||||
assert app.tracing is not None
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_app_tracing_config_errors_when_missing(mock_db):
|
||||
mock_db.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.get.return_value = None
|
||||
with pytest.raises(ValueError, match="App not found"):
|
||||
OpsTraceManager.get_app_tracing_config("app")
|
||||
|
||||
|
||||
def test_get_app_tracing_config_returns_defaults(mock_db):
|
||||
mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=None)
|
||||
mock_db.get.return_value = SimpleNamespace(tracing=None)
|
||||
assert OpsTraceManager.get_app_tracing_config("app-id") == {"enabled": False, "tracing_provider": None}
|
||||
|
||||
|
||||
def test_get_app_tracing_config_returns_payload(mock_db):
|
||||
payload = {"enabled": True, "tracing_provider": "dummy"}
|
||||
mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=json.dumps(payload))
|
||||
mock_db.get.return_value = SimpleNamespace(tracing=json.dumps(payload))
|
||||
assert OpsTraceManager.get_app_tracing_config("app-id") == payload
|
||||
|
||||
|
||||
@@ -501,7 +500,7 @@ def test_trace_task_dataset_retrieval_trace(trace_task_message):
|
||||
def test_trace_task_tool_trace(monkeypatch, mock_db):
|
||||
custom_message = make_message_data(agent_thoughts=[make_agent_thought("tool-a", datetime(2025, 2, 20, 12, 1, 0))])
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda _: custom_message)
|
||||
configure_db_query(mock_db, message_file=FakeMessageFile())
|
||||
configure_db_scalar(mock_db, message_file=FakeMessageFile())
|
||||
task = TraceTask(trace_type=TraceTaskName.TOOL_TRACE, message_id="msg-id")
|
||||
timer = {"start": 1, "end": 5}
|
||||
result = task.tool_trace("msg-id", timer, tool_name="tool-a", tool_inputs={"foo": 1}, tool_outputs="result")
|
||||
|
||||
@@ -802,8 +802,8 @@ class TestMessageTrace:
|
||||
def test_basic_message_trace(self, trace_instance, monkeypatch):
|
||||
"""message_trace creates message run and llm child run."""
|
||||
monkeypatch.setattr(
|
||||
"core.ops.weave_trace.weave_trace.db.session.query",
|
||||
lambda model: MagicMock(where=lambda: MagicMock(first=lambda: None)),
|
||||
"core.ops.weave_trace.weave_trace.db.session.get",
|
||||
lambda model, pk: None,
|
||||
)
|
||||
|
||||
trace_instance.start_call = MagicMock()
|
||||
@@ -823,7 +823,7 @@ class TestMessageTrace:
|
||||
trace_instance.file_base_url = "http://files.test"
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.get.return_value = None
|
||||
monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
|
||||
|
||||
trace_instance.start_call = MagicMock()
|
||||
@@ -845,7 +845,7 @@ class TestMessageTrace:
|
||||
end_user.session_id = "session-xyz"
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = end_user
|
||||
mock_db.session.get.return_value = end_user
|
||||
monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
|
||||
|
||||
trace_instance.start_call = MagicMock()
|
||||
@@ -865,7 +865,7 @@ class TestMessageTrace:
|
||||
def test_message_trace_no_end_user(self, trace_instance, monkeypatch):
|
||||
"""message_trace handles when from_end_user_id is None."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.get.return_value = None
|
||||
monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
|
||||
|
||||
trace_instance.start_call = MagicMock()
|
||||
@@ -883,7 +883,7 @@ class TestMessageTrace:
|
||||
def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch):
|
||||
"""trace_id falls back to message_id when trace_id is None."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.get.return_value = None
|
||||
monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
|
||||
|
||||
trace_instance.start_call = MagicMock()
|
||||
@@ -898,7 +898,7 @@ class TestMessageTrace:
|
||||
def test_message_trace_file_list_none(self, trace_instance, monkeypatch):
|
||||
"""message_trace handles file_list=None gracefully."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.get.return_value = None
|
||||
monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
|
||||
|
||||
trace_instance.start_call = MagicMock()
|
||||
|
||||
Reference in New Issue
Block a user