mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 16:59:21 +08:00
refactor: remove all reqparser (#29289)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
This commit is contained in:
@@ -53,6 +53,7 @@ select = [
|
|||||||
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
|
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
|
||||||
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
||||||
"S311", # suspicious-non-cryptographic-random-usage,
|
"S311", # suspicious-non-cryptographic-random-usage,
|
||||||
|
"TID", # flake8-tidy-imports
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -88,6 +89,7 @@ ignore = [
|
|||||||
"SIM113", # enumerate-for-loop
|
"SIM113", # enumerate-for-loop
|
||||||
"SIM117", # multiple-with-statements
|
"SIM117", # multiple-with-statements
|
||||||
"SIM210", # if-expr-with-true-false
|
"SIM210", # if-expr-with-true-false
|
||||||
|
"TID252", # allow relative imports from parent modules
|
||||||
]
|
]
|
||||||
|
|
||||||
[lint.per-file-ignores]
|
[lint.per-file-ignores]
|
||||||
@@ -109,10 +111,20 @@ ignore = [
|
|||||||
"S110", # allow ignoring exceptions in tests code (currently)
|
"S110", # allow ignoring exceptions in tests code (currently)
|
||||||
|
|
||||||
]
|
]
|
||||||
|
"controllers/console/explore/trial.py" = ["TID251"]
|
||||||
|
"controllers/console/human_input_form.py" = ["TID251"]
|
||||||
|
"controllers/web/human_input_form.py" = ["TID251"]
|
||||||
|
|
||||||
[lint.pyflakes]
|
[lint.pyflakes]
|
||||||
allowed-unused-imports = [
|
allowed-unused-imports = [
|
||||||
"_pytest.monkeypatch",
|
|
||||||
"tests.integration_tests",
|
"tests.integration_tests",
|
||||||
"tests.unit_tests",
|
"tests.unit_tests",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[lint.flake8-tidy-imports]
|
||||||
|
|
||||||
|
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
|
||||||
|
msg = "Use Pydantic payload/query models instead of reqparse."
|
||||||
|
|
||||||
|
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"]
|
||||||
|
msg = "Use Pydantic payload/query models instead of reqparse."
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Literal, cast
|
from typing import Any, Literal, cast
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from flask import abort, request
|
from flask import abort, request
|
||||||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
from flask_restx import Resource, marshal_with # type: ignore
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
@@ -38,7 +37,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import variable_factory
|
from factories import variable_factory
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||||
from libs.login import current_account_with_tenant, current_user, login_required
|
from libs.login import current_account_with_tenant, current_user, login_required
|
||||||
from models import Account
|
from models import Account
|
||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
@@ -110,7 +109,7 @@ class NodeIdQuery(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class WorkflowRunQuery(BaseModel):
|
class WorkflowRunQuery(BaseModel):
|
||||||
last_id: UUID | None = None
|
last_id: UUIDStrOrEmpty | None = None
|
||||||
limit: int = Field(default=20, ge=1, le=100)
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
|
|
||||||
|
|
||||||
@@ -121,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel):
|
|||||||
start_node_title: str
|
start_node_title: str
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineRecommendedPluginQuery(BaseModel):
|
||||||
|
type: str = "all"
|
||||||
|
|
||||||
|
|
||||||
register_schema_models(
|
register_schema_models(
|
||||||
console_ns,
|
console_ns,
|
||||||
DraftWorkflowSyncPayload,
|
DraftWorkflowSyncPayload,
|
||||||
@@ -135,6 +138,7 @@ register_schema_models(
|
|||||||
NodeIdQuery,
|
NodeIdQuery,
|
||||||
WorkflowRunQuery,
|
WorkflowRunQuery,
|
||||||
DatasourceVariablesPayload,
|
DatasourceVariablesPayload,
|
||||||
|
RagPipelineRecommendedPluginQuery,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -975,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict())
|
||||||
parser.add_argument("type", type=str, location="args", required=False, default="all")
|
|
||||||
args = parser.parse_args()
|
|
||||||
type = args["type"]
|
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
|
recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type)
|
||||||
return recommended_plugins
|
return recommended_plugins
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -30,6 +30,7 @@ from core.errors.error import (
|
|||||||
from core.helper.trace_id_helper import get_external_trace_id
|
from core.helper.trace_id_helper import get_external_trace_id
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs import helper
|
from libs import helper
|
||||||
|
from libs.helper import UUIDStrOrEmpty
|
||||||
from models.model import App, AppMode, EndUser
|
from models.model import App, AppMode, EndUser
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.app_task_service import AppTaskService
|
from services.app_task_service import AppTaskService
|
||||||
@@ -52,7 +53,7 @@ class ChatRequestPayload(BaseModel):
|
|||||||
query: str
|
query: str
|
||||||
files: list[dict[str, Any]] | None = None
|
files: list[dict[str, Any]] | None = None
|
||||||
response_mode: Literal["blocking", "streaming"] | None = None
|
response_mode: Literal["blocking", "streaming"] | None = None
|
||||||
conversation_id: str | None = Field(default=None, description="Conversation UUID")
|
conversation_id: UUIDStrOrEmpty | None = Field(default=None, description="Conversation UUID")
|
||||||
retriever_from: str = Field(default="dev")
|
retriever_from: str = Field(default="dev")
|
||||||
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
|
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
|
||||||
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
|
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@@ -23,12 +22,13 @@ from fields.conversation_variable_fields import (
|
|||||||
build_conversation_variable_infinite_scroll_pagination_model,
|
build_conversation_variable_infinite_scroll_pagination_model,
|
||||||
build_conversation_variable_model,
|
build_conversation_variable_model,
|
||||||
)
|
)
|
||||||
|
from libs.helper import UUIDStrOrEmpty
|
||||||
from models.model import App, AppMode, EndUser
|
from models.model import App, AppMode, EndUser
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
|
|
||||||
|
|
||||||
class ConversationListQuery(BaseModel):
|
class ConversationListQuery(BaseModel):
|
||||||
last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination")
|
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last conversation ID for pagination")
|
||||||
limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
|
limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
|
||||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
||||||
default="-updated_at", description="Sort order for conversations"
|
default="-updated_at", description="Sort order for conversations"
|
||||||
@@ -48,7 +48,7 @@ class ConversationRenamePayload(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ConversationVariablesQuery(BaseModel):
|
class ConversationVariablesQuery(BaseModel):
|
||||||
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
|
||||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||||
variable_name: str | None = Field(
|
variable_name: str | None = Field(
|
||||||
default=None, description="Filter variables by name", min_length=1, max_length=255
|
default=None, description="Filter variables by name", min_length=1, max_length=255
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@@ -15,6 +14,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
|||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from fields.conversation_fields import ResultResponse
|
from fields.conversation_fields import ResultResponse
|
||||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||||
|
from libs.helper import UUIDStrOrEmpty
|
||||||
from models.model import App, AppMode, EndUser
|
from models.model import App, AppMode, EndUser
|
||||||
from services.errors.message import (
|
from services.errors.message import (
|
||||||
FirstMessageNotExistsError,
|
FirstMessageNotExistsError,
|
||||||
@@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class MessageListQuery(BaseModel):
|
class MessageListQuery(BaseModel):
|
||||||
conversation_id: UUID
|
conversation_id: UUIDStrOrEmpty
|
||||||
first_id: UUID | None = None
|
first_id: UUIDStrOrEmpty | None = None
|
||||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
|
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
from controllers.common.schema import register_schema_model
|
||||||
|
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||||
|
|
||||||
|
register_schema_model(service_api_ns, HitTestingPayload)
|
||||||
|
|
||||||
|
|
||||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
|
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
|
||||||
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||||
@@ -15,6 +18,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
|||||||
404: "Dataset not found",
|
404: "Dataset not found",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__])
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||||
def post(self, tenant_id, dataset_id):
|
def post(self, tenant_id, dataset_id):
|
||||||
"""Perform hit testing on a dataset.
|
"""Perform hit testing on a dataset.
|
||||||
|
|||||||
@@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity
|
|||||||
|
|
||||||
|
|
||||||
class WorkflowToolConfigurationUtils:
|
class WorkflowToolConfigurationUtils:
|
||||||
@classmethod
|
|
||||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
|
||||||
for configuration in configurations:
|
|
||||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Mapping
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import or_, select
|
from sqlalchemy import or_, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -10,8 +8,8 @@ from sqlalchemy.orm import Session
|
|||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||||
|
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
|
||||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@@ -38,12 +36,10 @@ class WorkflowToolManageService:
|
|||||||
label: str,
|
label: str,
|
||||||
icon: dict,
|
icon: dict,
|
||||||
description: str,
|
description: str,
|
||||||
parameters: list[Mapping[str, Any]],
|
parameters: list[WorkflowToolParameterConfiguration],
|
||||||
privacy_policy: str = "",
|
privacy_policy: str = "",
|
||||||
labels: list[str] | None = None,
|
labels: list[str] | None = None,
|
||||||
):
|
):
|
||||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
|
||||||
|
|
||||||
# check if the name is unique
|
# check if the name is unique
|
||||||
existing_workflow_tool_provider = (
|
existing_workflow_tool_provider = (
|
||||||
db.session.query(WorkflowToolProvider)
|
db.session.query(WorkflowToolProvider)
|
||||||
@@ -75,7 +71,7 @@ class WorkflowToolManageService:
|
|||||||
label=label,
|
label=label,
|
||||||
icon=json.dumps(icon),
|
icon=json.dumps(icon),
|
||||||
description=description,
|
description=description,
|
||||||
parameter_configuration=json.dumps(parameters),
|
parameter_configuration=json.dumps([p.model_dump() for p in parameters]),
|
||||||
privacy_policy=privacy_policy,
|
privacy_policy=privacy_policy,
|
||||||
version=workflow.version,
|
version=workflow.version,
|
||||||
)
|
)
|
||||||
@@ -104,7 +100,7 @@ class WorkflowToolManageService:
|
|||||||
label: str,
|
label: str,
|
||||||
icon: dict,
|
icon: dict,
|
||||||
description: str,
|
description: str,
|
||||||
parameters: list[Mapping[str, Any]],
|
parameters: list[WorkflowToolParameterConfiguration],
|
||||||
privacy_policy: str = "",
|
privacy_policy: str = "",
|
||||||
labels: list[str] | None = None,
|
labels: list[str] | None = None,
|
||||||
):
|
):
|
||||||
@@ -122,8 +118,6 @@ class WorkflowToolManageService:
|
|||||||
:param labels: labels
|
:param labels: labels
|
||||||
:return: the updated tool
|
:return: the updated tool
|
||||||
"""
|
"""
|
||||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
|
||||||
|
|
||||||
# check if the name is unique
|
# check if the name is unique
|
||||||
existing_workflow_tool_provider = (
|
existing_workflow_tool_provider = (
|
||||||
db.session.query(WorkflowToolProvider)
|
db.session.query(WorkflowToolProvider)
|
||||||
@@ -162,7 +156,7 @@ class WorkflowToolManageService:
|
|||||||
workflow_tool_provider.label = label
|
workflow_tool_provider.label = label
|
||||||
workflow_tool_provider.icon = json.dumps(icon)
|
workflow_tool_provider.icon = json.dumps(icon)
|
||||||
workflow_tool_provider.description = description
|
workflow_tool_provider.description = description
|
||||||
workflow_tool_provider.parameter_configuration = json.dumps(parameters)
|
workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters])
|
||||||
workflow_tool_provider.privacy_policy = privacy_policy
|
workflow_tool_provider.privacy_policy = privacy_policy
|
||||||
workflow_tool_provider.version = workflow.version
|
workflow_tool_provider.version = workflow.version
|
||||||
workflow_tool_provider.updated_at = datetime.now()
|
workflow_tool_provider.updated_at = datetime.now()
|
||||||
|
|||||||
@@ -3,7 +3,9 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from faker import Faker
|
from faker import Faker
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||||
from models.tools import WorkflowToolProvider
|
from models.tools import WorkflowToolProvider
|
||||||
from models.workflow import Workflow as WorkflowModel
|
from models.workflow import Workflow as WorkflowModel
|
||||||
from services.account_service import AccountService, TenantService
|
from services.account_service import AccountService, TenantService
|
||||||
@@ -130,20 +132,24 @@ class TestWorkflowToolManageService:
|
|||||||
def _create_test_workflow_tool_parameters(self):
|
def _create_test_workflow_tool_parameters(self):
|
||||||
"""Helper method to create valid workflow tool parameters."""
|
"""Helper method to create valid workflow tool parameters."""
|
||||||
return [
|
return [
|
||||||
{
|
WorkflowToolParameterConfiguration.model_validate(
|
||||||
"name": "input_text",
|
{
|
||||||
"description": "Input text for processing",
|
"name": "input_text",
|
||||||
"form": "form",
|
"description": "Input text for processing",
|
||||||
"type": "string",
|
"form": "form",
|
||||||
"required": True,
|
"type": "string",
|
||||||
},
|
"required": True,
|
||||||
{
|
}
|
||||||
"name": "output_format",
|
),
|
||||||
"description": "Output format specification",
|
WorkflowToolParameterConfiguration.model_validate(
|
||||||
"form": "form",
|
{
|
||||||
"type": "select",
|
"name": "output_format",
|
||||||
"required": False,
|
"description": "Output format specification",
|
||||||
},
|
"form": "form",
|
||||||
|
"type": "select",
|
||||||
|
"required": False,
|
||||||
|
}
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
|
def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
@@ -208,7 +214,7 @@ class TestWorkflowToolManageService:
|
|||||||
assert created_tool_provider.label == tool_label
|
assert created_tool_provider.label == tool_label
|
||||||
assert created_tool_provider.icon == json.dumps(tool_icon)
|
assert created_tool_provider.icon == json.dumps(tool_icon)
|
||||||
assert created_tool_provider.description == tool_description
|
assert created_tool_provider.description == tool_description
|
||||||
assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters)
|
assert created_tool_provider.parameter_configuration == json.dumps([p.model_dump() for p in tool_parameters])
|
||||||
assert created_tool_provider.privacy_policy == tool_privacy_policy
|
assert created_tool_provider.privacy_policy == tool_privacy_policy
|
||||||
assert created_tool_provider.version == workflow.version
|
assert created_tool_provider.version == workflow.version
|
||||||
assert created_tool_provider.user_id == account.id
|
assert created_tool_provider.user_id == account.id
|
||||||
@@ -353,18 +359,9 @@ class TestWorkflowToolManageService:
|
|||||||
app, account, workflow = self._create_test_app_and_account(
|
app, account, workflow = self._create_test_app_and_account(
|
||||||
db_session_with_containers, mock_external_service_dependencies
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup invalid workflow tool parameters (missing required fields)
|
|
||||||
invalid_parameters = [
|
|
||||||
{
|
|
||||||
"name": "input_text",
|
|
||||||
# Missing description and form fields
|
|
||||||
"type": "string",
|
|
||||||
"required": True,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
# Attempt to create workflow tool with invalid parameters
|
# Attempt to create workflow tool with invalid parameters
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
# Setup invalid workflow tool parameters (missing required fields)
|
||||||
WorkflowToolManageService.create_workflow_tool(
|
WorkflowToolManageService.create_workflow_tool(
|
||||||
user_id=account.id,
|
user_id=account.id,
|
||||||
tenant_id=account.current_tenant.id,
|
tenant_id=account.current_tenant.id,
|
||||||
@@ -373,7 +370,16 @@ class TestWorkflowToolManageService:
|
|||||||
label=fake.word(),
|
label=fake.word(),
|
||||||
icon={"type": "emoji", "emoji": "🔧"},
|
icon={"type": "emoji", "emoji": "🔧"},
|
||||||
description=fake.text(max_nb_chars=200),
|
description=fake.text(max_nb_chars=200),
|
||||||
parameters=invalid_parameters,
|
parameters=[
|
||||||
|
WorkflowToolParameterConfiguration.model_validate(
|
||||||
|
{
|
||||||
|
"name": "input_text",
|
||||||
|
# Missing description and form fields
|
||||||
|
"type": "string",
|
||||||
|
"required": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify error message contains validation error
|
# Verify error message contains validation error
|
||||||
@@ -579,11 +585,12 @@ class TestWorkflowToolManageService:
|
|||||||
|
|
||||||
# Verify database state was updated
|
# Verify database state was updated
|
||||||
db.session.refresh(created_tool)
|
db.session.refresh(created_tool)
|
||||||
|
assert created_tool is not None
|
||||||
assert created_tool.name == updated_tool_name
|
assert created_tool.name == updated_tool_name
|
||||||
assert created_tool.label == updated_tool_label
|
assert created_tool.label == updated_tool_label
|
||||||
assert created_tool.icon == json.dumps(updated_tool_icon)
|
assert created_tool.icon == json.dumps(updated_tool_icon)
|
||||||
assert created_tool.description == updated_tool_description
|
assert created_tool.description == updated_tool_description
|
||||||
assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters)
|
assert created_tool.parameter_configuration == json.dumps([p.model_dump() for p in updated_tool_parameters])
|
||||||
assert created_tool.privacy_policy == updated_tool_privacy_policy
|
assert created_tool.privacy_policy == updated_tool_privacy_policy
|
||||||
assert created_tool.version == workflow.version
|
assert created_tool.version == workflow.version
|
||||||
assert created_tool.updated_at is not None
|
assert created_tool.updated_at is not None
|
||||||
@@ -750,13 +757,15 @@ class TestWorkflowToolManageService:
|
|||||||
|
|
||||||
# Setup workflow tool parameters with FILE type
|
# Setup workflow tool parameters with FILE type
|
||||||
file_parameters = [
|
file_parameters = [
|
||||||
{
|
WorkflowToolParameterConfiguration.model_validate(
|
||||||
"name": "document",
|
{
|
||||||
"description": "Upload a document",
|
"name": "document",
|
||||||
"form": "form",
|
"description": "Upload a document",
|
||||||
"type": "file",
|
"form": "form",
|
||||||
"required": False,
|
"type": "file",
|
||||||
}
|
"required": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Execute the method under test
|
# Execute the method under test
|
||||||
@@ -823,13 +832,15 @@ class TestWorkflowToolManageService:
|
|||||||
|
|
||||||
# Setup workflow tool parameters with FILES type
|
# Setup workflow tool parameters with FILES type
|
||||||
files_parameters = [
|
files_parameters = [
|
||||||
{
|
WorkflowToolParameterConfiguration.model_validate(
|
||||||
"name": "documents",
|
{
|
||||||
"description": "Upload multiple documents",
|
"name": "documents",
|
||||||
"form": "form",
|
"description": "Upload multiple documents",
|
||||||
"type": "files",
|
"form": "form",
|
||||||
"required": False,
|
"type": "files",
|
||||||
}
|
"required": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Execute the method under test
|
# Execute the method under test
|
||||||
|
|||||||
Reference in New Issue
Block a user