refactor: remove all reqparser (#29289)

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2026-02-01 13:43:14 +09:00
committed by GitHub
parent b8cb5f5ea2
commit 7828508b30
10 changed files with 434 additions and 412 deletions

View File

@@ -53,6 +53,7 @@ select = [
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers. "S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
"S302", # suspicious-marshal-usage, disallow use of `marshal` module "S302", # suspicious-marshal-usage, disallow use of `marshal` module
"S311", # suspicious-non-cryptographic-random-usage, "S311", # suspicious-non-cryptographic-random-usage,
"TID", # flake8-tidy-imports
] ]
@@ -88,6 +89,7 @@ ignore = [
"SIM113", # enumerate-for-loop "SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements "SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false "SIM210", # if-expr-with-true-false
"TID252", # allow relative imports from parent modules
] ]
[lint.per-file-ignores] [lint.per-file-ignores]
@@ -109,10 +111,20 @@ ignore = [
"S110", # allow ignoring exceptions in tests code (currently) "S110", # allow ignoring exceptions in tests code (currently)
] ]
"controllers/console/explore/trial.py" = ["TID251"]
"controllers/console/human_input_form.py" = ["TID251"]
"controllers/web/human_input_form.py" = ["TID251"]
[lint.pyflakes] [lint.pyflakes]
allowed-unused-imports = [ allowed-unused-imports = [
"_pytest.monkeypatch",
"tests.integration_tests", "tests.integration_tests",
"tests.unit_tests", "tests.unit_tests",
] ]
[lint.flake8-tidy-imports]
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
msg = "Use Pydantic payload/query models instead of reqparse."
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"]
msg = "Use Pydantic payload/query models instead of reqparse."

View File

@@ -1,10 +1,9 @@
import json import json
import logging import logging
from typing import Any, Literal, cast from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request from flask import abort, request
from flask_restx import Resource, marshal_with, reqparse # type: ignore from flask_restx import Resource, marshal_with # type: ignore
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@@ -38,7 +37,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db from extensions.ext_database import db
from factories import variable_factory from factories import variable_factory
from libs import helper from libs import helper
from libs.helper import TimestampField from libs.helper import TimestampField, UUIDStrOrEmpty
from libs.login import current_account_with_tenant, current_user, login_required from libs.login import current_account_with_tenant, current_user, login_required
from models import Account from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
@@ -110,7 +109,7 @@ class NodeIdQuery(BaseModel):
class WorkflowRunQuery(BaseModel): class WorkflowRunQuery(BaseModel):
last_id: UUID | None = None last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100) limit: int = Field(default=20, ge=1, le=100)
@@ -121,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel):
start_node_title: str start_node_title: str
class RagPipelineRecommendedPluginQuery(BaseModel):
type: str = "all"
register_schema_models( register_schema_models(
console_ns, console_ns,
DraftWorkflowSyncPayload, DraftWorkflowSyncPayload,
@@ -135,6 +138,7 @@ register_schema_models(
NodeIdQuery, NodeIdQuery,
WorkflowRunQuery, WorkflowRunQuery,
DatasourceVariablesPayload, DatasourceVariablesPayload,
RagPipelineRecommendedPluginQuery,
) )
@@ -975,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
parser = reqparse.RequestParser() query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict())
parser.add_argument("type", type=str, location="args", required=False, default="all")
args = parser.parse_args()
type = args["type"]
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type) recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type)
return recommended_plugins return recommended_plugins

File diff suppressed because it is too large Load Diff

View File

@@ -30,6 +30,7 @@ from core.errors.error import (
from core.helper.trace_id_helper import get_external_trace_id from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService from services.app_task_service import AppTaskService
@@ -52,7 +53,7 @@ class ChatRequestPayload(BaseModel):
query: str query: str
files: list[dict[str, Any]] | None = None files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None response_mode: Literal["blocking", "streaming"] | None = None
conversation_id: str | None = Field(default=None, description="Conversation UUID") conversation_id: UUIDStrOrEmpty | None = Field(default=None, description="Conversation UUID")
retriever_from: str = Field(default="dev") retriever_from: str = Field(default="dev")
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name") auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat") workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")

View File

@@ -1,5 +1,4 @@
from typing import Any, Literal from typing import Any, Literal
from uuid import UUID
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
@@ -23,12 +22,13 @@ from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model, build_conversation_variable_infinite_scroll_pagination_model,
build_conversation_variable_model, build_conversation_variable_model,
) )
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
class ConversationListQuery(BaseModel): class ConversationListQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination") last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last conversation ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return") limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field( sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort order for conversations" default="-updated_at", description="Sort order for conversations"
@@ -48,7 +48,7 @@ class ConversationRenamePayload(BaseModel):
class ConversationVariablesQuery(BaseModel): class ConversationVariablesQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination") last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return") limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
variable_name: str | None = Field( variable_name: str | None = Field(
default=None, description="Filter variables by name", min_length=1, max_length=255 default=None, description="Filter variables by name", min_length=1, max_length=255

View File

@@ -1,6 +1,5 @@
import logging import logging
from typing import Literal from typing import Literal
from uuid import UUID
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
@@ -15,6 +14,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from services.errors.message import ( from services.errors.message import (
FirstMessageNotExistsError, FirstMessageNotExistsError,
@@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel): class MessageListQuery(BaseModel):
conversation_id: UUID conversation_id: UUIDStrOrEmpty
first_id: UUID | None = None first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return") limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")

View File

@@ -1,7 +1,10 @@
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from controllers.common.schema import register_schema_model
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
register_schema_model(service_api_ns, HitTestingPayload)
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve") @service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
@@ -15,6 +18,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
404: "Dataset not found", 404: "Dataset not found",
} }
) )
@service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__])
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
"""Perform hit testing on a dataset. """Perform hit testing on a dataset.

View File

@@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity
class WorkflowToolConfigurationUtils: class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod @classmethod
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
""" """

View File

@@ -1,8 +1,6 @@
import json import json
import logging import logging
from collections.abc import Mapping
from datetime import datetime from datetime import datetime
from typing import Any
from sqlalchemy import or_, select from sqlalchemy import or_, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -10,8 +8,8 @@ from sqlalchemy.orm import Session
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db from extensions.ext_database import db
@@ -38,12 +36,10 @@ class WorkflowToolManageService:
label: str, label: str,
icon: dict, icon: dict,
description: str, description: str,
parameters: list[Mapping[str, Any]], parameters: list[WorkflowToolParameterConfiguration],
privacy_policy: str = "", privacy_policy: str = "",
labels: list[str] | None = None, labels: list[str] | None = None,
): ):
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique # check if the name is unique
existing_workflow_tool_provider = ( existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
@@ -75,7 +71,7 @@ class WorkflowToolManageService:
label=label, label=label,
icon=json.dumps(icon), icon=json.dumps(icon),
description=description, description=description,
parameter_configuration=json.dumps(parameters), parameter_configuration=json.dumps([p.model_dump() for p in parameters]),
privacy_policy=privacy_policy, privacy_policy=privacy_policy,
version=workflow.version, version=workflow.version,
) )
@@ -104,7 +100,7 @@ class WorkflowToolManageService:
label: str, label: str,
icon: dict, icon: dict,
description: str, description: str,
parameters: list[Mapping[str, Any]], parameters: list[WorkflowToolParameterConfiguration],
privacy_policy: str = "", privacy_policy: str = "",
labels: list[str] | None = None, labels: list[str] | None = None,
): ):
@@ -122,8 +118,6 @@ class WorkflowToolManageService:
:param labels: labels :param labels: labels
:return: the updated tool :return: the updated tool
""" """
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique # check if the name is unique
existing_workflow_tool_provider = ( existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
@@ -162,7 +156,7 @@ class WorkflowToolManageService:
workflow_tool_provider.label = label workflow_tool_provider.label = label
workflow_tool_provider.icon = json.dumps(icon) workflow_tool_provider.icon = json.dumps(icon)
workflow_tool_provider.description = description workflow_tool_provider.description = description
workflow_tool_provider.parameter_configuration = json.dumps(parameters) workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters])
workflow_tool_provider.privacy_policy = privacy_policy workflow_tool_provider.privacy_policy = privacy_policy
workflow_tool_provider.version = workflow.version workflow_tool_provider.version = workflow.version
workflow_tool_provider.updated_at = datetime.now() workflow_tool_provider.updated_at = datetime.now()

View File

@@ -3,7 +3,9 @@ from unittest.mock import patch
import pytest import pytest
from faker import Faker from faker import Faker
from pydantic import ValidationError
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from models.tools import WorkflowToolProvider from models.tools import WorkflowToolProvider
from models.workflow import Workflow as WorkflowModel from models.workflow import Workflow as WorkflowModel
from services.account_service import AccountService, TenantService from services.account_service import AccountService, TenantService
@@ -130,20 +132,24 @@ class TestWorkflowToolManageService:
def _create_test_workflow_tool_parameters(self): def _create_test_workflow_tool_parameters(self):
"""Helper method to create valid workflow tool parameters.""" """Helper method to create valid workflow tool parameters."""
return [ return [
{ WorkflowToolParameterConfiguration.model_validate(
"name": "input_text", {
"description": "Input text for processing", "name": "input_text",
"form": "form", "description": "Input text for processing",
"type": "string", "form": "form",
"required": True, "type": "string",
}, "required": True,
{ }
"name": "output_format", ),
"description": "Output format specification", WorkflowToolParameterConfiguration.model_validate(
"form": "form", {
"type": "select", "name": "output_format",
"required": False, "description": "Output format specification",
}, "form": "form",
"type": "select",
"required": False,
}
),
] ]
def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
@@ -208,7 +214,7 @@ class TestWorkflowToolManageService:
assert created_tool_provider.label == tool_label assert created_tool_provider.label == tool_label
assert created_tool_provider.icon == json.dumps(tool_icon) assert created_tool_provider.icon == json.dumps(tool_icon)
assert created_tool_provider.description == tool_description assert created_tool_provider.description == tool_description
assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters) assert created_tool_provider.parameter_configuration == json.dumps([p.model_dump() for p in tool_parameters])
assert created_tool_provider.privacy_policy == tool_privacy_policy assert created_tool_provider.privacy_policy == tool_privacy_policy
assert created_tool_provider.version == workflow.version assert created_tool_provider.version == workflow.version
assert created_tool_provider.user_id == account.id assert created_tool_provider.user_id == account.id
@@ -353,18 +359,9 @@ class TestWorkflowToolManageService:
app, account, workflow = self._create_test_app_and_account( app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies db_session_with_containers, mock_external_service_dependencies
) )
# Setup invalid workflow tool parameters (missing required fields)
invalid_parameters = [
{
"name": "input_text",
# Missing description and form fields
"type": "string",
"required": True,
}
]
# Attempt to create workflow tool with invalid parameters # Attempt to create workflow tool with invalid parameters
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValidationError) as exc_info:
# Setup invalid workflow tool parameters (missing required fields)
WorkflowToolManageService.create_workflow_tool( WorkflowToolManageService.create_workflow_tool(
user_id=account.id, user_id=account.id,
tenant_id=account.current_tenant.id, tenant_id=account.current_tenant.id,
@@ -373,7 +370,16 @@ class TestWorkflowToolManageService:
label=fake.word(), label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"}, icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200), description=fake.text(max_nb_chars=200),
parameters=invalid_parameters, parameters=[
WorkflowToolParameterConfiguration.model_validate(
{
"name": "input_text",
# Missing description and form fields
"type": "string",
"required": True,
}
)
],
) )
# Verify error message contains validation error # Verify error message contains validation error
@@ -579,11 +585,12 @@ class TestWorkflowToolManageService:
# Verify database state was updated # Verify database state was updated
db.session.refresh(created_tool) db.session.refresh(created_tool)
assert created_tool is not None
assert created_tool.name == updated_tool_name assert created_tool.name == updated_tool_name
assert created_tool.label == updated_tool_label assert created_tool.label == updated_tool_label
assert created_tool.icon == json.dumps(updated_tool_icon) assert created_tool.icon == json.dumps(updated_tool_icon)
assert created_tool.description == updated_tool_description assert created_tool.description == updated_tool_description
assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters) assert created_tool.parameter_configuration == json.dumps([p.model_dump() for p in updated_tool_parameters])
assert created_tool.privacy_policy == updated_tool_privacy_policy assert created_tool.privacy_policy == updated_tool_privacy_policy
assert created_tool.version == workflow.version assert created_tool.version == workflow.version
assert created_tool.updated_at is not None assert created_tool.updated_at is not None
@@ -750,13 +757,15 @@ class TestWorkflowToolManageService:
# Setup workflow tool parameters with FILE type # Setup workflow tool parameters with FILE type
file_parameters = [ file_parameters = [
{ WorkflowToolParameterConfiguration.model_validate(
"name": "document", {
"description": "Upload a document", "name": "document",
"form": "form", "description": "Upload a document",
"type": "file", "form": "form",
"required": False, "type": "file",
} "required": False,
}
)
] ]
# Execute the method under test # Execute the method under test
@@ -823,13 +832,15 @@ class TestWorkflowToolManageService:
# Setup workflow tool parameters with FILES type # Setup workflow tool parameters with FILES type
files_parameters = [ files_parameters = [
{ WorkflowToolParameterConfiguration.model_validate(
"name": "documents", {
"description": "Upload multiple documents", "name": "documents",
"form": "form", "description": "Upload multiple documents",
"type": "files", "form": "form",
"required": False, "type": "files",
} "required": False,
}
)
] ]
# Execute the method under test # Execute the method under test