From cb9ee5903a4f2511a29982ef3815ccd5aece7c8b Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Thu, 2 Apr 2026 07:04:36 +0200 Subject: [PATCH] refactor: select in tag_service (#34441) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/services/tag_service.py | 49 +++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 70bf7f16f24..194622bd862 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -14,8 +14,8 @@ from models.model import App, Tag, TagBinding class TagService: @staticmethod def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None): - query = ( - db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) + stmt = ( + select(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) .where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) ) @@ -23,9 +23,9 @@ class TagService: from libs.helper import escape_like_pattern escaped_keyword = escape_like_pattern(keyword) - query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\"))) - query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) - results: list = query.order_by(Tag.created_at.desc()).all() + stmt = stmt.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\"))) + stmt = stmt.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) + results: list = list(db.session.execute(stmt.order_by(Tag.created_at.desc())).all()) return results @staticmethod @@ -64,8 +64,8 @@ class TagService: @staticmethod def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str): - tags = ( - db.session.query(Tag) + tags = db.session.scalars( + select(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) .where( TagBinding.target_id == target_id, @@ -73,8 +73,7 @@ class TagService: Tag.tenant_id == current_tenant_id, Tag.type == tag_type, ) - .all() - ) + ).all() return tags or [] @@ -97,7 +96,7 @@ class TagService: def update_tags(args: dict, tag_id: str) -> Tag: if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")): raise ValueError("Tag name already exists") - tag = db.session.query(Tag).where(Tag.id == tag_id).first() + tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1)) if not tag: raise NotFound("Tag not found") tag.name = args["name"] @@ -106,12 +105,12 @@ class TagService: @staticmethod def get_tag_binding_count(tag_id: str) -> int: - count = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).count() + count = db.session.scalar(select(func.count(TagBinding.id)).where(TagBinding.tag_id == tag_id)) or 0 return count @staticmethod def delete_tag(tag_id: str): - tag = db.session.query(Tag).where(Tag.id == tag_id).first() + tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1)) if not tag: raise NotFound("Tag not found") db.session.delete(tag) @@ -128,10 +127,10 @@ class TagService: TagService.check_target_exists(args["type"], args["target_id"]) # save tag binding for tag_id in args["tag_ids"]: - tag_binding = ( - db.session.query(TagBinding) + tag_binding = db.session.scalar( + select(TagBinding) .where(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) - .first() + .limit(1) ) if tag_binding: continue @@ -149,10 +148,10 @@ class TagService: # check if target exists TagService.check_target_exists(args["type"], args["target_id"]) # delete tag binding - tag_bindings = ( - db.session.query(TagBinding) - .where(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"])) - .first() + tag_bindings = db.session.scalar( + select(TagBinding) + .where(TagBinding.target_id == args["target_id"], TagBinding.tag_id == args["tag_id"]) + .limit(1) ) if tag_bindings: db.session.delete(tag_bindings) @@ -161,18 +160,16 @@ class TagService: @staticmethod def check_target_exists(type: str, target_id: str): if type == "knowledge": - dataset = ( - db.session.query(Dataset) + dataset = db.session.scalar( + select(Dataset) .where(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id) - .first() + .limit(1) ) if not dataset: raise NotFound("Dataset not found") elif type == "app": - app = ( - db.session.query(App) - .where(App.tenant_id == current_user.current_tenant_id, App.id == target_id) - .first() + app = db.session.scalar( + select(App).where(App.tenant_id == current_user.current_tenant_id, App.id == target_id).limit(1) ) if not app: raise NotFound("App not found")