mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 05:09:19 +08:00
refactor: select in 10 service files (#34373)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
@@ -2,6 +2,7 @@ import threading
|
||||
from typing import Any
|
||||
|
||||
import pytz
|
||||
from sqlalchemy import select
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
|
||||
@@ -23,25 +24,25 @@ class AgentService:
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
conversation: Conversation | None = (
|
||||
db.session.query(Conversation)
|
||||
conversation: Conversation | None = db.session.scalar(
|
||||
select(Conversation)
|
||||
.where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise ValueError(f"Conversation not found: {conversation_id}")
|
||||
|
||||
message: Message | None = (
|
||||
db.session.query(Message)
|
||||
message: Message | None = db.session.scalar(
|
||||
select(Message)
|
||||
.where(
|
||||
Message.id == message_id,
|
||||
Message.conversation_id == conversation_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not message:
|
||||
@@ -51,16 +52,11 @@ class AgentService:
|
||||
|
||||
if conversation.from_end_user_id:
|
||||
# only select name field
|
||||
executor = (
|
||||
db.session.query(EndUser, EndUser.name).where(EndUser.id == conversation.from_end_user_id).first()
|
||||
)
|
||||
executor_name = db.session.scalar(select(EndUser.name).where(EndUser.id == conversation.from_end_user_id))
|
||||
else:
|
||||
executor = db.session.query(Account, Account.name).where(Account.id == conversation.from_account_id).first()
|
||||
executor_name = db.session.scalar(select(Account.name).where(Account.id == conversation.from_account_id))
|
||||
|
||||
if executor:
|
||||
executor = executor.name
|
||||
else:
|
||||
executor = "Unknown"
|
||||
executor = executor_name or "Unknown"
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.timezone is not None
|
||||
timezone = pytz.timezone(current_user.timezone)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
|
||||
from core.helper.encrypter import decrypt_token, encrypt_token
|
||||
from extensions.ext_database import db
|
||||
@@ -7,11 +9,12 @@ from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
class APIBasedExtensionService:
|
||||
@staticmethod
|
||||
def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]:
|
||||
extension_list = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter_by(tenant_id=tenant_id)
|
||||
.order_by(APIBasedExtension.created_at.desc())
|
||||
.all()
|
||||
extension_list = list(
|
||||
db.session.scalars(
|
||||
select(APIBasedExtension)
|
||||
.where(APIBasedExtension.tenant_id == tenant_id)
|
||||
.order_by(APIBasedExtension.created_at.desc())
|
||||
).all()
|
||||
)
|
||||
|
||||
for extension in extension_list:
|
||||
@@ -36,11 +39,10 @@ class APIBasedExtensionService:
|
||||
|
||||
@staticmethod
|
||||
def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
|
||||
extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter_by(tenant_id=tenant_id)
|
||||
.filter_by(id=api_based_extension_id)
|
||||
.first()
|
||||
extension = db.session.scalar(
|
||||
select(APIBasedExtension)
|
||||
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not extension:
|
||||
@@ -58,23 +60,27 @@ class APIBasedExtensionService:
|
||||
|
||||
if not extension_data.id:
|
||||
# case one: check new data, name must be unique
|
||||
is_name_existed = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter_by(tenant_id=extension_data.tenant_id)
|
||||
.filter_by(name=extension_data.name)
|
||||
.first()
|
||||
is_name_existed = db.session.scalar(
|
||||
select(APIBasedExtension)
|
||||
.where(
|
||||
APIBasedExtension.tenant_id == extension_data.tenant_id,
|
||||
APIBasedExtension.name == extension_data.name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if is_name_existed:
|
||||
raise ValueError("name must be unique, it is already existed")
|
||||
else:
|
||||
# case two: check existing data, name must be unique
|
||||
is_name_existed = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter_by(tenant_id=extension_data.tenant_id)
|
||||
.filter_by(name=extension_data.name)
|
||||
.where(APIBasedExtension.id != extension_data.id)
|
||||
.first()
|
||||
is_name_existed = db.session.scalar(
|
||||
select(APIBasedExtension)
|
||||
.where(
|
||||
APIBasedExtension.tenant_id == extension_data.tenant_id,
|
||||
APIBasedExtension.name == extension_data.name,
|
||||
APIBasedExtension.id != extension_data.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if is_name_existed:
|
||||
|
||||
@@ -6,6 +6,7 @@ import sqlalchemy as sa
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from constants.model_template import default_app_templates
|
||||
@@ -433,9 +434,7 @@ class AppService:
|
||||
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
|
||||
elif provider_type == "api":
|
||||
try:
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first()
|
||||
)
|
||||
provider: ApiToolProvider | None = db.session.get(ApiToolProvider, provider_id)
|
||||
if provider is None:
|
||||
raise ValueError(f"provider not found for tool {tool_name}")
|
||||
meta["tool_icons"][tool_name] = json.loads(provider.icon)
|
||||
@@ -451,7 +450,7 @@ class AppService:
|
||||
:param app_id: app id
|
||||
:return: app code
|
||||
"""
|
||||
site = db.session.query(Site).where(Site.app_id == app_id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_id).limit(1))
|
||||
if not site:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
return str(site.code)
|
||||
@@ -463,7 +462,7 @@ class AppService:
|
||||
:param app_code: app code
|
||||
:return: app id
|
||||
"""
|
||||
site = db.session.query(Site).where(Site.code == app_code).first()
|
||||
site = db.session.scalar(select(Site).where(Site.code == app_code).limit(1))
|
||||
if not site:
|
||||
raise ValueError(f"App with code {app_code} not found")
|
||||
return str(site.app_id)
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Response
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import or_, select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.enums import FeedbackRating
|
||||
@@ -41,8 +41,8 @@ class FeedbackService:
|
||||
raise ValueError(f"Unsupported format: {format_type}")
|
||||
|
||||
# Build base query
|
||||
query = (
|
||||
db.session.query(MessageFeedback, Message, Conversation, App, Account)
|
||||
stmt = (
|
||||
select(MessageFeedback, Message, Conversation, App, Account)
|
||||
.join(Message, MessageFeedback.message_id == Message.id)
|
||||
.join(Conversation, MessageFeedback.conversation_id == Conversation.id)
|
||||
.join(App, MessageFeedback.app_id == App.id)
|
||||
@@ -52,36 +52,36 @@ class FeedbackService:
|
||||
|
||||
# Apply filters
|
||||
if from_source:
|
||||
query = query.filter(MessageFeedback.from_source == from_source)
|
||||
stmt = stmt.where(MessageFeedback.from_source == from_source)
|
||||
|
||||
if rating:
|
||||
query = query.filter(MessageFeedback.rating == rating)
|
||||
stmt = stmt.where(MessageFeedback.rating == rating)
|
||||
|
||||
if has_comment is not None:
|
||||
if has_comment:
|
||||
query = query.filter(MessageFeedback.content.isnot(None), MessageFeedback.content != "")
|
||||
stmt = stmt.where(MessageFeedback.content.isnot(None), MessageFeedback.content != "")
|
||||
else:
|
||||
query = query.filter(or_(MessageFeedback.content.is_(None), MessageFeedback.content == ""))
|
||||
stmt = stmt.where(or_(MessageFeedback.content.is_(None), MessageFeedback.content == ""))
|
||||
|
||||
if start_date:
|
||||
try:
|
||||
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
query = query.filter(MessageFeedback.created_at >= start_dt)
|
||||
stmt = stmt.where(MessageFeedback.created_at >= start_dt)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid start_date format: {start_date}. Use YYYY-MM-DD")
|
||||
|
||||
if end_date:
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
query = query.filter(MessageFeedback.created_at <= end_dt)
|
||||
stmt = stmt.where(MessageFeedback.created_at <= end_dt)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid end_date format: {end_date}. Use YYYY-MM-DD")
|
||||
|
||||
# Order by creation date (newest first)
|
||||
query = query.order_by(MessageFeedback.created_at.desc())
|
||||
stmt = stmt.order_by(MessageFeedback.created_at.desc())
|
||||
|
||||
# Execute query
|
||||
results = query.all()
|
||||
results = db.session.execute(stmt).all()
|
||||
|
||||
# Prepare data for export
|
||||
export_data = []
|
||||
|
||||
@@ -6,6 +6,7 @@ from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import select
|
||||
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
@@ -26,7 +27,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class RagPipelineTransformService:
|
||||
def transform_dataset(self, dataset_id: str):
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = db.session.get(Dataset, dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
if dataset.pipeline_id and dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE:
|
||||
@@ -306,7 +307,7 @@ class RagPipelineTransformService:
|
||||
jina_node_id = "1752491761974"
|
||||
firecrawl_node_id = "1752565402678"
|
||||
|
||||
documents = db.session.query(Document).where(Document.dataset_id == dataset.id).all()
|
||||
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset.id)).all()
|
||||
|
||||
for document in documents:
|
||||
data_source_info_dict = document.data_source_info_dict
|
||||
@@ -316,7 +317,7 @@ class RagPipelineTransformService:
|
||||
document.data_source_type = DataSourceType.LOCAL_FILE
|
||||
file_id = data_source_info_dict.get("upload_file_id")
|
||||
if file_id:
|
||||
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
file = db.session.get(UploadFile, file_id)
|
||||
if file:
|
||||
data_source_info = json.dumps(
|
||||
{
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from models.model import AccountTrialAppRecord, TrialApp
|
||||
@@ -27,7 +29,7 @@ class RecommendedAppService:
|
||||
apps = result["recommended_apps"]
|
||||
for app in apps:
|
||||
app_id = app["app_id"]
|
||||
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
|
||||
trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1))
|
||||
if trial_app_model:
|
||||
app["can_trial"] = True
|
||||
else:
|
||||
@@ -46,7 +48,7 @@ class RecommendedAppService:
|
||||
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
|
||||
if FeatureService.get_system_features().enable_trial_app:
|
||||
app_id = result["id"]
|
||||
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
|
||||
trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1))
|
||||
if trial_app_model:
|
||||
result["can_trial"] = True
|
||||
else:
|
||||
@@ -60,10 +62,10 @@ class RecommendedAppService:
|
||||
:param app_id: app id
|
||||
:return:
|
||||
"""
|
||||
account_trial_app_record = (
|
||||
db.session.query(AccountTrialAppRecord)
|
||||
account_trial_app_record = db.session.scalar(
|
||||
select(AccountTrialAppRecord)
|
||||
.where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if account_trial_app_record:
|
||||
account_trial_app_record.count += 1
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
@@ -16,16 +18,15 @@ class SavedMessageService:
|
||||
) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
raise ValueError("User is required")
|
||||
saved_messages = (
|
||||
db.session.query(SavedMessage)
|
||||
saved_messages = db.session.scalars(
|
||||
select(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
SavedMessage.created_by == user.id,
|
||||
)
|
||||
.order_by(SavedMessage.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
message_ids = [sm.message_id for sm in saved_messages]
|
||||
|
||||
return MessageService.pagination_by_last_id(
|
||||
@@ -36,15 +37,15 @@ class SavedMessageService:
|
||||
def save(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
|
||||
if not user:
|
||||
return
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
saved_message = db.session.scalar(
|
||||
select(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.message_id == message_id,
|
||||
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
SavedMessage.created_by == user.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if saved_message:
|
||||
@@ -66,15 +67,15 @@ class SavedMessageService:
|
||||
def delete(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
|
||||
if not user:
|
||||
return
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
saved_message = db.session.scalar(
|
||||
select(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.message_id == message_id,
|
||||
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
SavedMessage.created_by == user.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not saved_message:
|
||||
|
||||
@@ -332,12 +332,11 @@ class BuiltinToolManageService:
|
||||
get builtin tool provider credentials
|
||||
"""
|
||||
with db.session.no_autoflush:
|
||||
providers = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider_name)
|
||||
providers = db.session.scalars(
|
||||
select(BuiltinToolProvider)
|
||||
.where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if len(providers) == 0:
|
||||
return []
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
@@ -29,7 +30,7 @@ class VectorService:
|
||||
|
||||
for segment in segments:
|
||||
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
|
||||
dataset_document = db.session.get(DatasetDocument, segment.document_id)
|
||||
if not dataset_document:
|
||||
logger.warning(
|
||||
"Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
|
||||
@@ -38,11 +39,7 @@ class VectorService:
|
||||
)
|
||||
continue
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
processing_rule = db.session.get(DatasetProcessRule, dataset_document.dataset_process_rule_id)
|
||||
if not processing_rule:
|
||||
raise ValueError("No processing rule found.")
|
||||
# get embedding model instance
|
||||
@@ -271,8 +268,8 @@ class VectorService:
|
||||
vector.delete_by_ids(old_attachment_ids)
|
||||
|
||||
# Delete existing segment attachment bindings in one operation
|
||||
db.session.query(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id).delete(
|
||||
synchronize_session=False
|
||||
db.session.execute(
|
||||
delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id)
|
||||
)
|
||||
|
||||
if not attachment_ids:
|
||||
@@ -280,7 +277,7 @@ class VectorService:
|
||||
return
|
||||
|
||||
# Bulk fetch upload files - only fetch needed fields
|
||||
upload_file_list = db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
|
||||
upload_file_list = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all()
|
||||
|
||||
if not upload_file_list:
|
||||
db.session.commit()
|
||||
|
||||
@@ -138,14 +138,14 @@ class WorkflowService:
|
||||
if workflow_id:
|
||||
return self.get_published_workflow_by_id(app_model, workflow_id)
|
||||
# fetch draft workflow by app_model
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# return draft workflow
|
||||
@@ -155,14 +155,14 @@ class WorkflowService:
|
||||
"""
|
||||
fetch published workflow by workflow_id
|
||||
"""
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.id == workflow_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not workflow:
|
||||
return None
|
||||
@@ -182,14 +182,14 @@ class WorkflowService:
|
||||
return None
|
||||
|
||||
# fetch published workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.id == app_model.workflow_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
return workflow
|
||||
@@ -544,14 +544,14 @@ class WorkflowService:
|
||||
|
||||
# Use the same fallback logic as runtime: get the first available credential
|
||||
# ordered by is_default DESC, created_at ASC (same as tool_manager.py)
|
||||
default_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
default_provider = db.session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not default_provider:
|
||||
|
||||
@@ -99,7 +99,7 @@ class TestFeedbackService:
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
|
||||
# Test CSV export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
@@ -138,7 +138,7 @@ class TestFeedbackService:
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
|
||||
# Test JSON export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
@@ -175,7 +175,7 @@ class TestFeedbackService:
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
|
||||
# Test with filters
|
||||
result = FeedbackService.export_feedbacks(
|
||||
@@ -188,11 +188,8 @@ class TestFeedbackService:
|
||||
format_type="csv",
|
||||
)
|
||||
|
||||
# Verify filters were applied
|
||||
assert mock_query.filter.called
|
||||
filter_calls = mock_query.filter.call_args_list
|
||||
# At least three filter invocations are expected (source, rating, comment)
|
||||
assert len(filter_calls) >= 3
|
||||
# Verify query was executed (filters are baked into the select statement)
|
||||
assert mock_db_session.execute.called
|
||||
|
||||
def test_export_feedbacks_no_data(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback when no data exists."""
|
||||
@@ -206,7 +203,7 @@ class TestFeedbackService:
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
|
||||
@@ -271,7 +268,7 @@ class TestFeedbackService:
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
@@ -329,7 +326,7 @@ class TestFeedbackService:
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
@@ -367,7 +364,7 @@ class TestFeedbackService:
|
||||
),
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
|
||||
@@ -77,22 +77,12 @@ def _make_segment(
|
||||
def _mock_db_session_for_update_multimodel(*, upload_files: list[_UploadFileStub] | None) -> MagicMock:
|
||||
session = MagicMock(name="session")
|
||||
|
||||
binding_query = MagicMock(name="binding_query")
|
||||
binding_query.where.return_value = binding_query
|
||||
binding_query.delete.return_value = 1
|
||||
# db.session.execute() is used for delete(SegmentAttachmentBinding).where(...)
|
||||
session.execute = MagicMock(name="execute")
|
||||
|
||||
upload_query = MagicMock(name="upload_query")
|
||||
upload_query.where.return_value = upload_query
|
||||
upload_query.all.return_value = upload_files or []
|
||||
# db.session.scalars(select(UploadFile).where(...)).all() returns upload files
|
||||
session.scalars.return_value.all.return_value = upload_files or []
|
||||
|
||||
def query_side_effect(model: object) -> MagicMock:
|
||||
if model is vector_service_module.SegmentAttachmentBinding:
|
||||
return binding_query
|
||||
if model is vector_service_module.UploadFile:
|
||||
return upload_query
|
||||
return MagicMock(name=f"query({model})")
|
||||
|
||||
session.query.side_effect = query_side_effect
|
||||
db_mock = MagicMock(name="db")
|
||||
db_mock.session = session
|
||||
return db_mock
|
||||
@@ -165,22 +155,15 @@ def _mock_parent_child_queries(
|
||||
) -> MagicMock:
|
||||
session = MagicMock(name="session")
|
||||
|
||||
doc_query = MagicMock(name="doc_query")
|
||||
doc_query.filter_by.return_value = doc_query
|
||||
doc_query.first.return_value = dataset_document
|
||||
get_dispatch: dict[object, object | None] = {
|
||||
vector_service_module.DatasetDocument: dataset_document,
|
||||
vector_service_module.DatasetProcessRule: processing_rule,
|
||||
}
|
||||
|
||||
rule_query = MagicMock(name="rule_query")
|
||||
rule_query.where.return_value = rule_query
|
||||
rule_query.first.return_value = processing_rule
|
||||
def get_side_effect(model: object, pk: object) -> object | None:
|
||||
return get_dispatch.get(model)
|
||||
|
||||
def query_side_effect(model: object) -> MagicMock:
|
||||
if model is vector_service_module.DatasetDocument:
|
||||
return doc_query
|
||||
if model is vector_service_module.DatasetProcessRule:
|
||||
return rule_query
|
||||
return MagicMock(name=f"query({model})")
|
||||
|
||||
session.query.side_effect = query_side_effect
|
||||
session.get.side_effect = get_side_effect
|
||||
db_mock = MagicMock(name="db")
|
||||
db_mock.session = session
|
||||
return db_mock
|
||||
@@ -609,7 +592,7 @@ def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids(
|
||||
|
||||
vector_cls.assert_called_once_with(dataset=dataset)
|
||||
vector_instance.delete_by_ids.assert_called_once_with(["old-1", "old-2"])
|
||||
db_mock.session.query.assert_called_once_with(vector_service_module.SegmentAttachmentBinding)
|
||||
db_mock.session.execute.assert_called_once()
|
||||
db_mock.session.commit.assert_called_once()
|
||||
db_mock.session.add_all.assert_not_called()
|
||||
vector_instance.add_texts.assert_not_called()
|
||||
@@ -644,6 +627,8 @@ def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_up
|
||||
|
||||
binding_ctor = MagicMock(side_effect=lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(vector_service_module, "SegmentAttachmentBinding", binding_ctor)
|
||||
monkeypatch.setattr(vector_service_module, "delete", MagicMock())
|
||||
monkeypatch.setattr(vector_service_module, "select", MagicMock())
|
||||
|
||||
logger_mock = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "logger", logger_mock)
|
||||
@@ -677,6 +662,8 @@ def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops
|
||||
monkeypatch.setattr(
|
||||
vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs)
|
||||
)
|
||||
monkeypatch.setattr(vector_service_module, "delete", MagicMock())
|
||||
monkeypatch.setattr(vector_service_module, "select", MagicMock())
|
||||
|
||||
VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset)
|
||||
|
||||
@@ -698,6 +685,8 @@ def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch:
|
||||
monkeypatch.setattr(
|
||||
vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs)
|
||||
)
|
||||
monkeypatch.setattr(vector_service_module, "delete", MagicMock())
|
||||
monkeypatch.setattr(vector_service_module, "select", MagicMock())
|
||||
|
||||
logger_mock = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "logger", logger_mock)
|
||||
|
||||
@@ -268,7 +268,7 @@ class TestWorkflowService:
|
||||
Provides mock implementations of:
|
||||
- session.add(): Adding new records
|
||||
- session.commit(): Committing transactions
|
||||
- session.query(): Querying database
|
||||
- session.scalar(): Scalar queries
|
||||
- session.execute(): Executing SQL statements
|
||||
"""
|
||||
with patch("services.workflow_service.db") as mock_db:
|
||||
@@ -276,7 +276,7 @@ class TestWorkflowService:
|
||||
mock_db.session = mock_session
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.query = MagicMock()
|
||||
mock_session.scalar = MagicMock()
|
||||
mock_session.execute = MagicMock()
|
||||
yield mock_db
|
||||
|
||||
@@ -338,10 +338,8 @@ class TestWorkflowService:
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock()
|
||||
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock()
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = mock_workflow
|
||||
# Mock db.session.scalar() used by get_draft_workflow
|
||||
mock_db_session.session.scalar.return_value = mock_workflow
|
||||
|
||||
result = workflow_service.get_draft_workflow(app)
|
||||
|
||||
@@ -351,10 +349,8 @@ class TestWorkflowService:
|
||||
"""Test get_draft_workflow returns None when no draft exists."""
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock()
|
||||
|
||||
# Mock database query to return None
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
# Mock db.session.scalar() to return None
|
||||
mock_db_session.session.scalar.return_value = None
|
||||
|
||||
result = workflow_service.get_draft_workflow(app)
|
||||
|
||||
@@ -366,10 +362,8 @@ class TestWorkflowService:
|
||||
workflow_id = "workflow-123"
|
||||
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(version="v1")
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = mock_workflow
|
||||
# Mock db.session.scalar() used by get_published_workflow_by_id
|
||||
mock_db_session.session.scalar.return_value = mock_workflow
|
||||
|
||||
result = workflow_service.get_draft_workflow(app, workflow_id=workflow_id)
|
||||
|
||||
@@ -384,10 +378,8 @@ class TestWorkflowService:
|
||||
workflow_id = "workflow-123"
|
||||
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = mock_workflow
|
||||
# Mock db.session.scalar() used by get_published_workflow_by_id
|
||||
mock_db_session.session.scalar.return_value = mock_workflow
|
||||
|
||||
result = workflow_service.get_published_workflow_by_id(app, workflow_id)
|
||||
|
||||
@@ -406,10 +398,8 @@ class TestWorkflowService:
|
||||
workflow_id=workflow_id, version=Workflow.VERSION_DRAFT
|
||||
)
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = mock_workflow
|
||||
# Mock db.session.scalar() used by get_published_workflow_by_id
|
||||
mock_db_session.session.scalar.return_value = mock_workflow
|
||||
|
||||
with pytest.raises(IsDraftWorkflowError):
|
||||
workflow_service.get_published_workflow_by_id(app, workflow_id)
|
||||
@@ -419,10 +409,8 @@ class TestWorkflowService:
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock()
|
||||
workflow_id = "nonexistent-workflow"
|
||||
|
||||
# Mock database query to return None
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
# Mock db.session.scalar() to return None
|
||||
mock_db_session.session.scalar.return_value = None
|
||||
|
||||
result = workflow_service.get_published_workflow_by_id(app, workflow_id)
|
||||
|
||||
@@ -434,10 +422,8 @@ class TestWorkflowService:
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=workflow_id)
|
||||
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = mock_workflow
|
||||
# Mock db.session.scalar() used by get_published_workflow
|
||||
mock_db_session.session.scalar.return_value = mock_workflow
|
||||
|
||||
result = workflow_service.get_published_workflow(app)
|
||||
|
||||
@@ -466,11 +452,9 @@ class TestWorkflowService:
|
||||
graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
|
||||
features = {"file_upload": {"enabled": False}}
|
||||
|
||||
# Mock get_draft_workflow to return None (no existing draft)
|
||||
# Mock db.session.scalar() to return None (no existing draft)
|
||||
# This simulates the first time a workflow is created for an app
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
mock_db_session.session.scalar.return_value = None
|
||||
|
||||
with (
|
||||
patch.object(workflow_service, "validate_features_structure"),
|
||||
@@ -504,12 +488,10 @@ class TestWorkflowService:
|
||||
features = {"file_upload": {"enabled": False}}
|
||||
unique_hash = "test-hash-123"
|
||||
|
||||
# Mock existing draft workflow
|
||||
# Mock existing draft workflow via db.session.scalar()
|
||||
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash=unique_hash)
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = mock_workflow
|
||||
mock_db_session.session.scalar.return_value = mock_workflow
|
||||
|
||||
with (
|
||||
patch.object(workflow_service, "validate_features_structure"),
|
||||
@@ -545,12 +527,10 @@ class TestWorkflowService:
|
||||
graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
|
||||
features = {}
|
||||
|
||||
# Mock existing draft workflow with different hash
|
||||
# Mock existing draft workflow with different hash via db.session.scalar()
|
||||
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash="old-hash")
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = mock_workflow
|
||||
mock_db_session.session.scalar.return_value = mock_workflow
|
||||
|
||||
with pytest.raises(WorkflowHashNotEqualError):
|
||||
workflow_service.sync_draft_workflow(
|
||||
|
||||
@@ -347,7 +347,7 @@ class TestGetBuiltinToolProviderCredentials:
|
||||
def test_returns_empty_when_no_providers(self, mock_db):
|
||||
mock_db.session.no_autoflush.__enter__ = MagicMock(return_value=None)
|
||||
mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = []
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_tool_provider_credentials("t", "google")
|
||||
|
||||
@@ -362,7 +362,7 @@ class TestGetBuiltinToolProviderCredentials:
|
||||
mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
provider = MagicMock(provider="google", is_default=False)
|
||||
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [provider]
|
||||
mock_db.session.scalars.return_value.all.return_value = [provider]
|
||||
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"key": "decrypted"}
|
||||
|
||||
Reference in New Issue
Block a user