refactor: select in 10 service files (#34373)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
Renzo
2026-04-01 10:03:49 +02:00
committed by GitHub
parent b23ea0397a
commit 31f7752ba9
14 changed files with 147 additions and 180 deletions

View File

@@ -2,6 +2,7 @@ import threading
from typing import Any
import pytz
from sqlalchemy import select
import contexts
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
@@ -23,25 +24,25 @@ class AgentService:
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
conversation: Conversation | None = (
db.session.query(Conversation)
conversation: Conversation | None = db.session.scalar(
select(Conversation)
.where(
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
)
.first()
.limit(1)
)
if not conversation:
raise ValueError(f"Conversation not found: {conversation_id}")
message: Message | None = (
db.session.query(Message)
message: Message | None = db.session.scalar(
select(Message)
.where(
Message.id == message_id,
Message.conversation_id == conversation_id,
)
.first()
.limit(1)
)
if not message:
@@ -51,16 +52,11 @@ class AgentService:
if conversation.from_end_user_id:
# only select name field
executor = (
db.session.query(EndUser, EndUser.name).where(EndUser.id == conversation.from_end_user_id).first()
)
executor_name = db.session.scalar(select(EndUser.name).where(EndUser.id == conversation.from_end_user_id))
else:
executor = db.session.query(Account, Account.name).where(Account.id == conversation.from_account_id).first()
executor_name = db.session.scalar(select(Account.name).where(Account.id == conversation.from_account_id))
if executor:
executor = executor.name
else:
executor = "Unknown"
executor = executor_name or "Unknown"
assert isinstance(current_user, Account)
assert current_user.timezone is not None
timezone = pytz.timezone(current_user.timezone)

View File

@@ -1,3 +1,5 @@
from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.helper.encrypter import decrypt_token, encrypt_token
from extensions.ext_database import db
@@ -7,11 +9,12 @@ from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
class APIBasedExtensionService:
@staticmethod
def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]:
extension_list = (
db.session.query(APIBasedExtension)
.filter_by(tenant_id=tenant_id)
.order_by(APIBasedExtension.created_at.desc())
.all()
extension_list = list(
db.session.scalars(
select(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id)
.order_by(APIBasedExtension.created_at.desc())
).all()
)
for extension in extension_list:
@@ -36,11 +39,10 @@ class APIBasedExtensionService:
@staticmethod
def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
extension = (
db.session.query(APIBasedExtension)
.filter_by(tenant_id=tenant_id)
.filter_by(id=api_based_extension_id)
.first()
extension = db.session.scalar(
select(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.limit(1)
)
if not extension:
@@ -58,23 +60,27 @@ class APIBasedExtensionService:
if not extension_data.id:
# case one: check new data, name must be unique
is_name_existed = (
db.session.query(APIBasedExtension)
.filter_by(tenant_id=extension_data.tenant_id)
.filter_by(name=extension_data.name)
.first()
is_name_existed = db.session.scalar(
select(APIBasedExtension)
.where(
APIBasedExtension.tenant_id == extension_data.tenant_id,
APIBasedExtension.name == extension_data.name,
)
.limit(1)
)
if is_name_existed:
raise ValueError("name must be unique, it is already existed")
else:
# case two: check existing data, name must be unique
is_name_existed = (
db.session.query(APIBasedExtension)
.filter_by(tenant_id=extension_data.tenant_id)
.filter_by(name=extension_data.name)
.where(APIBasedExtension.id != extension_data.id)
.first()
is_name_existed = db.session.scalar(
select(APIBasedExtension)
.where(
APIBasedExtension.tenant_id == extension_data.tenant_id,
APIBasedExtension.name == extension_data.name,
APIBasedExtension.id != extension_data.id,
)
.limit(1)
)
if is_name_existed:

View File

@@ -6,6 +6,7 @@ import sqlalchemy as sa
from flask_sqlalchemy.pagination import Pagination
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from sqlalchemy import select
from configs import dify_config
from constants.model_template import default_app_templates
@@ -433,9 +434,7 @@ class AppService:
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
elif provider_type == "api":
try:
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first()
)
provider: ApiToolProvider | None = db.session.get(ApiToolProvider, provider_id)
if provider is None:
raise ValueError(f"provider not found for tool {tool_name}")
meta["tool_icons"][tool_name] = json.loads(provider.icon)
@@ -451,7 +450,7 @@ class AppService:
:param app_id: app id
:return: app code
"""
site = db.session.query(Site).where(Site.app_id == app_id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_id).limit(1))
if not site:
raise ValueError(f"App with id {app_id} not found")
return str(site.code)
@@ -463,7 +462,7 @@ class AppService:
:param app_code: app code
:return: app id
"""
site = db.session.query(Site).where(Site.code == app_code).first()
site = db.session.scalar(select(Site).where(Site.code == app_code).limit(1))
if not site:
raise ValueError(f"App with code {app_code} not found")
return str(site.app_id)

View File

@@ -4,7 +4,7 @@ import json
from datetime import datetime
from flask import Response
from sqlalchemy import or_
from sqlalchemy import or_, select
from extensions.ext_database import db
from models.enums import FeedbackRating
@@ -41,8 +41,8 @@ class FeedbackService:
raise ValueError(f"Unsupported format: {format_type}")
# Build base query
query = (
db.session.query(MessageFeedback, Message, Conversation, App, Account)
stmt = (
select(MessageFeedback, Message, Conversation, App, Account)
.join(Message, MessageFeedback.message_id == Message.id)
.join(Conversation, MessageFeedback.conversation_id == Conversation.id)
.join(App, MessageFeedback.app_id == App.id)
@@ -52,36 +52,36 @@ class FeedbackService:
# Apply filters
if from_source:
query = query.filter(MessageFeedback.from_source == from_source)
stmt = stmt.where(MessageFeedback.from_source == from_source)
if rating:
query = query.filter(MessageFeedback.rating == rating)
stmt = stmt.where(MessageFeedback.rating == rating)
if has_comment is not None:
if has_comment:
query = query.filter(MessageFeedback.content.isnot(None), MessageFeedback.content != "")
stmt = stmt.where(MessageFeedback.content.isnot(None), MessageFeedback.content != "")
else:
query = query.filter(or_(MessageFeedback.content.is_(None), MessageFeedback.content == ""))
stmt = stmt.where(or_(MessageFeedback.content.is_(None), MessageFeedback.content == ""))
if start_date:
try:
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
query = query.filter(MessageFeedback.created_at >= start_dt)
stmt = stmt.where(MessageFeedback.created_at >= start_dt)
except ValueError:
raise ValueError(f"Invalid start_date format: {start_date}. Use YYYY-MM-DD")
if end_date:
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
query = query.filter(MessageFeedback.created_at <= end_dt)
stmt = stmt.where(MessageFeedback.created_at <= end_dt)
except ValueError:
raise ValueError(f"Invalid end_date format: {end_date}. Use YYYY-MM-DD")
# Order by creation date (newest first)
query = query.order_by(MessageFeedback.created_at.desc())
stmt = stmt.order_by(MessageFeedback.created_at.desc())
# Execute query
results = query.all()
results = db.session.execute(stmt).all()
# Prepare data for export
export_data = []

View File

@@ -6,6 +6,7 @@ from uuid import uuid4
import yaml
from flask_login import current_user
from sqlalchemy import select
from constants import DOCUMENT_EXTENSIONS
from core.plugin.impl.plugin import PluginInstaller
@@ -26,7 +27,7 @@ logger = logging.getLogger(__name__)
class RagPipelineTransformService:
def transform_dataset(self, dataset_id: str):
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = db.session.get(Dataset, dataset_id)
if not dataset:
raise ValueError("Dataset not found")
if dataset.pipeline_id and dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE:
@@ -306,7 +307,7 @@ class RagPipelineTransformService:
jina_node_id = "1752491761974"
firecrawl_node_id = "1752565402678"
documents = db.session.query(Document).where(Document.dataset_id == dataset.id).all()
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset.id)).all()
for document in documents:
data_source_info_dict = document.data_source_info_dict
@@ -316,7 +317,7 @@ class RagPipelineTransformService:
document.data_source_type = DataSourceType.LOCAL_FILE
file_id = data_source_info_dict.get("upload_file_id")
if file_id:
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
file = db.session.get(UploadFile, file_id)
if file:
data_source_info = json.dumps(
{

View File

@@ -1,3 +1,5 @@
from sqlalchemy import select
from configs import dify_config
from extensions.ext_database import db
from models.model import AccountTrialAppRecord, TrialApp
@@ -27,7 +29,7 @@ class RecommendedAppService:
apps = result["recommended_apps"]
for app in apps:
app_id = app["app_id"]
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1))
if trial_app_model:
app["can_trial"] = True
else:
@@ -46,7 +48,7 @@ class RecommendedAppService:
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
if FeatureService.get_system_features().enable_trial_app:
app_id = result["id"]
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1))
if trial_app_model:
result["can_trial"] = True
else:
@@ -60,10 +62,10 @@ class RecommendedAppService:
:param app_id: app id
:return:
"""
account_trial_app_record = (
db.session.query(AccountTrialAppRecord)
account_trial_app_record = db.session.scalar(
select(AccountTrialAppRecord)
.where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id)
.first()
.limit(1)
)
if account_trial_app_record:
account_trial_app_record.count += 1

View File

@@ -1,5 +1,7 @@
from typing import Union
from sqlalchemy import select
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
@@ -16,16 +18,15 @@ class SavedMessageService:
) -> InfiniteScrollPagination:
if not user:
raise ValueError("User is required")
saved_messages = (
db.session.query(SavedMessage)
saved_messages = db.session.scalars(
select(SavedMessage)
.where(
SavedMessage.app_id == app_model.id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
SavedMessage.created_by == user.id,
)
.order_by(SavedMessage.created_at.desc())
.all()
)
).all()
message_ids = [sm.message_id for sm in saved_messages]
return MessageService.pagination_by_last_id(
@@ -36,15 +37,15 @@ class SavedMessageService:
def save(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
if not user:
return
saved_message = (
db.session.query(SavedMessage)
saved_message = db.session.scalar(
select(SavedMessage)
.where(
SavedMessage.app_id == app_model.id,
SavedMessage.message_id == message_id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
SavedMessage.created_by == user.id,
)
.first()
.limit(1)
)
if saved_message:
@@ -66,15 +67,15 @@ class SavedMessageService:
def delete(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
if not user:
return
saved_message = (
db.session.query(SavedMessage)
saved_message = db.session.scalar(
select(SavedMessage)
.where(
SavedMessage.app_id == app_model.id,
SavedMessage.message_id == message_id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
SavedMessage.created_by == user.id,
)
.first()
.limit(1)
)
if not saved_message:

View File

@@ -332,12 +332,11 @@ class BuiltinToolManageService:
get builtin tool provider credentials
"""
with db.session.no_autoflush:
providers = (
db.session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider_name)
providers = db.session.scalars(
select(BuiltinToolProvider)
.where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.all()
)
).all()
if len(providers) == 0:
return []

View File

@@ -1,6 +1,7 @@
import logging
from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import delete, select
from core.model_manager import ModelInstance, ModelManager
from core.rag.datasource.keyword.keyword_factory import Keyword
@@ -29,7 +30,7 @@ class VectorService:
for segment in segments:
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
dataset_document = db.session.get(DatasetDocument, segment.document_id)
if not dataset_document:
logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
@@ -38,11 +39,7 @@ class VectorService:
)
continue
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
processing_rule = db.session.get(DatasetProcessRule, dataset_document.dataset_process_rule_id)
if not processing_rule:
raise ValueError("No processing rule found.")
# get embedding model instance
@@ -271,8 +268,8 @@ class VectorService:
vector.delete_by_ids(old_attachment_ids)
# Delete existing segment attachment bindings in one operation
db.session.query(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id).delete(
synchronize_session=False
db.session.execute(
delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id)
)
if not attachment_ids:
@@ -280,7 +277,7 @@ class VectorService:
return
# Bulk fetch upload files - only fetch needed fields
upload_file_list = db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
upload_file_list = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all()
if not upload_file_list:
db.session.commit()

View File

@@ -138,14 +138,14 @@ class WorkflowService:
if workflow_id:
return self.get_published_workflow_by_id(app_model, workflow_id)
# fetch draft workflow by app_model
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == Workflow.VERSION_DRAFT,
)
.first()
.limit(1)
)
# return draft workflow
@@ -155,14 +155,14 @@ class WorkflowService:
"""
fetch published workflow by workflow_id
"""
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id,
)
.first()
.limit(1)
)
if not workflow:
return None
@@ -182,14 +182,14 @@ class WorkflowService:
return None
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == app_model.workflow_id,
)
.first()
.limit(1)
)
return workflow
@@ -544,14 +544,14 @@ class WorkflowService:
# Use the same fallback logic as runtime: get the first available credential
# ordered by is_default DESC, created_at ASC (same as tool_manager.py)
default_provider = (
db.session.query(BuiltinToolProvider)
default_provider = db.session.scalar(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
.limit(1)
)
if not default_provider:

View File

@@ -99,7 +99,7 @@ class TestFeedbackService:
)
]
mock_db_session.query.return_value = mock_query
mock_db_session.execute.return_value = mock_query
# Test CSV export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
@@ -138,7 +138,7 @@ class TestFeedbackService:
)
]
mock_db_session.query.return_value = mock_query
mock_db_session.execute.return_value = mock_query
# Test JSON export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
@@ -175,7 +175,7 @@ class TestFeedbackService:
)
]
mock_db_session.query.return_value = mock_query
mock_db_session.execute.return_value = mock_query
# Test with filters
result = FeedbackService.export_feedbacks(
@@ -188,11 +188,8 @@ class TestFeedbackService:
format_type="csv",
)
# Verify filters were applied
assert mock_query.filter.called
filter_calls = mock_query.filter.call_args_list
# At least three filter invocations are expected (source, rating, comment)
assert len(filter_calls) >= 3
# Verify query was executed (filters are baked into the select statement)
assert mock_db_session.execute.called
def test_export_feedbacks_no_data(self, mock_db_session, sample_data):
"""Test exporting feedback when no data exists."""
@@ -206,7 +203,7 @@ class TestFeedbackService:
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
mock_db_session.query.return_value = mock_query
mock_db_session.execute.return_value = mock_query
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
@@ -271,7 +268,7 @@ class TestFeedbackService:
)
]
mock_db_session.query.return_value = mock_query
mock_db_session.execute.return_value = mock_query
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
@@ -329,7 +326,7 @@ class TestFeedbackService:
)
]
mock_db_session.query.return_value = mock_query
mock_db_session.execute.return_value = mock_query
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
@@ -367,7 +364,7 @@ class TestFeedbackService:
),
]
mock_db_session.query.return_value = mock_query
mock_db_session.execute.return_value = mock_query
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")

View File

@@ -77,22 +77,12 @@ def _make_segment(
def _mock_db_session_for_update_multimodel(*, upload_files: list[_UploadFileStub] | None) -> MagicMock:
session = MagicMock(name="session")
binding_query = MagicMock(name="binding_query")
binding_query.where.return_value = binding_query
binding_query.delete.return_value = 1
# db.session.execute() is used for delete(SegmentAttachmentBinding).where(...)
session.execute = MagicMock(name="execute")
upload_query = MagicMock(name="upload_query")
upload_query.where.return_value = upload_query
upload_query.all.return_value = upload_files or []
# db.session.scalars(select(UploadFile).where(...)).all() returns upload files
session.scalars.return_value.all.return_value = upload_files or []
def query_side_effect(model: object) -> MagicMock:
if model is vector_service_module.SegmentAttachmentBinding:
return binding_query
if model is vector_service_module.UploadFile:
return upload_query
return MagicMock(name=f"query({model})")
session.query.side_effect = query_side_effect
db_mock = MagicMock(name="db")
db_mock.session = session
return db_mock
@@ -165,22 +155,15 @@ def _mock_parent_child_queries(
) -> MagicMock:
session = MagicMock(name="session")
doc_query = MagicMock(name="doc_query")
doc_query.filter_by.return_value = doc_query
doc_query.first.return_value = dataset_document
get_dispatch: dict[object, object | None] = {
vector_service_module.DatasetDocument: dataset_document,
vector_service_module.DatasetProcessRule: processing_rule,
}
rule_query = MagicMock(name="rule_query")
rule_query.where.return_value = rule_query
rule_query.first.return_value = processing_rule
def get_side_effect(model: object, pk: object) -> object | None:
return get_dispatch.get(model)
def query_side_effect(model: object) -> MagicMock:
if model is vector_service_module.DatasetDocument:
return doc_query
if model is vector_service_module.DatasetProcessRule:
return rule_query
return MagicMock(name=f"query({model})")
session.query.side_effect = query_side_effect
session.get.side_effect = get_side_effect
db_mock = MagicMock(name="db")
db_mock.session = session
return db_mock
@@ -609,7 +592,7 @@ def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids(
vector_cls.assert_called_once_with(dataset=dataset)
vector_instance.delete_by_ids.assert_called_once_with(["old-1", "old-2"])
db_mock.session.query.assert_called_once_with(vector_service_module.SegmentAttachmentBinding)
db_mock.session.execute.assert_called_once()
db_mock.session.commit.assert_called_once()
db_mock.session.add_all.assert_not_called()
vector_instance.add_texts.assert_not_called()
@@ -644,6 +627,8 @@ def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_up
binding_ctor = MagicMock(side_effect=lambda **kwargs: kwargs)
monkeypatch.setattr(vector_service_module, "SegmentAttachmentBinding", binding_ctor)
monkeypatch.setattr(vector_service_module, "delete", MagicMock())
monkeypatch.setattr(vector_service_module, "select", MagicMock())
logger_mock = MagicMock()
monkeypatch.setattr(vector_service_module, "logger", logger_mock)
@@ -677,6 +662,8 @@ def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops
monkeypatch.setattr(
vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs)
)
monkeypatch.setattr(vector_service_module, "delete", MagicMock())
monkeypatch.setattr(vector_service_module, "select", MagicMock())
VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset)
@@ -698,6 +685,8 @@ def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch:
monkeypatch.setattr(
vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs)
)
monkeypatch.setattr(vector_service_module, "delete", MagicMock())
monkeypatch.setattr(vector_service_module, "select", MagicMock())
logger_mock = MagicMock()
monkeypatch.setattr(vector_service_module, "logger", logger_mock)

View File

@@ -268,7 +268,7 @@ class TestWorkflowService:
Provides mock implementations of:
- session.add(): Adding new records
- session.commit(): Committing transactions
- session.query(): Querying database
- session.scalar(): Scalar queries
- session.execute(): Executing SQL statements
"""
with patch("services.workflow_service.db") as mock_db:
@@ -276,7 +276,7 @@ class TestWorkflowService:
mock_db.session = mock_session
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.query = MagicMock()
mock_session.scalar = MagicMock()
mock_session.execute = MagicMock()
yield mock_db
@@ -338,10 +338,8 @@ class TestWorkflowService:
app = TestWorkflowAssociatedDataFactory.create_app_mock()
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock()
# Mock database query
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
# Mock db.session.scalar() used by get_draft_workflow
mock_db_session.session.scalar.return_value = mock_workflow
result = workflow_service.get_draft_workflow(app)
@@ -351,10 +349,8 @@ class TestWorkflowService:
"""Test get_draft_workflow returns None when no draft exists."""
app = TestWorkflowAssociatedDataFactory.create_app_mock()
# Mock database query to return None
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = None
# Mock db.session.scalar() to return None
mock_db_session.session.scalar.return_value = None
result = workflow_service.get_draft_workflow(app)
@@ -366,10 +362,8 @@ class TestWorkflowService:
workflow_id = "workflow-123"
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(version="v1")
# Mock database query
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
# Mock db.session.scalar() used by get_published_workflow_by_id
mock_db_session.session.scalar.return_value = mock_workflow
result = workflow_service.get_draft_workflow(app, workflow_id=workflow_id)
@@ -384,10 +378,8 @@ class TestWorkflowService:
workflow_id = "workflow-123"
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
# Mock database query
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
# Mock db.session.scalar() used by get_published_workflow_by_id
mock_db_session.session.scalar.return_value = mock_workflow
result = workflow_service.get_published_workflow_by_id(app, workflow_id)
@@ -406,10 +398,8 @@ class TestWorkflowService:
workflow_id=workflow_id, version=Workflow.VERSION_DRAFT
)
# Mock database query
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
# Mock db.session.scalar() used by get_published_workflow_by_id
mock_db_session.session.scalar.return_value = mock_workflow
with pytest.raises(IsDraftWorkflowError):
workflow_service.get_published_workflow_by_id(app, workflow_id)
@@ -419,10 +409,8 @@ class TestWorkflowService:
app = TestWorkflowAssociatedDataFactory.create_app_mock()
workflow_id = "nonexistent-workflow"
# Mock database query to return None
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = None
# Mock db.session.scalar() to return None
mock_db_session.session.scalar.return_value = None
result = workflow_service.get_published_workflow_by_id(app, workflow_id)
@@ -434,10 +422,8 @@ class TestWorkflowService:
app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=workflow_id)
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
# Mock database query
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
# Mock db.session.scalar() used by get_published_workflow
mock_db_session.session.scalar.return_value = mock_workflow
result = workflow_service.get_published_workflow(app)
@@ -466,11 +452,9 @@ class TestWorkflowService:
graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
features = {"file_upload": {"enabled": False}}
# Mock get_draft_workflow to return None (no existing draft)
# Mock db.session.scalar() to return None (no existing draft)
# This simulates the first time a workflow is created for an app
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = None
mock_db_session.session.scalar.return_value = None
with (
patch.object(workflow_service, "validate_features_structure"),
@@ -504,12 +488,10 @@ class TestWorkflowService:
features = {"file_upload": {"enabled": False}}
unique_hash = "test-hash-123"
# Mock existing draft workflow
# Mock existing draft workflow via db.session.scalar()
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash=unique_hash)
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
mock_db_session.session.scalar.return_value = mock_workflow
with (
patch.object(workflow_service, "validate_features_structure"),
@@ -545,12 +527,10 @@ class TestWorkflowService:
graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
features = {}
# Mock existing draft workflow with different hash
# Mock existing draft workflow with different hash via db.session.scalar()
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash="old-hash")
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
mock_db_session.session.scalar.return_value = mock_workflow
with pytest.raises(WorkflowHashNotEqualError):
workflow_service.sync_draft_workflow(

View File

@@ -347,7 +347,7 @@ class TestGetBuiltinToolProviderCredentials:
def test_returns_empty_when_no_providers(self, mock_db):
mock_db.session.no_autoflush.__enter__ = MagicMock(return_value=None)
mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False)
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = []
mock_db.session.scalars.return_value.all.return_value = []
result = BuiltinToolManageService.get_builtin_tool_provider_credentials("t", "google")
@@ -362,7 +362,7 @@ class TestGetBuiltinToolProviderCredentials:
mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False)
provider = MagicMock(provider="google", is_default=False)
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [provider]
mock_db.session.scalars.return_value.all.return_value = [provider]
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"key": "decrypted"}