refactor: core/tools, agent, callback_handler, encrypter, llm_generator, plugin, inner_api (#34205)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
Renzo
2026-03-28 11:14:43 +01:00
committed by GitHub
parent 7cc81e9a43
commit 364d7ebc40
19 changed files with 99 additions and 118 deletions

View File

@@ -18,7 +18,7 @@ from graphon.model_runtime.entities import (
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from graphon.model_runtime.entities.model_entities import ModelFeature
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from sqlalchemy import select
from sqlalchemy import func, select
from core.agent.entities import AgentEntity, AgentToolEntity
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@@ -104,11 +104,14 @@ class BaseAgentRunner(AppRunner):
)
# get how many agent thoughts have been created
self.agent_thought_count = (
db.session.query(MessageAgentThought)
.where(
MessageAgentThought.message_id == self.message.id,
db.session.scalar(
select(func.count())
.select_from(MessageAgentThought)
.where(
MessageAgentThought.message_id == self.message.id,
)
)
.count()
or 0
)
db.session.close()

View File

@@ -1,7 +1,7 @@
import logging
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy import select, update
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -70,23 +70,21 @@ class DatasetIndexToolCallbackHandler:
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
_ = (
db.session.query(DocumentSegment)
db.session.execute(
update(DocumentSegment)
.where(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
.values(hit_count=DocumentSegment.hit_count + 1)
)
else:
query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
conditions = [DocumentSegment.index_node_id == document.metadata["doc_id"]]
if "dataset_id" in document.metadata:
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
conditions.append(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
db.session.execute(
update(DocumentSegment).where(*conditions).values(hit_count=DocumentSegment.hit_count + 1)
)
db.session.commit()

View File

@@ -19,7 +19,7 @@ def encrypt_token(tenant_id: str, token: str):
from extensions.ext_database import db
from models.account import Tenant
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
if not (tenant := db.session.get(Tenant, tenant_id)):
raise ValueError(f"Tenant with id {tenant_id} not found")
assert tenant.encrypt_public_key is not None
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)

View File

@@ -10,6 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult
from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from sqlalchemy import select
from core.app.app_config.entities import ModelConfig
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
@@ -410,8 +411,8 @@ class LLMGenerator:
model_config: ModelConfig,
ideal_output: str | None,
):
last_run: Message | None = (
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
last_run: Message | None = db.session.scalar(
select(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).limit(1)
)
if not last_run:
return LLMGenerator.__instruction_modify_common(

View File

@@ -227,7 +227,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
get app
"""
try:
app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first()
app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1))
except Exception:
raise ValueError("app not found")

View File

@@ -1,4 +1,4 @@
from sqlalchemy import select
from sqlalchemy import delete, select
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.provider import BuiltinToolProviderController
@@ -31,7 +31,7 @@ class ToolLabelManager:
raise ValueError("Unsupported tool type")
# delete old labels
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete()
db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id))
# insert new labels
for label in labels:

View File

@@ -255,11 +255,11 @@ class ToolManager:
if builtin_provider is None:
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
builtin_provider = (
db.session.query(BuiltinToolProvider)
builtin_provider = db.session.scalar(
select(BuiltinToolProvider)
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
.limit(1)
)
if builtin_provider is None:
@@ -818,13 +818,13 @@ class ToolManager:
:return: the provider controller, the credentials
"""
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
provider: ApiToolProvider | None = db.session.scalar(
select(ApiToolProvider)
.where(
ApiToolProvider.id == provider_id,
ApiToolProvider.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if provider is None:
@@ -872,13 +872,13 @@ class ToolManager:
get api provider
"""
provider_name = provider
provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
provider_obj: ApiToolProvider | None = db.session.scalar(
select(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
)
.first()
.limit(1)
)
if provider_obj is None:
@@ -964,10 +964,10 @@ class ToolManager:
@classmethod
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
try:
workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
workflow_provider: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
.limit(1)
)
if workflow_provider is None:
@@ -981,10 +981,10 @@ class ToolManager:
@classmethod
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
try:
api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
api_provider: ApiToolProvider | None = db.session.scalar(
select(ApiToolProvider)
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.first()
.limit(1)
)
if api_provider is None:

View File

@@ -110,7 +110,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
context_list: list[RetrievalSourceMetadata] = []
resource_number = 1
for segment in sorted_segments:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
dataset = db.session.get(Dataset, segment.dataset_id)
document_stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,

View File

@@ -205,7 +205,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if self.return_resource:
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
dataset = db.session.get(Dataset, segment.dataset_id)
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,