mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 02:19:20 +08:00
refactor: select in tag_service (#34441)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -14,8 +14,8 @@ from models.model import App, Tag, TagBinding
|
||||
class TagService:
|
||||
@staticmethod
|
||||
def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None):
|
||||
query = (
|
||||
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
|
||||
stmt = (
|
||||
select(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
|
||||
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
|
||||
)
|
||||
@@ -23,9 +23,9 @@ class TagService:
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_keyword = escape_like_pattern(keyword)
|
||||
query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
|
||||
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
|
||||
results: list = query.order_by(Tag.created_at.desc()).all()
|
||||
stmt = stmt.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
|
||||
stmt = stmt.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
|
||||
results: list = list(db.session.execute(stmt.order_by(Tag.created_at.desc())).all())
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
@@ -64,8 +64,8 @@ class TagService:
|
||||
|
||||
@staticmethod
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str):
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
tags = db.session.scalars(
|
||||
select(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.where(
|
||||
TagBinding.target_id == target_id,
|
||||
@@ -73,8 +73,7 @@ class TagService:
|
||||
Tag.tenant_id == current_tenant_id,
|
||||
Tag.type == tag_type,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return tags or []
|
||||
|
||||
@@ -97,7 +96,7 @@ class TagService:
|
||||
def update_tags(args: dict, tag_id: str) -> Tag:
|
||||
if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")):
|
||||
raise ValueError("Tag name already exists")
|
||||
tag = db.session.query(Tag).where(Tag.id == tag_id).first()
|
||||
tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
tag.name = args["name"]
|
||||
@@ -106,12 +105,12 @@ class TagService:
|
||||
|
||||
@staticmethod
|
||||
def get_tag_binding_count(tag_id: str) -> int:
|
||||
count = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).count()
|
||||
count = db.session.scalar(select(func.count(TagBinding.id)).where(TagBinding.tag_id == tag_id)) or 0
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def delete_tag(tag_id: str):
|
||||
tag = db.session.query(Tag).where(Tag.id == tag_id).first()
|
||||
tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
db.session.delete(tag)
|
||||
@@ -128,10 +127,10 @@ class TagService:
|
||||
TagService.check_target_exists(args["type"], args["target_id"])
|
||||
# save tag binding
|
||||
for tag_id in args["tag_ids"]:
|
||||
tag_binding = (
|
||||
db.session.query(TagBinding)
|
||||
tag_binding = db.session.scalar(
|
||||
select(TagBinding)
|
||||
.where(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if tag_binding:
|
||||
continue
|
||||
@@ -149,10 +148,10 @@ class TagService:
|
||||
# check if target exists
|
||||
TagService.check_target_exists(args["type"], args["target_id"])
|
||||
# delete tag binding
|
||||
tag_bindings = (
|
||||
db.session.query(TagBinding)
|
||||
.where(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
|
||||
.first()
|
||||
tag_bindings = db.session.scalar(
|
||||
select(TagBinding)
|
||||
.where(TagBinding.target_id == args["target_id"], TagBinding.tag_id == args["tag_id"])
|
||||
.limit(1)
|
||||
)
|
||||
if tag_bindings:
|
||||
db.session.delete(tag_bindings)
|
||||
@@ -161,18 +160,16 @@ class TagService:
|
||||
@staticmethod
|
||||
def check_target_exists(type: str, target_id: str):
|
||||
if type == "knowledge":
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found")
|
||||
elif type == "app":
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
|
||||
.first()
|
||||
app = db.session.scalar(
|
||||
select(App).where(App.tenant_id == current_user.current_tenant_id, App.id == target_id).limit(1)
|
||||
)
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
Reference in New Issue
Block a user