refactor: select in tag_service (#34441)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo
2026-04-02 07:04:36 +02:00
committed by GitHub
parent cd406d2794
commit cb9ee5903a

View File

@@ -14,8 +14,8 @@ from models.model import App, Tag, TagBinding
class TagService:
@staticmethod
def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None):
query = (
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
stmt = (
select(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
.where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
)
@@ -23,9 +23,9 @@ class TagService:
from libs.helper import escape_like_pattern
escaped_keyword = escape_like_pattern(keyword)
query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results: list = query.order_by(Tag.created_at.desc()).all()
stmt = stmt.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
stmt = stmt.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results: list = list(db.session.execute(stmt.order_by(Tag.created_at.desc())).all())
return results
@staticmethod
@@ -64,8 +64,8 @@ class TagService:
@staticmethod
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str):
tags = (
db.session.query(Tag)
tags = db.session.scalars(
select(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.where(
TagBinding.target_id == target_id,
@@ -73,8 +73,7 @@ class TagService:
Tag.tenant_id == current_tenant_id,
Tag.type == tag_type,
)
.all()
)
).all()
return tags or []
@@ -97,7 +96,7 @@ class TagService:
def update_tags(args: dict, tag_id: str) -> Tag:
if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")):
raise ValueError("Tag name already exists")
tag = db.session.query(Tag).where(Tag.id == tag_id).first()
tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
if not tag:
raise NotFound("Tag not found")
tag.name = args["name"]
@@ -106,12 +105,12 @@ class TagService:
@staticmethod
def get_tag_binding_count(tag_id: str) -> int:
count = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).count()
count = db.session.scalar(select(func.count(TagBinding.id)).where(TagBinding.tag_id == tag_id)) or 0
return count
@staticmethod
def delete_tag(tag_id: str):
tag = db.session.query(Tag).where(Tag.id == tag_id).first()
tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
if not tag:
raise NotFound("Tag not found")
db.session.delete(tag)
@@ -128,10 +127,10 @@ class TagService:
TagService.check_target_exists(args["type"], args["target_id"])
# save tag binding
for tag_id in args["tag_ids"]:
tag_binding = (
db.session.query(TagBinding)
tag_binding = db.session.scalar(
select(TagBinding)
.where(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
.first()
.limit(1)
)
if tag_binding:
continue
@@ -149,10 +148,10 @@ class TagService:
# check if target exists
TagService.check_target_exists(args["type"], args["target_id"])
# delete tag binding
tag_bindings = (
db.session.query(TagBinding)
.where(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
.first()
tag_bindings = db.session.scalar(
select(TagBinding)
.where(TagBinding.target_id == args["target_id"], TagBinding.tag_id == args["tag_id"])
.limit(1)
)
if tag_bindings:
db.session.delete(tag_bindings)
@@ -161,18 +160,16 @@ class TagService:
@staticmethod
def check_target_exists(type: str, target_id: str):
if type == "knowledge":
dataset = (
db.session.query(Dataset)
dataset = db.session.scalar(
select(Dataset)
.where(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
.first()
.limit(1)
)
if not dataset:
raise NotFound("Dataset not found")
elif type == "app":
app = (
db.session.query(App)
.where(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
.first()
app = db.session.scalar(
select(App).where(App.tenant_id == current_user.current_tenant_id, App.id == target_id).limit(1)
)
if not app:
raise NotFound("App not found")