mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 02:19:20 +08:00
refactor: select in message_service and ops_service (#34414)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Union
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
@@ -75,17 +76,15 @@ class MessageService:
|
||||
fetch_limit = limit + 1
|
||||
|
||||
if first_id:
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id, Message.id == first_id)
|
||||
.first()
|
||||
first_message = db.session.scalar(
|
||||
select(Message).where(Message.conversation_id == conversation.id, Message.id == first_id).limit(1)
|
||||
)
|
||||
|
||||
if not first_message:
|
||||
raise FirstMessageNotExistsError()
|
||||
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
history_messages = db.session.scalars(
|
||||
select(Message)
|
||||
.where(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < first_message.created_at,
|
||||
@@ -93,16 +92,14 @@ class MessageService:
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(fetch_limit)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
history_messages = db.session.scalars(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(fetch_limit)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
has_more = False
|
||||
if len(history_messages) > limit:
|
||||
@@ -129,7 +126,7 @@ class MessageService:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
base_query = db.session.query(Message)
|
||||
stmt = select(Message)
|
||||
|
||||
fetch_limit = limit + 1
|
||||
|
||||
@@ -138,28 +135,27 @@ class MessageService:
|
||||
app_model=app_model, user=user, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
base_query = base_query.where(Message.conversation_id == conversation.id)
|
||||
stmt = stmt.where(Message.conversation_id == conversation.id)
|
||||
|
||||
# Check if include_ids is not None and not empty to avoid WHERE false condition
|
||||
if include_ids is not None:
|
||||
if len(include_ids) == 0:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
base_query = base_query.where(Message.id.in_(include_ids))
|
||||
stmt = stmt.where(Message.id.in_(include_ids))
|
||||
|
||||
if last_id:
|
||||
last_message = base_query.where(Message.id == last_id).first()
|
||||
last_message = db.session.scalar(stmt.where(Message.id == last_id).limit(1))
|
||||
|
||||
if not last_message:
|
||||
raise LastMessageNotExistsError()
|
||||
|
||||
history_messages = (
|
||||
base_query.where(Message.created_at < last_message.created_at, Message.id != last_message.id)
|
||||
history_messages = db.session.scalars(
|
||||
stmt.where(Message.created_at < last_message.created_at, Message.id != last_message.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(fetch_limit)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
history_messages = base_query.order_by(Message.created_at.desc()).limit(fetch_limit).all()
|
||||
history_messages = db.session.scalars(stmt.order_by(Message.created_at.desc()).limit(fetch_limit)).all()
|
||||
|
||||
has_more = False
|
||||
if len(history_messages) > limit:
|
||||
@@ -214,21 +210,20 @@ class MessageService:
|
||||
def get_all_messages_feedbacks(cls, app_model: App, page: int, limit: int):
|
||||
"""Get all feedbacks of an app"""
|
||||
offset = (page - 1) * limit
|
||||
feedbacks = (
|
||||
db.session.query(MessageFeedback)
|
||||
feedbacks = db.session.scalars(
|
||||
select(MessageFeedback)
|
||||
.where(MessageFeedback.app_id == app_model.id)
|
||||
.order_by(MessageFeedback.created_at.desc(), MessageFeedback.id.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return [record.to_dict() for record in feedbacks]
|
||||
|
||||
@classmethod
|
||||
def get_message(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
message = db.session.scalar(
|
||||
select(Message)
|
||||
.where(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
@@ -236,7 +231,7 @@ class MessageService:
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not message:
|
||||
@@ -282,10 +277,10 @@ class MessageService:
|
||||
)
|
||||
else:
|
||||
if not conversation.override_model_configs:
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig)
|
||||
app_model_config = db.session.scalar(
|
||||
select(AppModelConfig)
|
||||
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
conversation_override_model_configs = _app_model_config_adapter.validate_json(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
|
||||
from extensions.ext_database import db
|
||||
@@ -15,17 +17,17 @@ class OpsService:
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
trace_config_data: TraceAppConfig | None = (
|
||||
db.session.query(TraceAppConfig)
|
||||
trace_config_data: TraceAppConfig | None = db.session.scalar(
|
||||
select(TraceAppConfig)
|
||||
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not trace_config_data:
|
||||
return None
|
||||
|
||||
# decrypt_token and obfuscated_token
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
app = db.session.get(App, app_id)
|
||||
if not app:
|
||||
return None
|
||||
tenant_id = app.tenant_id
|
||||
@@ -182,17 +184,17 @@ class OpsService:
|
||||
project_url = None
|
||||
|
||||
# check if trace config already exists
|
||||
trace_config_data: TraceAppConfig | None = (
|
||||
db.session.query(TraceAppConfig)
|
||||
trace_config_data: TraceAppConfig | None = db.session.scalar(
|
||||
select(TraceAppConfig)
|
||||
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if trace_config_data:
|
||||
return None
|
||||
|
||||
# get tenant id
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
app = db.session.get(App, app_id)
|
||||
if not app:
|
||||
return None
|
||||
tenant_id = app.tenant_id
|
||||
@@ -224,17 +226,17 @@ class OpsService:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
|
||||
# check if trace config already exists
|
||||
current_trace_config = (
|
||||
db.session.query(TraceAppConfig)
|
||||
current_trace_config = db.session.scalar(
|
||||
select(TraceAppConfig)
|
||||
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not current_trace_config:
|
||||
return None
|
||||
|
||||
# get tenant id
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
app = db.session.get(App, app_id)
|
||||
if not app:
|
||||
return None
|
||||
tenant_id = app.tenant_id
|
||||
@@ -261,10 +263,10 @@ class OpsService:
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
trace_config = (
|
||||
db.session.query(TraceAppConfig)
|
||||
trace_config = db.session.scalar(
|
||||
select(TraceAppConfig)
|
||||
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not trace_config:
|
||||
|
||||
@@ -151,12 +151,7 @@ class TestMessageServicePaginationByFirstId:
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.all.return_value = messages
|
||||
mock_db.session.scalars.return_value.all.return_value = messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@@ -196,12 +191,7 @@ class TestMessageServicePaginationByFirstId:
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.all.return_value = messages
|
||||
mock_db.session.scalars.return_value.all.return_value = messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@@ -246,31 +236,8 @@ class TestMessageServicePaginationByFirstId:
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
# Setup query mocks
|
||||
mock_query_first = MagicMock()
|
||||
mock_query_history = MagicMock()
|
||||
|
||||
query_calls = []
|
||||
|
||||
def query_side_effect(*args):
|
||||
if args[0] == Message:
|
||||
query_calls.append(args)
|
||||
if len(query_calls) == 1:
|
||||
return mock_query_first
|
||||
else:
|
||||
return mock_query_history
|
||||
|
||||
mock_db.session.query.side_effect = [mock_query_first, mock_query_history]
|
||||
|
||||
# Setup first message query
|
||||
mock_query_first.where.return_value = mock_query_first
|
||||
mock_query_first.first.return_value = first_message
|
||||
|
||||
# Setup history messages query
|
||||
mock_query_history.where.return_value = mock_query_history
|
||||
mock_query_history.order_by.return_value = mock_query_history
|
||||
mock_query_history.limit.return_value = mock_query_history
|
||||
mock_query_history.all.return_value = history_messages
|
||||
mock_db.session.scalar.return_value = first_message
|
||||
mock_db.session.scalars.return_value.all.return_value = history_messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@@ -285,8 +252,6 @@ class TestMessageServicePaginationByFirstId:
|
||||
# Assert
|
||||
assert len(result.data) == 5
|
||||
assert result.has_more is False
|
||||
mock_query_first.where.assert_called_once()
|
||||
mock_query_history.where.assert_called_once()
|
||||
|
||||
# Test 06: First message not found
|
||||
@patch("services.message_service.db")
|
||||
@@ -300,10 +265,7 @@ class TestMessageServicePaginationByFirstId:
|
||||
|
||||
mock_conversation_service.get_conversation.return_value = conversation
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None # Message not found
|
||||
mock_db.session.scalar.return_value = None # Message not found
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(FirstMessageNotExistsError):
|
||||
@@ -336,12 +298,7 @@ class TestMessageServicePaginationByFirstId:
|
||||
for i in range(11)
|
||||
]
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.all.return_value = messages
|
||||
mock_db.session.scalars.return_value.all.return_value = messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@@ -369,12 +326,7 @@ class TestMessageServicePaginationByFirstId:
|
||||
|
||||
mock_conversation_service.get_conversation.return_value = conversation
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@@ -443,12 +395,7 @@ class TestMessageServicePaginationByLastId:
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.all.return_value = messages
|
||||
mock_db.session.scalars.return_value.all.return_value = messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_last_id(
|
||||
@@ -485,22 +432,8 @@ class TestMessageServicePaginationByLastId:
|
||||
for i in range(6, 10)
|
||||
]
|
||||
|
||||
# Setup base query mock that returns itself for chaining
|
||||
mock_base_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_base_query
|
||||
|
||||
# First where() call for last_id lookup
|
||||
mock_query_last = MagicMock()
|
||||
mock_query_last.first.return_value = last_message
|
||||
|
||||
# Second where() call for history messages
|
||||
mock_query_history = MagicMock()
|
||||
mock_query_history.order_by.return_value = mock_query_history
|
||||
mock_query_history.limit.return_value = mock_query_history
|
||||
mock_query_history.all.return_value = new_messages
|
||||
|
||||
# Setup where() to return different mocks on consecutive calls
|
||||
mock_base_query.where.side_effect = [mock_query_last, mock_query_history]
|
||||
mock_db.session.scalar.return_value = last_message
|
||||
mock_db.session.scalars.return_value.all.return_value = new_messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_last_id(
|
||||
@@ -522,10 +455,7 @@ class TestMessageServicePaginationByLastId:
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None # Message not found
|
||||
mock_db.session.scalar.return_value = None # Message not found
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(LastMessageNotExistsError):
|
||||
@@ -557,12 +487,7 @@ class TestMessageServicePaginationByLastId:
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.all.return_value = messages
|
||||
mock_db.session.scalars.return_value.all.return_value = messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_last_id(
|
||||
@@ -576,8 +501,6 @@ class TestMessageServicePaginationByLastId:
|
||||
# Assert
|
||||
assert len(result.data) == 5
|
||||
assert result.has_more is False
|
||||
# Verify conversation_id was used in query
|
||||
mock_query.where.assert_called()
|
||||
mock_conversation_service.get_conversation.assert_called_once()
|
||||
|
||||
# Test 14: Pagination with include_ids filter
|
||||
@@ -594,12 +517,7 @@ class TestMessageServicePaginationByLastId:
|
||||
factory.create_message_mock(message_id="msg-003"),
|
||||
]
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.all.return_value = messages
|
||||
mock_db.session.scalars.return_value.all.return_value = messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_last_id(
|
||||
@@ -632,12 +550,7 @@ class TestMessageServicePaginationByLastId:
|
||||
for i in range(11)
|
||||
]
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.all.return_value = messages
|
||||
mock_db.session.scalars.return_value.all.return_value = messages
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_last_id(
|
||||
@@ -743,17 +656,13 @@ class TestMessageServiceGetMessage:
|
||||
user = factory.create_end_user_mock(user_id="end-user-123")
|
||||
message = factory.create_message_mock()
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = message
|
||||
mock_db.session.scalar.return_value = message
|
||||
|
||||
# Act
|
||||
result = MessageService.get_message(app_model=app, user=user, message_id="msg-123")
|
||||
|
||||
# Assert
|
||||
assert result == message
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
# Test 21: get_message success for Account (Admin)
|
||||
@patch("services.message_service.db")
|
||||
@@ -767,10 +676,7 @@ class TestMessageServiceGetMessage:
|
||||
user.id = "account-123"
|
||||
message = factory.create_message_mock()
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = message
|
||||
mock_db.session.scalar.return_value = message
|
||||
|
||||
# Act
|
||||
result = MessageService.get_message(app_model=app, user=user, message_id="msg-123")
|
||||
@@ -786,10 +692,7 @@ class TestMessageServiceGetMessage:
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
@@ -899,21 +802,13 @@ class TestMessageServiceFeedback:
|
||||
feedback = MagicMock()
|
||||
feedback.to_dict.return_value = {"id": "fb-1"}
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.limit.return_value = mock_query
|
||||
mock_query.offset.return_value = mock_query
|
||||
mock_query.all.return_value = [feedback]
|
||||
mock_db.session.scalars.return_value.all.return_value = [feedback]
|
||||
|
||||
# Act
|
||||
result = MessageService.get_all_messages_feedbacks(app_model=app, page=1, limit=10)
|
||||
|
||||
# Assert
|
||||
assert result == [{"id": "fb-1"}]
|
||||
mock_query.limit.assert_called_with(10)
|
||||
mock_query.offset.assert_called_with(0)
|
||||
|
||||
|
||||
class TestMessageServiceSuggestedQuestions:
|
||||
@@ -1015,10 +910,7 @@ class TestMessageServiceSuggestedQuestions:
|
||||
app_model_config.suggested_questions_after_answer_dict = {"enabled": True}
|
||||
app_model_config.model_dict = {"provider": "openai", "name": "gpt-4"}
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = app_model_config
|
||||
mock_db.session.scalar.return_value = app_model_config
|
||||
|
||||
mock_llm_gen.generate_suggested_questions_after_answer.return_value = ["Q1?"]
|
||||
|
||||
@@ -1029,7 +921,6 @@ class TestMessageServiceSuggestedQuestions:
|
||||
|
||||
# Assert
|
||||
assert result == ["Q1?"]
|
||||
mock_query.first.assert_called_once()
|
||||
mock_llm_gen.generate_suggested_questions_after_answer.assert_called_once()
|
||||
|
||||
# Test 30: get_suggested_questions_after_answer - Disabled Error
|
||||
|
||||
@@ -12,28 +12,27 @@ class TestOpsService:
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_db.session.query.assert_called_with(TraceAppConfig)
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, None]
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
assert mock_db.session.query.call_count == 2
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
@@ -43,7 +42,8 @@ class TestOpsService:
|
||||
trace_config.tracing_config = None
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app]
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = app
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Tracing config cannot be None."):
|
||||
@@ -72,7 +72,8 @@ class TestOpsService:
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app]
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = app
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
|
||||
@@ -97,7 +98,8 @@ class TestOpsService:
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app]
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = app
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
|
||||
@@ -118,7 +120,8 @@ class TestOpsService:
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app]
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = app
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
||||
@@ -139,7 +142,8 @@ class TestOpsService:
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app]
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = app
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
||||
@@ -189,7 +193,7 @@ class TestOpsService:
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
|
||||
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig)
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, config)
|
||||
@@ -206,7 +210,8 @@ class TestOpsService:
|
||||
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app]
|
||||
mock_db.session.scalar.return_value = None
|
||||
mock_db.session.get.return_value = app
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
|
||||
|
||||
# Act
|
||||
@@ -223,7 +228,7 @@ class TestOpsService:
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig)
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
||||
@@ -237,7 +242,8 @@ class TestOpsService:
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [None, None]
|
||||
mock_db.session.scalar.return_value = None
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
||||
@@ -253,7 +259,8 @@ class TestOpsService:
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app]
|
||||
mock_db.session.scalar.return_value = None
|
||||
mock_db.session.get.return_value = app
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
|
||||
|
||||
# Act
|
||||
@@ -274,7 +281,8 @@ class TestOpsService:
|
||||
mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url"
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app]
|
||||
mock_db.session.scalar.return_value = None
|
||||
mock_db.session.get.return_value = app
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"}
|
||||
|
||||
# Act
|
||||
@@ -297,7 +305,7 @@ class TestOpsService:
|
||||
def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
@@ -311,7 +319,8 @@ class TestOpsService:
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
current_config = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, None]
|
||||
mock_db.session.scalar.return_value = current_config
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
@@ -327,7 +336,8 @@ class TestOpsService:
|
||||
current_config = MagicMock(spec=TraceAppConfig)
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app]
|
||||
mock_db.session.scalar.return_value = current_config
|
||||
mock_db.session.get.return_value = app
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
|
||||
|
||||
@@ -344,7 +354,8 @@ class TestOpsService:
|
||||
current_config.to_dict.return_value = {"some": "data"}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app]
|
||||
mock_db.session.scalar.return_value = current_config
|
||||
mock_db.session.get.return_value = app
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
|
||||
@@ -358,7 +369,7 @@ class TestOpsService:
|
||||
@patch("services.ops_service.db")
|
||||
def test_delete_tracing_app_config_no_config(self, mock_db):
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.delete_tracing_app_config("app_id", "arize")
|
||||
@@ -370,7 +381,7 @@ class TestOpsService:
|
||||
def test_delete_tracing_app_config_success(self, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = trace_config
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
|
||||
# Act
|
||||
result = OpsService.delete_tracing_app_config("app_id", "arize")
|
||||
|
||||
Reference in New Issue
Block a user