refactor: select in 13 small service files (#34371)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo
2026-04-01 06:00:05 +02:00
committed by GitHub
parent 42d7623cc6
commit beda78e911
16 changed files with 72 additions and 102 deletions

View File

@@ -132,7 +132,7 @@ class AudioService:
uuid.UUID(message_id)
except ValueError:
return None
message = db.session.query(Message).where(Message.id == message_id).first()
message = db.session.get(Message, message_id)
if message is None:
return None
if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}:

View File

@@ -6,6 +6,7 @@ from typing import Literal
import httpx
from pydantic import TypeAdapter
from sqlalchemy import select
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
from typing_extensions import TypedDict
from werkzeug.exceptions import InternalServerError
@@ -158,10 +159,10 @@ class BillingService:
def is_tenant_owner_or_admin(current_user: Account):
tenant_id = current_user.current_tenant_id
join: TenantAccountJoin | None = (
db.session.query(TenantAccountJoin)
join: TenantAccountJoin | None = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
.first()
.limit(1)
)
if not join:

View File

@@ -137,11 +137,11 @@ class ConversationService:
@classmethod
def auto_generate_name(cls, app_model: App, conversation: Conversation):
# get conversation first message
message = (
db.session.query(Message)
message = db.session.scalar(
select(Message)
.where(Message.app_id == app_model.id, Message.conversation_id == conversation.id)
.order_by(Message.created_at.asc())
.first()
.limit(1)
)
if not message:
@@ -160,8 +160,8 @@ class ConversationService:
@classmethod
def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
conversation = (
db.session.query(Conversation)
conversation = db.session.scalar(
select(Conversation)
.where(
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
@@ -170,7 +170,7 @@ class ConversationService:
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
Conversation.is_deleted == False,
)
.first()
.limit(1)
)
if not conversation:

View File

@@ -1,6 +1,6 @@
import logging
from sqlalchemy import update
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from configs import dify_config
@@ -29,13 +29,13 @@ class CreditPoolService:
@classmethod
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None:
"""get tenant credit pool"""
return (
db.session.query(TenantCreditPool)
.filter_by(
tenant_id=tenant_id,
pool_type=pool_type,
return db.session.scalar(
select(TenantCreditPool)
.where(
TenantCreditPool.tenant_id == tenant_id,
TenantCreditPool.pool_type == pool_type,
)
.first()
.limit(1)
)
@classmethod

View File

@@ -4,6 +4,7 @@ import uuid
from datetime import UTC, datetime
from redis import RedisError
from sqlalchemy import select
from configs import dify_config
from extensions.ext_database import db
@@ -104,7 +105,9 @@ def sync_account_deletion(account_id: str, *, source: str) -> bool:
return True
# Fetch all workspaces the account belongs to
workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all()
workspace_joins = db.session.scalars(
select(TenantAccountJoin).where(TenantAccountJoin.account_id == account_id)
).all()
# Queue sync task for each workspace
success = True

View File

@@ -110,7 +110,7 @@ class PipelineGenerateService:
Update document status to waiting
:param document_id: document id
"""
document = db.session.query(Document).where(Document.id == document_id).first()
document = db.session.get(Document, document_id)
if document:
document.indexing_status = IndexingStatus.WAITING
db.session.add(document)

View File

@@ -1,4 +1,5 @@
import yaml
from sqlalchemy import select
from extensions.ext_database import db
from libs.login import current_account_with_tenant
@@ -32,12 +33,11 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param language: language
:return:
"""
pipeline_customized_templates = (
db.session.query(PipelineCustomizedTemplate)
pipeline_customized_templates = db.session.scalars(
select(PipelineCustomizedTemplate)
.where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
.order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc())
.all()
)
).all()
recommended_pipelines_results = []
for pipeline_customized_template in pipeline_customized_templates:
recommended_pipeline_result = {
@@ -59,9 +59,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param template_id: Template ID
:return:
"""
pipeline_template = (
db.session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
)
pipeline_template = db.session.get(PipelineCustomizedTemplate, template_id)
if not pipeline_template:
return None

View File

@@ -1,4 +1,5 @@
import yaml
from sqlalchemy import select
from extensions.ext_database import db
from models.dataset import PipelineBuiltInTemplate
@@ -30,8 +31,10 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:return:
"""
pipeline_built_in_templates: list[PipelineBuiltInTemplate] = (
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language).all()
pipeline_built_in_templates = list(
db.session.scalars(
select(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language)
).all()
)
recommended_pipelines_results = []
@@ -58,9 +61,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:return:
"""
# is in public recommended list
pipeline_template = (
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.id == template_id).first()
)
pipeline_template = db.session.get(PipelineBuiltInTemplate, template_id)
if not pipeline_template:
return None

View File

@@ -77,17 +77,15 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
:return:
"""
# is in public recommended list
recommended_app = (
db.session.query(RecommendedApp)
.where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id)
.first()
recommended_app = db.session.scalar(
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id).limit(1)
)
if not recommended_app:
return None
# get app detail
app_model = db.session.query(App).where(App.id == app_id).first()
app_model = db.session.get(App, app_id)
if not app_model or not app_model.is_public:
return None

View File

@@ -64,15 +64,15 @@ class WebConversationService:
def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
if not user:
return
pinned_conversation = (
db.session.query(PinnedConversation)
pinned_conversation = db.session.scalar(
select(PinnedConversation)
.where(
PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id == conversation_id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
PinnedConversation.created_by == user.id,
)
.first()
.limit(1)
)
if pinned_conversation:
@@ -96,15 +96,15 @@ class WebConversationService:
def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
if not user:
return
pinned_conversation = (
db.session.query(PinnedConversation)
pinned_conversation = db.session.scalar(
select(PinnedConversation)
.where(
PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id == conversation_id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
PinnedConversation.created_by == user.id,
)
.first()
.limit(1)
)
if not pinned_conversation:

View File

@@ -3,6 +3,7 @@ import secrets
from datetime import UTC, datetime, timedelta
from typing import Any
from sqlalchemy import select
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
@@ -92,10 +93,10 @@ class WebAppAuthService:
@classmethod
def create_end_user(cls, app_code, email) -> EndUser:
site = db.session.query(Site).where(Site.code == app_code).first()
site = db.session.scalar(select(Site).where(Site.code == app_code).limit(1))
if not site:
raise NotFound("Site not found.")
app_model = db.session.query(App).where(App.id == site.app_id).first()
app_model = db.session.get(App, site.app_id)
if not app_model:
raise NotFound("App not found.")
end_user = EndUser(

View File

@@ -6,6 +6,7 @@ from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.utils.encoders import jsonable_encoder
from graphon.nodes import BuiltinNodeTypes
from graphon.variables.input_entities import VariableEntity
from sqlalchemy import select
from typing_extensions import TypedDict
from core.app.app_config.entities import (
@@ -648,10 +649,10 @@ class WorkflowConverter:
:param api_based_extension_id: api based extension id
:return:
"""
api_based_extension = (
db.session.query(APIBasedExtension)
api_based_extension = db.session.scalar(
select(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
.limit(1)
)
if not api_based_extension:

View File

@@ -1,4 +1,5 @@
from flask_login import current_user
from sqlalchemy import select
from configs import dify_config
from enums.cloud_plan import CloudPlan
@@ -24,10 +25,10 @@ class WorkspaceService:
}
# Get role of user
tenant_account_join = (
db.session.query(TenantAccountJoin)
tenant_account_join = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id)
.first()
.limit(1)
)
assert tenant_account_join is not None, "TenantAccountJoin not found"
tenant_info["role"] = tenant_account_join.role

View File

@@ -421,11 +421,8 @@ class TestAudioServiceTTS:
answer="Message answer text",
)
# Mock database query
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = message
# Mock database lookup
mock_db_session.get.return_value = message
# Mock ModelManager
mock_model_manager = mock_model_manager_class.return_value
@@ -568,11 +565,8 @@ class TestAudioServiceTTS:
# Arrange
app = factory.create_app_mock()
# Mock database query returning None
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None
# Mock database lookup returning None
mock_db_session.get.return_value = None
# Act
result = AudioService.transcript_tts(
@@ -594,11 +588,8 @@ class TestAudioServiceTTS:
status=MessageStatus.NORMAL,
)
# Mock database query
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = message
# Mock database lookup
mock_db_session.get.return_value = message
# Act
result = AudioService.transcript_tts(

View File

@@ -865,16 +865,11 @@ class TestBillingServiceAccountManagement:
mock_join = MagicMock(spec=TenantAccountJoin)
mock_join.role = TenantAccountRole.OWNER
mock_query = MagicMock()
mock_query.where.return_value.first.return_value = mock_join
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = mock_join
# Act - should not raise exception
BillingService.is_tenant_owner_or_admin(current_user)
# Assert
mock_db_session.query.assert_called_once()
def test_is_tenant_owner_or_admin_admin(self, mock_db_session):
"""Test tenant owner/admin check for admin role."""
# Arrange
@@ -885,16 +880,11 @@ class TestBillingServiceAccountManagement:
mock_join = MagicMock(spec=TenantAccountJoin)
mock_join.role = TenantAccountRole.ADMIN
mock_query = MagicMock()
mock_query.where.return_value.first.return_value = mock_join
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = mock_join
# Act - should not raise exception
BillingService.is_tenant_owner_or_admin(current_user)
# Assert
mock_db_session.query.assert_called_once()
def test_is_tenant_owner_or_admin_normal_user_raises_error(self, mock_db_session):
"""Test tenant owner/admin check raises error for normal user."""
# Arrange
@@ -905,9 +895,7 @@ class TestBillingServiceAccountManagement:
mock_join = MagicMock(spec=TenantAccountJoin)
mock_join.role = TenantAccountRole.NORMAL
mock_query = MagicMock()
mock_query.where.return_value.first.return_value = mock_join
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = mock_join
# Act & Assert
with pytest.raises(ValueError) as exc_info:
@@ -921,9 +909,7 @@ class TestBillingServiceAccountManagement:
current_user.id = "account-123"
current_user.current_tenant_id = "tenant-456"
mock_query = MagicMock()
mock_query.where.return_value.first.return_value = None
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError) as exc_info:
@@ -1135,9 +1121,7 @@ class TestBillingServiceEdgeCases:
mock_join.role = TenantAccountRole.EDITOR # Editor is not privileged
with patch("services.billing_service.db.session") as mock_session:
mock_query = MagicMock()
mock_query.where.return_value.first.return_value = mock_join
mock_session.query.return_value = mock_query
mock_session.scalar.return_value = mock_join
# Act & Assert
with pytest.raises(ValueError) as exc_info:
@@ -1155,9 +1139,7 @@ class TestBillingServiceEdgeCases:
mock_join.role = TenantAccountRole.DATASET_OPERATOR # Dataset operator is not privileged
with patch("services.billing_service.db.session") as mock_session:
mock_query = MagicMock()
mock_query.where.return_value.first.return_value = mock_join
mock_session.query.return_value = mock_query
mock_session.scalar.return_value = mock_join
# Act & Assert
with pytest.raises(ValueError) as exc_info:

View File

@@ -355,15 +355,13 @@ class TestConversationServiceGetConversation:
from_account_id=user.id, from_source=ConversationFromSource.CONSOLE
)
mock_query = mock_db_session.query.return_value
mock_query.where.return_value.first.return_value = conversation
mock_db_session.scalar.return_value = conversation
# Act
result = ConversationService.get_conversation(app_model, "conv-123", user)
# Assert
assert result == conversation
mock_db_session.query.assert_called_once_with(Conversation)
@patch("services.conversation_service.db.session")
def test_get_conversation_success_with_end_user(self, mock_db_session):
@@ -379,8 +377,7 @@ class TestConversationServiceGetConversation:
from_end_user_id=user.id, from_source=ConversationFromSource.API
)
mock_query = mock_db_session.query.return_value
mock_query.where.return_value.first.return_value = conversation
mock_db_session.scalar.return_value = conversation
# Act
result = ConversationService.get_conversation(app_model, "conv-123", user)
@@ -399,8 +396,7 @@ class TestConversationServiceGetConversation:
app_model = ConversationServiceTestDataFactory.create_app_mock()
user = ConversationServiceTestDataFactory.create_account_mock()
mock_query = mock_db_session.query.return_value
mock_query.where.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
# Act & Assert
with pytest.raises(ConversationNotExistsError):
@@ -489,8 +485,7 @@ class TestConversationServiceAutoGenerateName:
)
# Mock database query to return message
mock_query = mock_db_session.query.return_value
mock_query.where.return_value.order_by.return_value.first.return_value = message
mock_db_session.scalar.return_value = message
# Mock LLM generator
mock_llm_generator.generate_conversation_name.return_value = "Generated Name"
@@ -518,8 +513,7 @@ class TestConversationServiceAutoGenerateName:
conversation = ConversationServiceTestDataFactory.create_conversation_mock()
# Mock database query to return None
mock_query = mock_db_session.query.return_value
mock_query.where.return_value.order_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
# Act & Assert
with pytest.raises(MessageNotExistsError):
@@ -541,8 +535,7 @@ class TestConversationServiceAutoGenerateName:
)
# Mock database query to return message
mock_query = mock_db_session.query.return_value
mock_query.where.return_value.order_by.return_value.first.return_value = message
mock_db_session.scalar.return_value = message
# Mock LLM generator to raise exception
mock_llm_generator.generate_conversation_name.side_effect = Exception("LLM Error")